Plot Sgd Separating HyperplaneΒΆ
========================================= SGD: Maximum margin separating hyperplaneΒΆ
Plot the maximum margin separating hyperplane within a two-class separable dataset using a linear Support Vector Machines classifier trained using SGD.
Imports for SGD Maximum-Margin HyperplaneΒΆ
A maximum-margin separating hyperplane is the decision boundary that is as far as possible from the nearest data points of each class. This is the core idea behind Support Vector Machines (SVMs). Here we train an SVM via Stochastic Gradient Descent using SGDClassifier with loss="hinge", which is mathematically equivalent to a linear SVM but scales much better to large datasets.
What the visualization shows: The solid contour line represents the decision boundary (where the decision function equals zero), while the dashed lines represent the margin boundaries (decision function = +/-1). Data points between or on the dashed lines are the support vectors β they are the only samples that influence the position of the hyperplane. The alpha parameter controls regularization strength, trading off between maximizing the margin width and minimizing classification errors on the training set.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.linear_model import SGDClassifier
# we create 50 separable points
X, Y = make_blobs(n_samples=50, centers=2, random_state=0, cluster_std=0.60)
# fit the model
clf = SGDClassifier(loss="hinge", alpha=0.01, max_iter=200)
clf.fit(X, Y)
# plot the line, the points, and the nearest vectors to the plane
xx = np.linspace(-1, 5, 10)
yy = np.linspace(-1, 5, 10)
X1, X2 = np.meshgrid(xx, yy)
Z = np.empty(X1.shape)
for (i, j), val in np.ndenumerate(X1):
x1 = val
x2 = X2[i, j]
p = clf.decision_function([[x1, x2]])
Z[i, j] = p[0]
levels = [-1.0, 0.0, 1.0]
linestyles = ["dashed", "solid", "dashed"]
colors = "k"
plt.contour(X1, X2, Z, levels, colors=colors, linestyles=linestyles)
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolor="black", s=20)
plt.axis("tight")
plt.show()