Model Interpretability: SHAP, Permutation Importance & Partial DependenceΒΆ

Why did the model predict that? Explain black-box models using model-agnostic techniques that work on any sklearn estimator.

1. Types of InterpretabilityΒΆ

Global interpretability answers: β€œWhat does the model rely on overall?”

  • Feature importance, partial dependence plots

  • Describes the model’s general behavior across all predictions

Local interpretability answers: β€œWhy did the model predict THIS for THIS sample?”

  • SHAP values, LIME

  • Explains individual predictions

Interpretability Methods
β”œβ”€β”€ Model-specific
β”‚   β”œβ”€β”€ Linear models: coefficients
β”‚   β”œβ”€β”€ Decision trees: tree structure
β”‚   └── Random forests: impurity-based importance
└── Model-agnostic (work on ANY model)
    β”œβ”€β”€ Global: Permutation Importance, PDP, ALE
    └── Local: SHAP, LIME
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

# Load dataset
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Train Random Forest
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)

print(f"RF Test Accuracy: {rf.score(X_test, y_test):.4f}")
print(f"Features: {len(feature_names)}")

2. Built-in Feature Importance (Random Forest)ΒΆ

# Mean decrease in impurity (MDI) β€” fast but biased toward high-cardinality features
importances = pd.Series(rf.feature_importances_, index=feature_names)
importances_sorted = importances.sort_values(ascending=True)

# Compute standard deviation across trees
std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis=0)
std_sorted = pd.Series(std, index=feature_names).reindex(importances_sorted.index)

fig, ax = plt.subplots(figsize=(8, 9))
y_pos = range(len(importances_sorted))
ax.barh(y_pos, importances_sorted.values, xerr=std_sorted.values,
        align='center', color='steelblue', alpha=0.8, ecolor='gray')
ax.set_yticks(y_pos)
ax.set_yticklabels(importances_sorted.index, fontsize=8)
ax.set_xlabel('Feature Importance (Mean Decrease in Impurity)')
ax.set_title('Random Forest Built-in Feature Importance\n(error bars = std across trees)')
ax.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

print("Top 5 features:")
for feat, imp in importances.nlargest(5).items():
    print(f"  {feat:35s}: {imp:.4f}")

3. Permutation Importance: Model-Agnostic and UnbiasedΒΆ

How it works: Randomly shuffle one feature’s values, measure how much model performance drops. A big drop = that feature matters. Zero drop = the model doesn’t use it.

Advantage over MDI: Works on any model, and doesn’t favor high-cardinality features.

from sklearn.inspection import permutation_importance

# Run permutation importance on TEST set (not train β€” avoids memorization)
result = permutation_importance(
    rf, X_test, y_test,
    n_repeats=30,
    random_state=42,
    n_jobs=-1,
    scoring='accuracy'
)

perm_imp = pd.Series(result.importances_mean, index=feature_names)
perm_std = pd.Series(result.importances_std, index=feature_names)

# Compare MDI vs Permutation importance
top_perm = perm_imp.nlargest(10)
top_mdi  = importances.nlargest(10)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, data_dict, title, color in [
    (axes[0], top_mdi.sort_values(ascending=True), 'MDI (built-in)', 'steelblue'),
    (axes[1], top_perm.sort_values(ascending=True), 'Permutation (test set)', 'darkorange')
]:
    y_pos = range(len(data_dict))
    ax.barh(y_pos, data_dict.values, color=color, alpha=0.8)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(data_dict.index, fontsize=8)
    ax.set_title(f'Top 10: {title}')
    ax.set_xlabel('Importance')
    ax.grid(True, axis='x', alpha=0.3)

plt.suptitle('MDI vs Permutation Importance β€” may differ for correlated features!', fontsize=12)
plt.tight_layout()
plt.show()

print("Key differences:")
print("  MDI:         fast, biased, reflects training data")
print("  Permutation: slow, unbiased, reflects test performance")
print("  When features are correlated: MDI splits importance across them, Permutation shows true impact")

4. SHAP Values: Theoretically Sound Local + Global ExplanationsΒΆ

