Plot Coin Ward SegmentationΒΆ

====================================================================== A demo of structured Ward hierarchical clustering on an image of coinsΒΆ

Compute the segmentation of a 2D image with Ward hierarchical clustering. The clustering is spatially constrained in order for each segmented region to be in one piece.

Imports for Ward Hierarchical Clustering on Image SegmentationΒΆ

Structured Ward clustering applies hierarchical agglomerative clustering to image data with spatial connectivity constraints, ensuring each segmented region forms a contiguous spatial block. The grid_to_graph function builds a connectivity matrix from the image’s pixel grid where each pixel connects only to its immediate spatial neighbors. When passed to AgglomerativeClustering with linkage='ward', the algorithm is restricted to only merge adjacent regions, preventing it from grouping visually similar but spatially distant pixels into the same cluster.

How Ward linkage works for image segmentation: At each step, Ward linkage merges the pair of spatially adjacent regions whose fusion minimizes the increase in total within-cluster variance (sum of squared pixel intensity deviations from the region mean). Using only a 1D feature (pixel intensity) with spatial constraints, this approach effectively partitions the image into regions of homogeneous intensity. The Gaussian smoothing and downscaling preprocessing steps reduce noise and computational cost – Ward clustering on a full-resolution image would be prohibitively expensive since its complexity depends on the number of pixels and connectivity edges. The result segments each coin as a distinct region, though background heterogeneity may require setting n_clusters higher than the actual number of objects.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

# %%
# Generate data
# -------------

from skimage.data import coins

orig_coins = coins()

# %%
# Resize it to 20% of the original size to speed up the processing
# Applying a Gaussian filter for smoothing prior to down-scaling
# reduces aliasing artifacts.

import numpy as np
from scipy.ndimage import gaussian_filter
from skimage.transform import rescale

smoothened_coins = gaussian_filter(orig_coins, sigma=2)
rescaled_coins = rescale(
    smoothened_coins,
    0.2,
    mode="reflect",
    anti_aliasing=False,
)

X = np.reshape(rescaled_coins, (-1, 1))

# %%
# Define structure of the data
# ----------------------------
#
# Pixels are connected to their neighbors.

from sklearn.feature_extraction.image import grid_to_graph

connectivity = grid_to_graph(*rescaled_coins.shape)

# %%
# Compute clustering
# ------------------

import time as time

from sklearn.cluster import AgglomerativeClustering

print("Compute structured hierarchical clustering...")
st = time.time()
n_clusters = 27  # number of regions
ward = AgglomerativeClustering(
    n_clusters=n_clusters, linkage="ward", connectivity=connectivity
)
ward.fit(X)
label = np.reshape(ward.labels_, rescaled_coins.shape)
print(f"Elapsed time: {time.time() - st:.3f}s")
print(f"Number of pixels: {label.size}")
print(f"Number of clusters: {np.unique(label).size}")

# %%
# Plot the results on an image
# ----------------------------
#
# Agglomerative clustering is able to segment each coin however, we have had to
# use a ``n_cluster`` larger than the number of coins because the segmentation
# is finding a large in the background.

import matplotlib.pyplot as plt

plt.figure(figsize=(5, 5))
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
for l in range(n_clusters):
    plt.contour(
        label == l,
        colors=[
            plt.cm.nipy_spectral(l / float(n_clusters)),
        ],
    )
plt.axis("off")
plt.show()