Model Interpretability: SHAP, Permutation Importance & Partial DependenceΒΆ
Why did the model predict that? Explain black-box models using model-agnostic techniques that work on any sklearn estimator.
1. Types of InterpretabilityΒΆ
Global interpretability answers: βWhat does the model rely on overall?β
Feature importance, partial dependence plots
Describes the modelβs general behavior across all predictions
Local interpretability answers: βWhy did the model predict THIS for THIS sample?β
SHAP values, LIME
Explains individual predictions
Interpretability Methods
βββ Model-specific
β βββ Linear models: coefficients
β βββ Decision trees: tree structure
β βββ Random forests: impurity-based importance
βββ Model-agnostic (work on ANY model)
βββ Global: Permutation Importance, PDP, ALE
βββ Local: SHAP, LIME
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
# Load dataset
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Train Random Forest
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
print(f"RF Test Accuracy: {rf.score(X_test, y_test):.4f}")
print(f"Features: {len(feature_names)}")
2. Built-in Feature Importance (Random Forest)ΒΆ
# Mean decrease in impurity (MDI) β fast but biased toward high-cardinality features
importances = pd.Series(rf.feature_importances_, index=feature_names)
importances_sorted = importances.sort_values(ascending=True)
# Compute standard deviation across trees
std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis=0)
std_sorted = pd.Series(std, index=feature_names).reindex(importances_sorted.index)
fig, ax = plt.subplots(figsize=(8, 9))
y_pos = range(len(importances_sorted))
ax.barh(y_pos, importances_sorted.values, xerr=std_sorted.values,
align='center', color='steelblue', alpha=0.8, ecolor='gray')
ax.set_yticks(y_pos)
ax.set_yticklabels(importances_sorted.index, fontsize=8)
ax.set_xlabel('Feature Importance (Mean Decrease in Impurity)')
ax.set_title('Random Forest Built-in Feature Importance\n(error bars = std across trees)')
ax.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.show()
print("Top 5 features:")
for feat, imp in importances.nlargest(5).items():
print(f" {feat:35s}: {imp:.4f}")
3. Permutation Importance: Model-Agnostic and UnbiasedΒΆ
How it works: Randomly shuffle one featureβs values, measure how much model performance drops. A big drop = that feature matters. Zero drop = the model doesnβt use it.
Advantage over MDI: Works on any model, and doesnβt favor high-cardinality features.
from sklearn.inspection import permutation_importance
# Run permutation importance on TEST set (not train β avoids memorization)
result = permutation_importance(
rf, X_test, y_test,
n_repeats=30,
random_state=42,
n_jobs=-1,
scoring='accuracy'
)
perm_imp = pd.Series(result.importances_mean, index=feature_names)
perm_std = pd.Series(result.importances_std, index=feature_names)
# Compare MDI vs Permutation importance
top_perm = perm_imp.nlargest(10)
top_mdi = importances.nlargest(10)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
for ax, data_dict, title, color in [
(axes[0], top_mdi.sort_values(ascending=True), 'MDI (built-in)', 'steelblue'),
(axes[1], top_perm.sort_values(ascending=True), 'Permutation (test set)', 'darkorange')
]:
y_pos = range(len(data_dict))
ax.barh(y_pos, data_dict.values, color=color, alpha=0.8)
ax.set_yticks(y_pos)
ax.set_yticklabels(data_dict.index, fontsize=8)
ax.set_title(f'Top 10: {title}')
ax.set_xlabel('Importance')
ax.grid(True, axis='x', alpha=0.3)
plt.suptitle('MDI vs Permutation Importance β may differ for correlated features!', fontsize=12)
plt.tight_layout()
plt.show()
print("Key differences:")
print(" MDI: fast, biased, reflects training data")
print(" Permutation: slow, unbiased, reflects test performance")
print(" When features are correlated: MDI splits importance across them, Permutation shows true impact")
4. SHAP Values: Theoretically Sound Local + Global ExplanationsΒΆ
try:
import shap
# TreeExplainer is optimized for tree-based models (very fast)
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)
# For binary classification, shap_values is a list [class_0_shaps, class_1_shaps]
# We take class 1 (malignant)
if isinstance(shap_values, list):
shap_vals = shap_values[1]
else:
shap_vals = shap_values
print(f"SHAP values shape: {shap_vals.shape} (samples Γ features)")
# Summary plot: global overview
plt.figure(figsize=(9, 7))
shap.summary_plot(shap_vals, X_test, feature_names=feature_names, show=False)
plt.title('SHAP Summary Plot: Feature Impact on Malignant Prediction')
plt.tight_layout()
plt.show()
# Local explanation: single prediction
sample_idx = 0
print(f"\nExplaining prediction for sample {sample_idx}:")
print(f" True label: {'Malignant' if y_test[sample_idx] == 0 else 'Benign'}")
print(f" Predicted: {'Malignant' if rf.predict([X_test[sample_idx]])[0] == 0 else 'Benign'}")
shap.initjs()
# force_plot needs JavaScript; show bar plot instead
top_features = np.argsort(np.abs(shap_vals[sample_idx]))[-10:]
plt.figure(figsize=(8, 4))
plt.barh(
[feature_names[i] for i in top_features],
shap_vals[sample_idx][top_features],
color=['red' if v > 0 else 'blue' for v in shap_vals[sample_idx][top_features]]
)
plt.xlabel('SHAP value (impact on prediction)')
plt.title(f'Local SHAP: Why sample {sample_idx} was predicted this way')
plt.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.show()
except ImportError:
print("SHAP not installed. Install with: pip install shap")
print("\nWhat SHAP does:")
print(" SHAP (SHapley Additive exPlanations) uses game theory to assign")
print(" each feature a 'credit' for a prediction.")
print("")
print(" For a prediction f(x):")
print(" f(x) = base_value + shap_1 + shap_2 + ... + shap_n")
print("")
print(" Where:")
print(" base_value = average model output on training data")
print(" shap_i = contribution of feature i to this specific prediction")
print("")
print(" Key guarantee: SHAP values are the UNIQUE attribution satisfying")
print(" consistency, dummy, efficiency, and symmetry axioms.")
print("")
print(" Code (when shap is installed):")
print(" explainer = shap.TreeExplainer(rf)")
print(" shap_values = explainer.shap_values(X_test)")
print(" shap.summary_plot(shap_values[1], X_test, feature_names=feature_names)")
5. Partial Dependence Plots (PDP)ΒΆ
PDP shows the marginal effect of one or two features on the predicted outcome, averaging over all other features. Works without SHAP.
from sklearn.inspection import PartialDependenceDisplay
# Top 4 features by permutation importance
top_4_features = perm_imp.nlargest(4).index.tolist()
top_4_indices = [list(feature_names).index(f) for f in top_4_features]
print(f"Plotting PDP for: {top_4_features}")
fig, ax = plt.subplots(figsize=(12, 8))
display = PartialDependenceDisplay.from_estimator(
rf, X_train,
features=top_4_indices,
feature_names=feature_names,
grid_resolution=50,
ax=ax,
kind='both' # 'average' for PDP, 'individual' for ICE
)
plt.suptitle('Partial Dependence Plots (PDP) + Individual Conditional Expectation (ICE)', y=1.02)
plt.tight_layout()
plt.show()
print("Blue line = PDP (average effect)")
print("Gray lines = ICE (per-sample effect β reveals heterogeneity)")
print("")
print("PDP limitation: assumes feature independence (marginalizes over other features).")
print("For correlated features, Accumulated Local Effects (ALE) plots are more accurate.")
6. LIME: Local Surrogate ModelsΒΆ
try:
import lime
import lime.lime_tabular
explainer_lime = lime.lime_tabular.LimeTabularExplainer(
X_train,
feature_names=feature_names,
class_names=['Malignant', 'Benign'],
mode='classification'
)
# Explain one prediction
exp = explainer_lime.explain_instance(
X_test[0],
rf.predict_proba,
num_features=10
)
print("LIME explanation for sample 0:")
for feat, weight in exp.as_list():
direction = '+' if weight > 0 else '-'
print(f" {direction} {feat}: {weight:.4f}")
except ImportError:
print("LIME not installed. Install with: pip install lime")
print("\nHow LIME works:")
print(" 1. Pick the sample to explain: x")
print(" 2. Create perturbed samples around x (random noise)")
print(" 3. Get predictions from the black-box model for perturbed samples")
print(" 4. Weight perturbed samples by proximity to x")
print(" 5. Train a SIMPLE interpretable model (linear regression) on weighted samples")
print(" 6. Linear model coefficients = LIME explanation")
print("")
print(" LIME vs SHAP:")
print(" SHAP: theoretically sound (Shapley values), slower for non-tree models")
print(" LIME: faster, model-agnostic, but less stable (sensitive to perturbation)")
print("")
print(" Code (when lime is installed):")
print(" explainer = lime.lime_tabular.LimeTabularExplainer(X_train, feature_names=feature_names)")
print(" exp = explainer.explain_instance(X_test[0], rf.predict_proba, num_features=10)")
print(" exp.show_in_notebook()")
ExercisesΒΆ
Correlated features test: Create a dataset where two features are perfectly correlated (e.g.,
X2 = X1 + noise). Compare MDI, permutation importance, and SHAP for these features. Which method handles correlation best?PDP interaction: Use
PartialDependenceDisplaywithfeatures=[(0, 1)]to create a 2D PDP (interaction plot). What does a non-linear interaction look like?Model comparison: Train both
RandomForestClassifierandGradientBoostingClassifieron the breast cancer dataset. Use permutation importance on both. Do they agree on which features matter most?Explanation stability: Run LIME 10 times on the same test sample (itβs stochastic). How much do the feature weights vary? Does SHAP give the same answer every time?
Business communication: Pick 3 test samples β one correct, one false positive, one false negative. For each, use permutation importance + PDP to write a 2-sentence explanation a non-technical stakeholder could understand.