Plot Roc Curve Visualization ApiΒΆ
================================ ROC Curve with Visualization APIΒΆ
Scikit-learn defines a simple API for creating visualizations for machine learning. The key features of this API is to allow for quick plotting and visual adjustments without recalculation. In this example, we will demonstrate how to use the visualization API by comparing ROC curves.
Imports for ROC Curve Comparison Using the Visualization APIΒΆ
RocCurveDisplay.from_estimator computes and caches the ROC curve in a reusable display object, enabling efficient multi-model comparison without recomputation: The ROC curve plots the True Positive Rate (sensitivity) against the False Positive Rate (1 - specificity) at every possible classification threshold, summarizing a binary classifierβs ability to discriminate between classes regardless of the chosen operating point. The from_estimator method internally calls decision_function or predict_proba (depending on what the estimator provides), computes roc_curve, and stores the resulting fpr, tpr arrays along with the AUC score in the display object. Once created, calling .plot(ax=...) on the cached object renders the curve on any matplotlib axes without re-querying the model.
Overlaying ROC curves from different models on the same axes provides an immediate visual comparison of discriminative performance: By passing an existing axes object via the ax parameter β either from plt.gca() or from a previous displayβs .plot() call β multiple curves share the same coordinate system, making it trivial to identify which model dominates across threshold ranges. The curve_kwargs=dict(alpha=0.8) parameter adjusts curve transparency for better readability when curves overlap. In this example, an SVC (which uses decision_function as the scoring method) and a RandomForestClassifier (which uses predict_proba) are compared on a binarized wine dataset, demonstrating that the visualization API abstracts away differences in how models produce scores.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# %%
# Load Data and Train an SVC
# --------------------------
# First, we load the wine dataset and convert it to a binary classification
# problem. Then, we train a support vector classifier on a training dataset.
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = load_wine(return_X_y=True)
y = y == 2
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
# %%
# Plotting the ROC Curve
# ----------------------
# Next, we plot the ROC curve with a single call to
# :func:`sklearn.metrics.RocCurveDisplay.from_estimator`. The returned
# `svc_disp` object allows us to continue using the already computed ROC curve
# for the SVC in future plots.
svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
plt.show()
# %%
# Training a Random Forest and Plotting the ROC Curve
# ---------------------------------------------------
# We train a random forest classifier and create a plot comparing it to the SVC
# ROC curve. Notice how `svc_disp` uses
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve
# without recomputing the values of the roc curve itself. Furthermore, we
# pass `alpha=0.8` to the plot functions to adjust the alpha values of the
# curves.
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(
rfc, X_test, y_test, ax=ax, curve_kwargs=dict(alpha=0.8)
)
svc_disp.plot(ax=ax, curve_kwargs=dict(alpha=0.8))
plt.show()