Plot Kmeans PlusplusΒΆ

=========================================================== An example of K-Means++ initializationΒΆ

An example to show the output of the :func:sklearn.cluster.kmeans_plusplus function for generating initial seeds for clustering.

K-Means++ is used as the default initialization for :ref:k_means.

Imports for K-Means++ InitializationΒΆ

K-Means++ is a smart initialization strategy for the K-Means algorithm that selects initial centroids to be spread far apart, dramatically improving convergence speed and final cluster quality. The algorithm works by choosing the first centroid uniformly at random, then selecting each subsequent centroid with probability proportional to the squared distance from the nearest existing centroid – this ensures initial seeds are well-separated across the data distribution. The kmeans_plusplus function exposes this initialization step directly, returning both the initial center coordinates and the indices of the selected seed points.

Why initialization matters: Standard K-Means with random initialization is highly sensitive to the starting centroids and can converge to poor local minima, requiring many restarts (n_init) to find a good solution. K-Means++ provides a provable O(log k) approximation guarantee to the optimal clustering cost, meaning it consistently finds near-optimal starting points in a single pass. This is why scikit-learn uses init='k-means++' as the default for KMeans. In this example, make_blobs generates 4 well-separated Gaussian clusters, and the K-Means++ seeds are visualized as blue dots that land near the true cluster centers – demonstrating the algorithm’s ability to spread initial centroids across the natural structure of the data.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt

from sklearn.cluster import kmeans_plusplus
from sklearn.datasets import make_blobs

# Generate sample data
n_samples = 4000
n_components = 4

X, y_true = make_blobs(
    n_samples=n_samples, centers=n_components, cluster_std=0.60, random_state=0
)
X = X[:, ::-1]

# Calculate seeds from k-means++
centers_init, indices = kmeans_plusplus(X, n_clusters=4, random_state=0)

# Plot init seeds along side sample data
plt.figure(1)
colors = ["#4EACC5", "#FF9C34", "#4E9A06", "m"]

for k, col in enumerate(colors):
    cluster_data = y_true == k
    plt.scatter(X[cluster_data, 0], X[cluster_data, 1], c=col, marker=".", s=10)

plt.scatter(centers_init[:, 0], centers_init[:, 1], c="b", s=50)
plt.title("K-Means++ Initialization")
plt.xticks([])
plt.yticks([])
plt.show()