Plot Weighted SamplesΒΆ

===================== SVM: Weighted samplesΒΆ

Plot decision function of a weighted dataset, where the size of points is proportional to its weight.

The sample weighting rescales the C parameter, which means that the classifier puts more emphasis on getting these points right. The effect might often be subtle. To emphasize the effect here, we particularly increase the weight of the positive class, making the deformation of the decision boundary more visible.

Imports for SVM with Weighted SamplesΒΆ

Sample weighting in SVM allows individual data points to have different importance during training. Internally, each sample’s weight rescales its contribution to the C penalty, so a sample with weight w effectively has regularization parameter C*w. This is useful for handling class imbalance, incorporating prior knowledge about data reliability, or implementing cost-sensitive learning where misclassifying certain samples is more expensive.

Visualization approach: Point sizes in the scatter plot are proportional to their weights, making it visually clear which samples the model should prioritize. By comparing the decision boundaries with and without sample weights, we can observe how up-weighting the minority positive class shifts the boundary to reduce false negatives at the cost of more false positives on the majority class.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_classification
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.svm import SVC

X, y = make_classification(
    n_samples=1_000,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_clusters_per_class=1,
    class_sep=1.1,
    weights=[0.9, 0.1],
    random_state=0,
)
# down-sample for plotting
rng = np.random.RandomState(0)
plot_indices = rng.choice(np.arange(X.shape[0]), size=100, replace=True)
X_plot, y_plot = X[plot_indices], y[plot_indices]

Plotting Helper: Decision Function with Weighted Data PointsΒΆ

The visualization function renders both the data points (sized by their sample weights) and the SVM decision function as a continuous color map. Larger circles indicate higher-weight samples that exert more pull on the decision boundary. Comparing the β€œconstant weights” and β€œmodified weights” panels side by side reveals exactly how the increased weight on the positive class deforms the boundary – the classifier sacrifices accuracy on the larger negative class to better accommodate the up-weighted positive samples.

def plot_decision_function(classifier, sample_weight, axis, title):
    """Plot the synthetic data and the classifier decision function. Points with
    larger sample_weight are mapped to larger circles in the scatter plot."""
    axis.scatter(
        X_plot[:, 0],
        X_plot[:, 1],
        c=y_plot,
        s=100 * sample_weight[plot_indices],
        alpha=0.9,
        cmap=plt.cm.bone,
        edgecolors="black",
    )
    DecisionBoundaryDisplay.from_estimator(
        classifier,
        X_plot,
        response_method="decision_function",
        alpha=0.75,
        ax=axis,
        cmap=plt.cm.bone,
    )
    axis.axis("off")
    axis.set_title(title)


# we define constant weights as expected by the plotting function
sample_weight_constant = np.ones(len(X))
# assign random weights to all points
sample_weight_modified = abs(rng.randn(len(X)))
# assign bigger weights to the positive class
positive_class_indices = np.asarray(y == 1).nonzero()[0]
sample_weight_modified[positive_class_indices] *= 15

# This model does not include sample weights.
clf_no_weights = SVC(gamma=1)
clf_no_weights.fit(X, y)

# This other model includes sample weights.
clf_weights = SVC(gamma=1)
clf_weights.fit(X, y, sample_weight=sample_weight_modified)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
plot_decision_function(
    clf_no_weights, sample_weight_constant, axes[0], "Constant weights"
)
plot_decision_function(clf_weights, sample_weight_modified, axes[1], "Modified weights")

plt.show()