Plot Sparse Logistic Regression MnistΒΆ
===================================================== MNIST classification using multinomial logistic + L1ΒΆ
Here we fit a multinomial logistic regression with L1 penalty on a subset of the MNIST digits classification task. We use the SAGA algorithm for this purpose: this a solver that is fast when the number of samples is significantly larger than the number of features and is able to finely optimize non-smooth objective functions which is the case with the l1-penalty. Test accuracy reaches > 0.8, while weight vectors remains sparse and therefore more easily interpretable.
Note that this accuracy of this l1-penalized linear model is significantly below what can be reached by an l2-penalized linear model or a non-linear multi-layer perceptron model on this dataset.
Imports for Sparse L1 Logistic Regression on MNISTΒΆ
Applying L1-penalized logistic regression to image classification produces coefficient vectors that are sparse in pixel space β most pixel weights are driven to exactly zero. When reshaped to 28x28 and visualized, the non-zero coefficients form interpretable spatial patterns showing which pixel locations the model uses to distinguish each digit class. For example, the model for digit β0β might have strong positive weights in the center (where the hole is) and negative weights around the outer ring.
Tradeoff with accuracy: L1 sparsity makes the model more interpretable and resistant to overfitting on irrelevant pixels, but it sacrifices some predictive accuracy compared to L2 regularization or non-linear models like MLPs. The SAGA solver handles the L1 penalty efficiently on the 784-dimensional MNIST feature space. StandardScaler normalizes pixel values, and C=50/n_train sets moderate regularization strength. The resulting ~84% sparse coefficient vectors demonstrate that most pixels are uninformative for digit recognition.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import time
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
# Turn down for faster convergence
t0 = time.time()
train_samples = 5000
# Load data from https://www.openml.org/d/554
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=train_samples, test_size=10000
)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# Turn up tolerance for faster convergence
clf = LogisticRegression(C=50.0 / train_samples, l1_ratio=1, solver="saga", tol=0.1)
clf.fit(X_train, y_train)
sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)
# print('Best C % .4f' % clf.C_)
print("Sparsity with L1 penalty: %.2f%%" % sparsity)
print("Test score with L1 penalty: %.4f" % score)
coef = clf.coef_.copy()
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
l1_plot = plt.subplot(2, 5, i + 1)
l1_plot.imshow(
coef[i].reshape(28, 28),
interpolation="nearest",
cmap=plt.cm.RdBu,
vmin=-scale,
vmax=scale,
)
l1_plot.set_xticks(())
l1_plot.set_yticks(())
l1_plot.set_xlabel(f"Class {i}")
plt.suptitle("Classification vector for...")
run_time = time.time() - t0
print("Example run in %.3f s" % run_time)
plt.show()