Plot Digits Kde SamplingΒΆ
========================= Kernel Density EstimationΒΆ
This example shows how kernel density estimation (KDE), a powerful non-parametric density estimation technique, can be used to learn a generative model for a dataset. With this generative model in place, new samples can be drawn. These new samples reflect the underlying model of the data.
Imports for Generative Sampling via Kernel Density EstimationΒΆ
KDE as a non-parametric generative model: KernelDensity estimates the probability density function of the training data by placing a kernel (typically Gaussian) centered on each data point, then summing and normalizing these kernels. Unlike parametric models (e.g., GMMs) that assume a specific distributional form, KDE makes no assumptions about the data distribution and can model arbitrarily complex densities. The bandwidth parameter controls the smoothness β too small captures noise as spurious modes, too large oversmooths and loses structure. GridSearchCV optimizes bandwidth by maximizing the log-likelihood on held-out folds, automatically finding the resolution that best balances bias and variance in density estimation.
Dimensionality reduction enables tractable density estimation: KDE suffers from the curse of dimensionality β in high-dimensional spaces (like the 64-pixel digit images), data points are sparse and kernel overlap diminishes, making density estimates unreliable. PCA with n_components=15 projects the data into a lower-dimensional subspace that retains most of the variance while making density estimation tractable. After fitting the KDE in PCA space, kde.sample(44) draws new points from the estimated density, and pca.inverse_transform maps these synthetic samples back to the original 8x8 pixel space, producing novel digit images that reflect the learned distribution of real digits.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
# load the data
digits = load_digits()
# project the 64-dimensional data to a lower dimension
pca = PCA(n_components=15, whiten=False)
data = pca.fit_transform(digits.data)
# use grid search cross-validation to optimize the bandwidth
params = {"bandwidth": np.logspace(-1, 1, 20)}
grid = GridSearchCV(KernelDensity(), params)
grid.fit(data)
print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
# use the best estimator to compute the kernel density estimate
kde = grid.best_estimator_
# sample 44 new points from the data
new_data = kde.sample(44, random_state=0)
new_data = pca.inverse_transform(new_data)
# turn data into a 4x11 grid
new_data = new_data.reshape((4, 11, -1))
real_data = digits.data[:44].reshape((4, 11, -1))
# plot real digits and resampled digits
fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
for j in range(11):
ax[4, j].set_visible(False)
for i in range(4):
im = ax[i, j].imshow(
real_data[i, j].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
)
im.set_clim(0, 16)
im = ax[i + 5, j].imshow(
new_data[i, j].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
)
im.set_clim(0, 16)
ax[0, 5].set_title("Selection from the input data")
ax[5, 5].set_title('"New" digits drawn from the kernel density model')
plt.show()