Plot Digits AgglomerationΒΆ
========================================================= Feature agglomerationΒΆ
These images show how similar features are merged together using feature agglomeration.
Imports for Feature Agglomeration on Digit ImagesΒΆ
FeatureAgglomeration applies hierarchical clustering to features (columns) rather than samples (rows), merging similar features into super-features to reduce dimensionality. For image data, grid_to_graph creates a connectivity matrix based on pixel spatial adjacency, ensuring that only neighboring pixels can be merged β this preserves the spatial structure of the image while compressing the feature space from 64 pixels down to 32 super-pixels. Each super-feature is the mean of all original features in its cluster.
Dimensionality reduction with spatial awareness: Unlike PCA or other global projections, feature agglomeration respects the local spatial relationships between pixels, producing interpretable compressed representations where each super-pixel corresponds to a contiguous image region. The inverse_transform method maps the reduced representation back to the original feature space, revealing the information loss as a blurred version of the original images. The labels_ attribute shows which original pixels were grouped together, visualized as a color-coded segmentation map. This technique is particularly useful for high-dimensional structured data (images, time series, spatial grids) where neighboring features are expected to carry redundant information.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import matplotlib.pyplot as plt
import numpy as np
from sklearn import cluster, datasets
from sklearn.feature_extraction.image import grid_to_graph
digits = datasets.load_digits()
images = digits.images
X = np.reshape(images, (len(images), -1))
connectivity = grid_to_graph(*images[0].shape)
agglo = cluster.FeatureAgglomeration(connectivity=connectivity, n_clusters=32)
agglo.fit(X)
X_reduced = agglo.transform(X)
X_restored = agglo.inverse_transform(X_reduced)
images_restored = np.reshape(X_restored, images.shape)
plt.figure(1, figsize=(4, 3.5))
plt.clf()
plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.91)
for i in range(4):
plt.subplot(3, 4, i + 1)
plt.imshow(images[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest")
plt.xticks(())
plt.yticks(())
if i == 1:
plt.title("Original data")
plt.subplot(3, 4, 4 + i + 1)
plt.imshow(images_restored[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest")
if i == 1:
plt.title("Agglomerated data")
plt.xticks(())
plt.yticks(())
plt.subplot(3, 4, 10)
plt.imshow(
np.reshape(agglo.labels_, images[0].shape),
interpolation="nearest",
cmap=plt.cm.nipy_spectral,
)
plt.xticks(())
plt.yticks(())
plt.title("Labels")
plt.show()