Plot Caching Nearest NeighborsΒΆ
========================= Caching nearest neighborsΒΆ
This example demonstrates how to precompute the k nearest neighbors before using them in KNeighborsClassifier. KNeighborsClassifier can compute the nearest neighbors internally, but precomputing them can have several benefits, such as finer parameter control, caching for multiple use, or custom implementations.
Here we use the caching property of pipelines to cache the nearest neighbors graph between multiple fits of KNeighborsClassifier. The first call is slow since it computes the neighbors graph, while subsequent calls are faster as they do not need to recompute the graph. Here the durations are small since the dataset is small, but the gain can be more substantial when the dataset grows larger, or when the grid of parameter to search is large.
Imports for Caching Nearest Neighbors in Grid Search PipelinesΒΆ
Precomputing the neighbor graph to avoid redundant computation: KNeighborsTransformer computes the k-nearest neighbors graph once as a sparse distance matrix, which KNeighborsClassifier with metric="precomputed" then uses directly for classification. When tuning n_neighbors via GridSearchCV, the classifier only needs to filter the precomputed graph (selecting fewer neighbors from the already-computed set) rather than recomputing distances from scratch for each hyperparameter value. The transformer is initialized with n_neighbors=max(n_neighbors_list) to ensure the graph contains enough neighbors for all grid search configurations.
Pipeline memory caching eliminates repeated graph computation: The Pipeline with memory=tmpdir parameter uses joblib to cache the output of each pipeline step to disk. During GridSearchCVβs cross-validation, the KNeighborsTransformer step produces identical output for each fold regardless of the classifierβs n_neighbors setting, so the cached result is reused across all 9 hyperparameter values. Without caching, the neighbor graph would be recomputed 9 times per fold. This pattern becomes critical for large datasets where the O(n * k * d) neighbor search dominates the total computation time, making the one-time graph computation plus disk I/O far cheaper than repeated exact searches.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from tempfile import TemporaryDirectory
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier, KNeighborsTransformer
from sklearn.pipeline import Pipeline
X, y = load_digits(return_X_y=True)
n_neighbors_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# The transformer computes the nearest neighbors graph using the maximum number
# of neighbors necessary in the grid search. The classifier model filters the
# nearest neighbors graph as required by its own n_neighbors parameter.
graph_model = KNeighborsTransformer(n_neighbors=max(n_neighbors_list), mode="distance")
classifier_model = KNeighborsClassifier(metric="precomputed")
# Note that we give `memory` a directory to cache the graph computation
# that will be used several times when tuning the hyperparameters of the
# classifier.
with TemporaryDirectory(prefix="sklearn_graph_cache_") as tmpdir:
full_model = Pipeline(
steps=[("graph", graph_model), ("classifier", classifier_model)], memory=tmpdir
)
param_grid = {"classifier__n_neighbors": n_neighbors_list}
grid_model = GridSearchCV(full_model, param_grid)
grid_model.fit(X, y)
# Plot the results of the grid search.
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].errorbar(
x=n_neighbors_list,
y=grid_model.cv_results_["mean_test_score"],
yerr=grid_model.cv_results_["std_test_score"],
)
axes[0].set(xlabel="n_neighbors", title="Classification accuracy")
axes[1].errorbar(
x=n_neighbors_list,
y=grid_model.cv_results_["mean_fit_time"],
yerr=grid_model.cv_results_["std_fit_time"],
color="r",
)
axes[1].set(xlabel="n_neighbors", title="Fit time (with caching)")
fig.tight_layout()
plt.show()