Plot Custom KernelΒΆ
====================== SVM with custom kernelΒΆ
Simple usage of Support Vector Machines to classify a sample. It will plot the decision surface and the support vectors.
Imports for SVM with a Custom Kernel FunctionΒΆ
Scikit-learnβs SVC accepts user-defined kernel functions, enabling domain-specific similarity measures beyond the built-in linear, polynomial, and RBF options. A custom kernel is simply a Python function that takes two data matrices X and Y and returns their pairwise similarity matrix K(X, Y). This flexibility allows encoding expert knowledge about the problem β for example, string kernels for text, graph kernels for molecular data, or weighted feature spaces as demonstrated here.
The custom kernel in this example applies an anisotropic scaling by multiplying through a diagonal matrix that weights the first feature twice as heavily as the second. This is equivalent to stretching the feature space along the first axis before computing a linear kernel, giving the SVM more sensitivity to differences in that dimension.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, svm
from sklearn.inspection import DecisionBoundaryDisplay
# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2] # we only take the first two features. We could
# avoid this ugly slicing by using a two-dim dataset
Y = iris.target
My KernelΒΆ
We create a custom kernel:
(2 0)
k(X, Y) = X ( ) Y.T
(0 1)
def my_kernel(X, Y):
"""
We create a custom kernel:
(2 0)
k(X, Y) = X ( ) Y.T
(0 1)
"""
M = np.array([[2, 0], [0, 1.0]])
return np.dot(np.dot(X, M), Y.T)
h = 0.02 # step size in the mesh
# we create an instance of SVM and fit out data.
clf = svm.SVC(kernel=my_kernel)
clf.fit(X, Y)
ax = plt.gca()
DecisionBoundaryDisplay.from_estimator(
clf,
X,
cmap=plt.cm.Paired,
ax=ax,
response_method="predict",
plot_method="pcolormesh",
shading="auto",
)
# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolors="k")
plt.title("3-Class classification using Support Vector Machine with custom kernel")
plt.axis("tight")
plt.show()