Plot Mean ShiftΒΆ

============================================= A demo of the mean-shift clustering algorithmΒΆ

Reference:

Dorin Comaniciu and Peter Meer, β€œMean Shift: A robust approach toward feature space analysis”. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2002. pp. 603-619.

Imports for Mean Shift ClusteringΒΆ

Mean Shift is a non-parametric, mode-seeking algorithm that iteratively shifts each data point toward the densest region in its neighborhood until convergence, with points that converge to the same mode forming a cluster. The algorithm uses a kernel density estimate (KDE) with a specified bandwidth to define local density, then computes the mean of all points within the bandwidth window and shifts the center to that mean – repeating until the shift distance falls below a tolerance. The key parameter is bandwidth, which controls the scale of the density estimation and directly determines the number and size of discovered clusters.

Automatic bandwidth selection and scalability: The estimate_bandwidth function automatically selects an appropriate bandwidth by computing the median of pairwise distances within a quantile of nearest neighbors, providing a data-driven default. Setting bin_seeding=True discretizes the feature space into a grid, using only bin centers as initial seeds rather than every data point, which dramatically speeds up computation for large datasets. Mean Shift naturally determines the number of clusters (no need to specify k) and can discover arbitrarily shaped clusters, making it popular in computer vision applications like image segmentation and object tracking where the mode structure of color or feature histograms reveals meaningful regions.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np

from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs

# %%
# Generate sample data
# --------------------
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# %%
# Compute clustering with MeanShift
# ---------------------------------

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# %%
# Plot result
# -----------
import matplotlib.pyplot as plt

plt.figure(1)
plt.clf()

colors = ["#dede00", "#377eb8", "#f781bf"]
markers = ["x", "o", "^"]

for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
    plt.plot(
        cluster_center[0],
        cluster_center[1],
        markers[k],
        markerfacecolor=col,
        markeredgecolor="k",
        markersize=14,
    )
plt.title("Estimated number of clusters: %d" % n_clusters_)
plt.show()