Plot Successive Halving HeatmapΒΆ
Comparison between grid search and successive halvingΒΆ
This example compares the parameter search performed by
- class:
~sklearn.model_selection.HalvingGridSearchCVand- class:
~sklearn.model_selection.GridSearchCV.
Imports for Successive Halving vs Grid Search ComparisonΒΆ
Successive halving as a resource-efficient search strategy: HalvingGridSearchCV (imported via enable_halving_search_cv) implements a tournament-style elimination: in the first iteration, all candidate hyperparameter configurations are evaluated using a small amount of resources (few training samples). In each subsequent iteration, only the top-performing fraction (1/factor) of candidates survives and receives more resources (more training samples). This multi-fidelity approach is dramatically faster than standard GridSearchCV because most candidates are eliminated early before receiving the full training budget, while the final survivors are evaluated with enough data to reliably identify the best configuration.
Heatmap visualization of the search process: The make_heatmap helper function pivots the cv_results_ DataFrame to create a 2D grid of mean_test_score values across C and gamma hyperparameters for the SVC. For HalvingGridSearchCV, the iteration number overlaid on each cell shows when a configuration was eliminated (low numbers = eliminated early) or survived to the final round (high numbers = strong candidates). Comparing the two heatmaps confirms that successive halving identifies the same optimal region as exhaustive grid search but at a fraction of the computational cost, particularly valuable when the parameter grid is large and model training is expensive.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from time import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.experimental import enable_halving_search_cv # noqa: F401
from sklearn.model_selection import GridSearchCV, HalvingGridSearchCV
from sklearn.svm import SVC
# %%
# We first define the parameter space for an :class:`~sklearn.svm.SVC`
# estimator, and compute the time required to train a
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
# :class:`~sklearn.model_selection.GridSearchCV` instance.
rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {"gamma": gammas, "C": Cs}
clf = SVC(random_state=rng)
tic = time()
gsh = HalvingGridSearchCV(
estimator=clf, param_grid=param_grid, factor=2, random_state=rng
)
gsh.fit(X, y)
gsh_time = time() - tic
tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic
# %%
# We now plot heatmaps for both search estimators.
Heatmap Visualization of Grid Search ResultsΒΆ
Comparing search strategies visually: The make_heatmap function creates a color-coded matrix showing mean_test_score for each (C, gamma) combination. For HalvingGridSearchCV, it extracts the last iteration each candidate participated in via pivot_table with aggfunc="last", and overlays the iteration number as white text on each cell. Configurations that survived to the final iteration (highest number) correspond to the best-performing hyperparameter regions, while low iteration numbers indicate early elimination due to poor performance on the initial small-resource evaluation.
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
"""Helper to make a heatmap."""
results = pd.DataFrame(gs.cv_results_)
results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
np.float64
)
if is_sh:
# SH dataframe: get mean_test_score values for the highest iter
scores_matrix = results.sort_values("iter").pivot_table(
index="param_gamma",
columns="param_C",
values="mean_test_score",
aggfunc="last",
)
else:
scores_matrix = results.pivot(
index="param_gamma", columns="param_C", values="mean_test_score"
)
im = ax.imshow(scores_matrix)
ax.set_xticks(np.arange(len(Cs)))
ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
ax.set_xlabel("C", fontsize=15)
ax.set_yticks(np.arange(len(gammas)))
ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
ax.set_ylabel("gamma", fontsize=15)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
if is_sh:
iterations = results.pivot_table(
index="param_gamma", columns="param_C", values="iter", aggfunc="max"
).values
for i in range(len(gammas)):
for j in range(len(Cs)):
ax.text(
j,
i,
iterations[i, j],
ha="center",
va="center",
color="w",
fontsize=20,
)
if make_cbar:
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)
fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes
make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)
ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15)
ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15)
plt.show()
# %%
# The heatmaps show the mean test score of the parameter combinations for an
# :class:`~sklearn.svm.SVC` instance. The
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
# iteration at which the combinations where last used. The combinations marked
# as ``0`` were only evaluated at the first iteration, while the ones with
# ``5`` are the parameter combinations that are considered the best ones.
#
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
# class is able to find parameter combinations that are just as accurate as
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time.