Plot Gradient Boosting RegularizationΒΆ

================================ Gradient Boosting regularizationΒΆ

Illustration of the effect of different regularization strategies for Gradient Boosting. The example is taken from Hastie et al 2009 [1]_.

The loss function used is binomial deviance. Regularization via shrinkage (learning_rate < 1.0) improves performance considerably. In combination with shrinkage, stochastic gradient boosting (subsample < 1.0) can produce more accurate models by reducing the variance via bagging. Subsampling without shrinkage usually does poorly. Another strategy to reduce the variance is by subsampling the features analogous to the random splits in Random Forests (via the max_features parameter).

… [1] T. Hastie, R. Tibshirani and J. Friedman, β€œElements of Statistical Learning Ed. 2”, Springer, 2009.

Imports for Gradient Boosting Regularization StrategiesΒΆ

Regularization in gradient boosting controls overfitting through three complementary mechanisms: shrinkage (learning_rate < 1.0) scales down each tree’s contribution, requiring more iterations but producing better generalization; stochastic gradient boosting (subsample < 1.0) trains each tree on a random subset of the data, introducing bagging-style variance reduction; and feature subsampling (max_features < n_features) restricts each tree to a random subset of features, analogous to Random Forests. Each mechanism independently reduces variance, and combining them (shrinkage + subsampling, or shrinkage + feature subsampling) typically yields the best results.

Key findings from the experiment: The test deviance curves show that shrinkage alone (learning_rate=0.2) significantly outperforms no regularization (learning_rate=1.0). Subsampling without shrinkage performs poorly because each tree overfits its random subsample. The combination of shrinkage with either data subsampling or feature subsampling achieves the lowest test deviance. The staged_predict_proba method efficiently computes test performance at each iteration without refitting. These results, drawn from Hastie et al. (2009), demonstrate that the learning_rate is the most important regularization parameter in gradient boosting – it should almost always be set well below 1.0, with n_estimators increased accordingly.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np

from sklearn import datasets, ensemble
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

X, y = datasets.make_hastie_10_2(n_samples=4000, random_state=1)

# map labels from {-1, 1} to {0, 1}
labels, y = np.unique(y, return_inverse=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=0)

original_params = {
    "n_estimators": 400,
    "max_leaf_nodes": 4,
    "max_depth": None,
    "random_state": 2,
    "min_samples_split": 5,
}

plt.figure()

for label, color, setting in [
    ("No shrinkage", "orange", {"learning_rate": 1.0, "subsample": 1.0}),
    ("learning_rate=0.2", "turquoise", {"learning_rate": 0.2, "subsample": 1.0}),
    ("subsample=0.5", "blue", {"learning_rate": 1.0, "subsample": 0.5}),
    (
        "learning_rate=0.2, subsample=0.5",
        "gray",
        {"learning_rate": 0.2, "subsample": 0.5},
    ),
    (
        "learning_rate=0.2, max_features=2",
        "magenta",
        {"learning_rate": 0.2, "max_features": 2},
    ),
]:
    params = dict(original_params)
    params.update(setting)

    clf = ensemble.GradientBoostingClassifier(**params)
    clf.fit(X_train, y_train)

    # compute test set deviance
    test_deviance = np.zeros((params["n_estimators"],), dtype=np.float64)

    for i, y_proba in enumerate(clf.staged_predict_proba(X_test)):
        test_deviance[i] = 2 * log_loss(y_test, y_proba[:, 1])

    plt.plot(
        (np.arange(test_deviance.shape[0]) + 1)[::5],
        test_deviance[::5],
        "-",
        color=color,
        label=label,
    )

plt.legend(loc="upper right")
plt.xlabel("Boosting Iterations")
plt.ylabel("Test Set Deviance")

plt.show()