Plot Dict Face PatchesΒΆ
Online learning of a dictionary of parts of facesΒΆ
This example uses a large dataset of faces to learn a set of 20 x 20 images patches that constitute faces.
From the programming standpoint, it is interesting because it shows how to use the online API of the scikit-learn to process a very large dataset by chunks. The way we proceed is that we load an image at a time and extract randomly 50 patches from this image. Once we have accumulated 500 of these patches (using 10 images), we run the
- func:
~sklearn.cluster.MiniBatchKMeans.partial_fitmethod of the online KMeans object, MiniBatchKMeans.
The verbose setting on the MiniBatchKMeans enables us to see that some clusters are reassigned during the successive calls to partial-fit. This is because the number of patches that they represent has become too low, and it is better to choose a random new cluster.
Imports for Online Dictionary Learning of Face PatchesΒΆ
Online MiniBatchKMeans learns a visual dictionary of face patches by processing data in chunks rather than loading the entire dataset into memory. The partial_fit method updates cluster centroids incrementally: image patches are extracted from the Olivetti faces dataset using extract_patches_2d, accumulated into buffers of 500 patches (from 10 images), mean-centered and standardized, then used to update the K-Means centroids. This online learning pattern is essential for datasets too large to fit in memory, and each partial_fit call is a stochastic gradient step toward the optimal centroids.
Visual dictionary interpretation: The 81 learned cluster centers (arranged in a 9x9 grid) represent the most common 20x20 pixel patterns found across all face images β edges, textures, and partial facial features like eye corners, nose bridges, and mouth curves. This is conceptually similar to how convolutional neural networks learn filter banks, but here the dictionary is learned purely through unsupervised clustering. The verbose=True flag reveals cluster reassignments during training, which occur when a cluster becomes too small (representing too few patches) and is reinitialized randomly β a mechanism that prevents βdeadβ clusters and ensures all centroids remain representative.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# %%
# Load the data
# -------------
from sklearn import datasets
faces = datasets.fetch_olivetti_faces()
# %%
# Learn the dictionary of images
# ------------------------------
import time
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.feature_extraction.image import extract_patches_2d
print("Learning the dictionary... ")
rng = np.random.RandomState(0)
kmeans = MiniBatchKMeans(n_clusters=81, random_state=rng, verbose=True, n_init=3)
patch_size = (20, 20)
buffer = []
t0 = time.time()
# The online learning part: cycle over the whole dataset 6 times
index = 0
for _ in range(6):
for img in faces.images:
data = extract_patches_2d(img, patch_size, max_patches=50, random_state=rng)
data = np.reshape(data, (len(data), -1))
buffer.append(data)
index += 1
if index % 10 == 0:
data = np.concatenate(buffer, axis=0)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
kmeans.partial_fit(data)
buffer = []
if index % 100 == 0:
print("Partial fit of %4i out of %i" % (index, 6 * len(faces.images)))
dt = time.time() - t0
print("done in %.2fs." % dt)
# %%
# Plot the results
# ----------------
import matplotlib.pyplot as plt
plt.figure(figsize=(4.2, 4))
for i, patch in enumerate(kmeans.cluster_centers_):
plt.subplot(9, 9, i + 1)
plt.imshow(patch.reshape(patch_size), cmap=plt.cm.gray, interpolation="nearest")
plt.xticks(())
plt.yticks(())
plt.suptitle(
"Patches of faces\nTrain time %.1fs on %d patches" % (dt, 8 * len(faces.images)),
fontsize=16,
)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
plt.show()