try:
    import shap
    
    # TreeExplainer is optimized for tree-based models (very fast)
    explainer = shap.TreeExplainer(rf)
    shap_values = explainer.shap_values(X_test)
    
    # For binary classification, shap_values is a list [class_0_shaps, class_1_shaps]
    # We take class 1 (malignant)
    if isinstance(shap_values, list):
        shap_vals = shap_values[1]
    else:
        shap_vals = shap_values
    
    print(f"SHAP values shape: {shap_vals.shape}  (samples Γ— features)")
    
    # Summary plot: global overview
    plt.figure(figsize=(9, 7))
    shap.summary_plot(shap_vals, X_test, feature_names=feature_names, show=False)
    plt.title('SHAP Summary Plot: Feature Impact on Malignant Prediction')
    plt.tight_layout()
    plt.show()
    
    # Local explanation: single prediction
    sample_idx = 0
    print(f"\nExplaining prediction for sample {sample_idx}:")
    print(f"  True label: {'Malignant' if y_test[sample_idx] == 0 else 'Benign'}")
    print(f"  Predicted:  {'Malignant' if rf.predict([X_test[sample_idx]])[0] == 0 else 'Benign'}")
    
    shap.initjs()
    # force_plot needs JavaScript; show bar plot instead
    top_features = np.argsort(np.abs(shap_vals[sample_idx]))[-10:]
    plt.figure(figsize=(8, 4))
    plt.barh(
        [feature_names[i] for i in top_features],
        shap_vals[sample_idx][top_features],
        color=['red' if v > 0 else 'blue' for v in shap_vals[sample_idx][top_features]]
    )
    plt.xlabel('SHAP value (impact on prediction)')
    plt.title(f'Local SHAP: Why sample {sample_idx} was predicted this way')
    plt.grid(True, axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("SHAP not installed. Install with: pip install shap")
    print("\nWhat SHAP does:")
    print("  SHAP (SHapley Additive exPlanations) uses game theory to assign")
    print("  each feature a 'credit' for a prediction.")
    print("")
    print("  For a prediction f(x):")
    print("    f(x) = base_value + shap_1 + shap_2 + ... + shap_n")
    print("")
    print("  Where:")
    print("    base_value = average model output on training data")
    print("    shap_i     = contribution of feature i to this specific prediction")
    print("")
    print("  Key guarantee: SHAP values are the UNIQUE attribution satisfying")
    print("  consistency, dummy, efficiency, and symmetry axioms.")
    print("")
    print("  Code (when shap is installed):")
    print("    explainer = shap.TreeExplainer(rf)")
    print("    shap_values = explainer.shap_values(X_test)")
    print("    shap.summary_plot(shap_values[1], X_test, feature_names=feature_names)")

5. Partial Dependence Plots (PDP)ΒΆ

PDP shows the marginal effect of one or two features on the predicted outcome, averaging over all other features. Works without SHAP.

from sklearn.inspection import PartialDependenceDisplay

# Top 4 features by permutation importance
top_4_features = perm_imp.nlargest(4).index.tolist()
top_4_indices  = [list(feature_names).index(f) for f in top_4_features]

print(f"Plotting PDP for: {top_4_features}")

fig, ax = plt.subplots(figsize=(12, 8))
display = PartialDependenceDisplay.from_estimator(
    rf, X_train,
    features=top_4_indices,
    feature_names=feature_names,
    grid_resolution=50,
    ax=ax,
    kind='both'  # 'average' for PDP, 'individual' for ICE
)
plt.suptitle('Partial Dependence Plots (PDP) + Individual Conditional Expectation (ICE)', y=1.02)
plt.tight_layout()
plt.show()

print("Blue line = PDP (average effect)")
print("Gray lines = ICE (per-sample effect β€” reveals heterogeneity)")
print("")
print("PDP limitation: assumes feature independence (marginalizes over other features).")
print("For correlated features, Accumulated Local Effects (ALE) plots are more accurate.")

6. LIME: Local Surrogate ModelsΒΆ

try:
    import lime
    import lime.lime_tabular
    
    explainer_lime = lime.lime_tabular.LimeTabularExplainer(
        X_train,
        feature_names=feature_names,
        class_names=['Malignant', 'Benign'],
        mode='classification'
    )
    
    # Explain one prediction
    exp = explainer_lime.explain_instance(
        X_test[0],
        rf.predict_proba,
        num_features=10
    )
    
    print("LIME explanation for sample 0:")
    for feat, weight in exp.as_list():
        direction = '+' if weight > 0 else '-'
        print(f"  {direction} {feat}: {weight:.4f}")
    
except ImportError:
    print("LIME not installed. Install with: pip install lime")
    print("\nHow LIME works:")
    print("  1. Pick the sample to explain: x")
    print("  2. Create perturbed samples around x (random noise)")
    print("  3. Get predictions from the black-box model for perturbed samples")
    print("  4. Weight perturbed samples by proximity to x")
    print("  5. Train a SIMPLE interpretable model (linear regression) on weighted samples")
    print("  6. Linear model coefficients = LIME explanation")
    print("")
    print("  LIME vs SHAP:")
    print("    SHAP: theoretically sound (Shapley values), slower for non-tree models")
    print("    LIME: faster, model-agnostic, but less stable (sensitive to perturbation)")
    print("")
    print("  Code (when lime is installed):")
    print("    explainer = lime.lime_tabular.LimeTabularExplainer(X_train, feature_names=feature_names)")
    print("    exp = explainer.explain_instance(X_test[0], rf.predict_proba, num_features=10)")
    print("    exp.show_in_notebook()")

ExercisesΒΆ

  1. Correlated features test: Create a dataset where two features are perfectly correlated (e.g., X2 = X1 + noise). Compare MDI, permutation importance, and SHAP for these features. Which method handles correlation best?

  2. PDP interaction: Use PartialDependenceDisplay with features=[(0, 1)] to create a 2D PDP (interaction plot). What does a non-linear interaction look like?

  3. Model comparison: Train both RandomForestClassifier and GradientBoostingClassifier on the breast cancer dataset. Use permutation importance on both. Do they agree on which features matter most?

  4. Explanation stability: Run LIME 10 times on the same test sample (it’s stochastic). How much do the feature weights vary? Does SHAP give the same answer every time?

  5. Business communication: Pick 3 test samples β€” one correct, one false positive, one false negative. For each, use permutation importance + PDP to write a 2-sentence explanation a non-technical stakeholder could understand.