Plot Nearest CentroidΒΆ

=============================== Nearest Centroid ClassificationΒΆ

Sample usage of Nearest Centroid classification. It will plot the decision boundaries for each class.

Imports for Nearest Centroid Classification with ShrinkageΒΆ

Nearest centroid as the simplest prototype-based classifier: NearestCentroid computes the centroid (mean) of each class in the training data, then classifies new points by assigning them to the class whose centroid is closest in Euclidean distance. This produces linear decision boundaries (hyperplanes equidistant between pairs of centroids) with O(nd) training time and O(kd) prediction time for k classes and d features – far faster than KNN which must search through all training samples. However, it assumes classes have roughly spherical distributions of similar spread, making it brittle for elongated or multi-modal class distributions.

Shrinkage regularization for high-dimensional robustness: The shrink_threshold parameter applies nearest shrunken centroid regularization, which moves each class centroid toward the overall dataset centroid by shrinking per-feature deviations that fall below the threshold to zero. This performs implicit feature selection – features where the class means do not differ substantially from the global mean are effectively removed from the distance computation. With shrink_threshold=None (no regularization), the full centroid differences determine the boundary; with shrink_threshold=0.2, the boundaries shift as weak discriminating features are suppressed. The DecisionBoundaryDisplay visualizes how shrinkage simplifies the decision regions on the 2D iris subset.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap

from sklearn import datasets
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid

# import some data to play with
iris = datasets.load_iris()
# we only take the first two features. We could avoid this ugly
# slicing by using a two-dim dataset
X = iris.data[:, :2]
y = iris.target

# Create color maps
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])

for shrinkage in [None, 0.2]:
    # we create an instance of Nearest Centroid Classifier and fit the data.
    clf = NearestCentroid(shrink_threshold=shrinkage)
    clf.fit(X, y)
    y_pred = clf.predict(X)
    print(shrinkage, np.mean(y == y_pred))

    _, ax = plt.subplots()
    DecisionBoundaryDisplay.from_estimator(
        clf, X, cmap=cmap_light, ax=ax, response_method="predict"
    )

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20)
    plt.title("3-Class classification (shrink_threshold=%r)" % shrinkage)
    plt.axis("tight")

plt.show()