import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import cv2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
Advanced Neural Network Interpretability TheoryΒΆ
1. Foundations and TaxonomyΒΆ
Definition: Interpretability is the degree to which a human can understand the cause of a decision made by a model.
Taxonomy of Interpretability:
Property |
Intrinsic |
Post-hoc |
|---|---|---|
Model-specific |
Linear/Tree models |
GradCAM, CAM |
Model-agnostic |
Sparse models |
LIME, SHAP |
Local |
Rule-based |
Single prediction explanation |
Global |
Feature importance |
Overall behavior understanding |
Fidelity vs. Interpretability Trade-off:
High fidelity (complex models): Better performance, harder to interpret
High interpretability (simple models): Easier to understand, may sacrifice performance
Goal: Post-hoc methods bridge this gap
2. Gradient-Based Attribution MethodsΒΆ
2.1 Vanilla Gradients (Saliency Maps)
Definition: Attribution is the gradient of the output w.r.t. input:
Where \(f(x)\) is the model output for input \(x\).
First-order Taylor approximation:
The gradient \(\nabla f\) indicates local sensitivity: how much output changes with small input perturbations.
Limitations:
Saturation: Gradients vanish in saturated regions (ReLU dead neurons, sigmoid plateaus)
Noise: High-frequency artifacts due to local nature
No baseline: Doesnβt distinguish important vs. unimportant features
2.2 Integrated Gradients (IG)
Axioms for attribution:
Sensitivity: If feature differs from baseline and affects output, attribution should be non-zero
Implementation invariance: Functionally equivalent models should have identical attributions
IG Definition:
Where:
\(x'\): Baseline (typically zeros or average image)
\(\alpha \in [0,1]\): Interpolation parameter
Path: Straight line from \(x'\) to \(x\)
Discrete approximation (Riemann sum):
Typical: \(m = 50\) steps
Completeness property:
Attribution sums to the difference between output and baseline.
Proof sketch: By fundamental theorem of calculus:
2.3 SmoothGrad
Problem: Gradients are noisy and sensitive to small perturbations
Solution: Average gradients over noisy samples
Where \(\mathcal{N}(0, \sigma^2)\) is Gaussian noise, typical \(\sigma = 0.15 \times (\max(x) - \min(x))\)
Effect: Reduces high-frequency noise, produces visually cleaner attribution maps
Variants:
SmoothGrad-Squared: \(\frac{1}{n} \sum (\partial f / \partial x)^2\)
VarGrad: Variance of gradients across samples
3. Class Activation Mapping (CAM) MethodsΒΆ
3.1 CAM (Original)
Architecture requirement: Global Average Pooling (GAP) before final classification layer
Where:
\(A^k\): Feature map k from last convolutional layer
\(w_k^c\): Weight connecting feature k to class c
Z: Spatial size (height Γ width)
Class Activation Map:
Interpretation: \(M^c\) highlights regions most discriminative for class c
Limitation: Requires GAP architecture, not applicable to arbitrary networks
3.2 Grad-CAM (Gradient-weighted CAM)
Key innovation: Use gradients to compute importance weights, removing architecture constraint
Gradient-based weights:
Intuition: \(\alpha_k^c\) measures how much feature map k contributes to class c
Grad-CAM:
ReLU: Keeps only positive influences (features increasing class score)
Why this works: Chain rule decomposition:
Where \(z\) are downstream activations. Averaging spatially gives global importance.
3.3 Grad-CAM++ (Improved Localization)
Problem: Grad-CAM produces coarse localization, especially for multiple objects
Solution: Weighted combination of pixel-wise gradients
Pixel-wise weights:
Advantages:
Better localization for multiple occurrences
Handles multiple objects of same class
Provides confidence through gradient magnitude
3.4 Score-CAM (Gradient-Free)
Motivation: Gradients can be noisy and require backpropagation
Approach: Forward-pass only, using activation maps as masks
Algorithm:
For each feature map \(A^k\), upsample to input size: \(M^k\)
Compute masked input: \(X_k = X \odot M^k\) (element-wise product)
Forward pass through network: \(S_k = f(X_k)\)
Weight by increase in target class score: \(\alpha_k^c = S_k^c - f(X_{\text{baseline}})^c\)
Score-CAM:
Advantages:
No gradients needed (faster, more stable)
Works with any differentiable or non-differentiable model
Less sensitive to gradient saturation
4. Perturbation-Based MethodsΒΆ
4.1 Occlusion Sensitivity
Idea: Occlude parts of input, measure output change
Procedure:
Slide patch (e.g., 5Γ5 gray square) over input
At each position, compute output drop
Higher drop β more important region
Advantages:
Model-agnostic
Easy to understand
Limitations:
Computationally expensive: \(O(H \times W)\) forward passes
Patch size affects results
Doesnβt handle feature interactions
4.2 RISE (Randomized Input Sampling for Explanation)
Approach: Generate random masks, aggregate weighted by output
Algorithm:
Generate N random binary masks: \(\{M_1, ..., M_N\}\)
For each mask, forward pass: \(f(x \odot M_i)\)
Weight masks by output: \(w_i = f(x \odot M_i)^c\)
Aggregate: \(S^c = \frac{1}{N} \sum_i w_i \cdot M_i\)
Advantages:
Smooth saliency maps
Probabilistic interpretation
Handles complex interactions
Hyperparameters:
N: Number of masks (typically 2000-8000)
Mask size: Controls granularity
Probability: Fraction of pixels kept (typically 0.5)
5. SHAP (SHapley Additive exPlanations)ΒΆ
5.1 Shapley Values from Game Theory
Setup: Cooperative game with players (features), value function (model output)
Shapley value for feature i:
Where:
N: Set of all features
S: Subset of features
v(S): Model output with features in S
\(v(S \cup \{i\}) - v(S)\): Marginal contribution of feature i to subset S
Axioms satisfied:
Efficiency: \(\sum_i \phi_i = v(N) - v(\emptyset)\) (complete attribution)
Symmetry: If features i, j contribute equally, \(\phi_i = \phi_j\)
Dummy: If feature has no effect, \(\phi_i = 0\)
Additivity: Linear combination of games β linear combination of values
5.2 SHAP for Deep Networks
Challenge: Computing Shapley values requires \(2^n\) evaluations (exponential in features)
DeepSHAP (approximation): Uses backpropagation-like approach to approximate Shapley values
Linear SHAP (for linear models):
Simple: Deviation from mean Γ weight
Kernel SHAP (model-agnostic):
Weighted linear regression to approximate Shapley values:
Where:
\(z\): Binary vector indicating feature presence
\(\pi(z)\): SHAP kernel weight \(\frac{|N|-1}{\binom{|N|}{|z|} |z| (|N| - |z|)}\)
\(h_x(z)\): Maps binary vector to feature values
Complexity: \(O(2^n)\) exact, \(O(n^2)\) with sampling approximations
6. LIME (Local Interpretable Model-agnostic Explanations)ΒΆ
6.1 Core Idea
Approximate complex model locally with interpretable model (linear, decision tree)
Optimization objective:
Where:
\(g \in G\): Interpretable model (e.g., linear)
\(\mathcal{L}\): Fidelity loss (how well g approximates f locally)
\(\pi_x\): Proximity measure (kernel, e.g., exponential)
\(\Omega(g)\): Complexity penalty (e.g., number of features)
6.2 LIME Algorithm
Sample: Generate N perturbed samples around x: \(\{x_1', ..., x_N'\}\)
Evaluate: Get model predictions: \(\{f(x_1'), ..., f(x_N')\}\)
Weight: Compute proximity: \(\pi_x(x_i') = \exp(-D(x, x_i')^2 / \sigma^2)\)
Fit: Train interpretable model g on weighted samples
Explain: Extract coefficients or rules from g
For images:
Segment image into superpixels (e.g., SLIC algorithm)
Perturb by turning superpixels on/off
Fit linear model to predict class probability
Example linear model:
Top-k positive \(w_i\) indicate most important superpixels
6.3 Comparison: LIME vs SHAP
Property |
LIME |
SHAP |
|---|---|---|
Framework |
Local surrogate |
Shapley values |
Guarantees |
None (heuristic) |
Efficiency, symmetry |
Sampling |
Random perturbations |
Coalitional |
Consistency |
Not guaranteed |
Guaranteed |
Speed |
Fast |
Slower (exact) |
Use case |
Quick explanations |
Rigorous attribution |
7. Attention Mechanisms as InterpretabilityΒΆ
7.1 Attention Weights
For Transformer models: Attention weights \(\alpha_{ij}\) indicate βhow much position j contributes to position iβ
Naive interpretation: High \(\alpha_{ij}\) β token j is important for token i
Problem: Attention is NOT explanation
Attention weights donβt necessarily indicate importance
Multiple heads can have different patterns
Downstream layers complicate interpretation
7.2 Attention Rollout
Idea: Propagate attention through layers to see token-to-token flow
Recursive formula:
Where \(A^{(l)}\) is attention at layer l (averaged over heads)
Initialization: \(A_{\text{rollout}}^{(0)} = I\) (identity)
Final attribution: Last layerβs rolled-out attention \(A_{\text{rollout}}^{(L)}\)
7.3 Attention Flow
Max flow through attention graph:
Compute maximum information flow from input to output through attention edges
Algorithm:
Build graph: Nodes = tokens at each layer, edges = attention weights
Run max-flow algorithm from source (input token) to sink (output)
Flow value = attribution
8. Concept-Based InterpretabilityΒΆ
8.1 TCAV (Testing with Concept Activation Vectors)
Motivation: Gradients/attention show low-level features, not high-level concepts
Idea: Test if model uses human-defined concepts (e.g., βstripesβ for zebra)
Concept Activation Vector (CAV): Linear classifier separating concept examples from random examples in activation space
TCAV Score: Fraction of examples where directional derivative along CAV is positive
Where:
\(X_c\): Examples of class c
\(h_l\): Activations at layer l
Interpretation: TCAV = 0.8 means β80% of class c examples have concept Cβ
8.2 Network Dissection
Goal: Identify what individual neurons detect
Approach:
Collect semantic segmentation dataset (labels for concepts: sky, grass, car)
For each neuron, compute IoU with each concept
Assign neuron to highest-IoU concept (if IoU > threshold)
Where:
\(A^n\): Activation map of neuron n
\(t_n\): Threshold (top quantile, e.g., 0.05)
\(L_p = c\): Pixel p labeled as concept c
Results: Neurons specialize to detect objects, textures, parts (e.g., βdog face detectorβ)
9. Counterfactual ExplanationsΒΆ
9.1 Definition
βWhat minimal change to input would flip the modelβs prediction?β
Where:
\(D(x, x')\): Distance metric (e.g., L2, L1, or feature distance)
Constraint: Prediction changes
9.2 Optimization Approach
Loss function:
Where:
First term: Minimize distance to original
Second term: Push to different class (margin \(\kappa\))
Algorithm:
Initialize: \(x' = x\)
Gradient descent on \(\mathcal{L}(x')\)
Project to valid input space (e.g., [0,1] for images)
Stop when \(f(x') \neq f(x)\)
9.3 Diverse Counterfactuals
Generate multiple diverse counterfactuals:
Diversity loss:
Encourages counterfactuals to differ from each other
10. Faithfulness and EvaluationΒΆ
10.1 Deletion/Insertion Metrics
Deletion: Remove features in order of importance, measure output drop
Good attribution: Output drops quickly
Insertion: Add features in order of importance (from blank), measure output rise
Good attribution: Output rises quickly
Area Under Curve (AUC):
Deletion AUC: Higher is better (rapid drop)
Insertion AUC: Higher is better (rapid rise)
10.2 Sensitivity-n
Definition: Maximum sensitivity to n-pixel perturbation
Where \(x_S\) has pixels in set S perturbed
Lower is better: Attribution should not change drastically with small input changes
10.3 Infidelity
Measures: How well attribution approximates actual model behavior
Where:
\(I\): Random perturbation
\(\phi(x)\): Attribution
Expectation over perturbations
Lower is better: Predictions from attribution should match actual perturbations
11. Practical ConsiderationsΒΆ
11.1 Method Selection Guide
Goal |
Method |
Pros |
Cons |
|---|---|---|---|
Quick visualization |
Grad-CAM |
Fast, class-specific |
Coarse resolution |
Pixel-level attribution |
IG, SmoothGrad |
Detailed |
Baseline-dependent |
Model-agnostic |
LIME, SHAP |
Any model |
Slow, approximations |
Conceptual |
TCAV |
Human concepts |
Needs concept datasets |
Debugging |
Activation max, Dissection |
Direct neuron insight |
Requires analysis |
11.2 Common Pitfalls
Confusing correlation with causation: Attribution β causal importance
Ignoring baseline choice: IG results depend heavily on baseline
Over-interpreting attention: Attention weights β feature importance
Not validating: Always check faithfulness metrics
Cherry-picking examples: Show failures, not just successes
11.3 Hyperparameters
Integrated Gradients:
Steps: 50-100 (more = smoother but slower)
Baseline: Zero, mean, blur, or black image
LIME:
Samples: 1000-5000
Kernel width: 0.25 Γ \(\sqrt{n\_features}\)
Superpixels (images): 50-200
SHAP:
Background samples: 50-100 (DeepSHAP)
Max evaluations: 2000 (KernelSHAP)
SmoothGrad:
Noise std: 0.1-0.2 Γ (max - min)
Samples: 20-50
12. Advanced Topics and Research FrontiersΒΆ
12.1 Robustness of Explanations
Problem: Explanations can be fragile to adversarial attacks
Adversarial explanation attack:
Same prediction, different explanation β undermines trust
Defense: Robust attribution methods (e.g., robust integrated gradients)
12.2 Multi-Modal Explanations
For vision-language models (CLIP, GPT-4V):
Cross-modal attention: How image regions align with text tokens
Grad-CAM for vision encoder + attention weights for text
Integrated multimodal gradients
12.3 Explanations for Generated Outputs
Diffusion models: Which parts of noise contribute to final image? LLMs: Token-level attribution for generated text Challenge: Long generation chains complicate credit assignment
12.4 Interactive Explanations
What-if analysis: User perturbs input, sees effect on prediction and explanation Counterfactual exploration: Navigate alternative scenarios Concept activation: Manually adjust concept importance
Key Papers and TimelineΒΆ
Foundation (2013-2016):
Simonyan et al. 2013: Saliency Maps - First gradient-based visualization
Zhou et al. 2016: CAM - Class activation mapping with GAP
Ribeiro et al. 2016: LIME - Local interpretable models
Gradient Methods (2017-2018):
Sundararajan et al. 2017: Integrated Gradients - Axioms and path integration
Selvaraju et al. 2017: Grad-CAM - Gradient-weighted activation maps
Smilkov et al. 2017: SmoothGrad - Noise-based gradient smoothing
Game Theory (2017-2020):
Lundberg & Lee 2017: SHAP - Shapley values for ML
Chattopadhyay et al. 2018: Grad-CAM++ - Improved localization
Wang et al. 2020: Score-CAM - Gradient-free activation maps
Concepts & Validation (2018-2022):
Kim et al. 2018: TCAV - Concept activation vectors
Bau et al. 2017: Network Dissection - Neuron semantic analysis
Hooker et al. 2019: Deletion/Insertion - Faithfulness metrics
Recent (2020-2024):
Chefer et al. 2021: Transformer Interpretability - Attention rollout
Achtibat et al. 2023: Attribution Robustness - Adversarial explanations
Diverse counterfactuals and causal interpretability
Computational Complexity:
Grad-CAM: \(O(1)\) (single backward pass)
Integrated Gradients: \(O(m)\) where m = steps
LIME: \(O(n \cdot k)\) where n = samples, k = model evals
SHAP: \(O(2^d)\) exact, \(O(d^2)\) approximation (d = features)
"""
Advanced Interpretability Implementations
This cell provides production-ready implementations of:
1. SHAP (KernelSHAP, DeepSHAP approximation)
2. LIME for images
3. Attention visualization (rollout, flow)
4. Concept Activation Vectors (TCAV)
5. Counterfactual generation
6. Faithfulness evaluation metrics
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.svm import LinearSVC
from skimage.segmentation import slic
from scipy.special import comb
import warnings
warnings.filterwarnings('ignore')
# ============================================================================
# SHAP Implementation
# ============================================================================
class KernelSHAP:
"""
Model-agnostic SHAP using weighted linear regression
Theory:
- Approximates Shapley values by fitting weighted linear model
- SHAP kernel weight: Ο(z) = (M-1) / (C(M,|z|) * |z| * (M-|z|))
- Where M = number of features, |z| = number of present features
"""
def __init__(self, model, background_data, num_samples=1000):
"""
Args:
model: Prediction function (input -> output)
background_data: Reference samples (N, D) for creating coalitions
num_samples: Number of coalitional samples to generate
"""
self.model = model
self.background = background_data
self.num_samples = num_samples
self.M = background_data.shape[1] # Number of features
def kernel_weight(self, z):
"""SHAP kernel weight for coalition z (binary vector)"""
num_present = z.sum()
if num_present == 0 or num_present == self.M:
return 1e10 # High weight for empty/full coalitions
return (self.M - 1) / (comb(self.M, num_present) * num_present * (self.M - num_present))
def explain(self, x, verbose=False):
"""
Compute SHAP values for input x
Returns:
shap_values: (D,) array of feature attributions
"""
# Sample coalitions (binary vectors indicating feature presence)
coalitions = np.random.binomial(1, 0.5, size=(self.num_samples, self.M))
# Ensure we have empty and full coalitions
coalitions[0] = np.zeros(self.M)
coalitions[1] = np.ones(self.M)
# Create masked samples
masked_samples = []
for z in coalitions:
# For each coalition, average over background samples
sample = np.where(z[:, None] == 1, x, self.background)
masked_samples.append(sample)
masked_samples = np.array(masked_samples) # (num_samples, M, feature_dim)
# Evaluate model on masked samples
with torch.no_grad():
# Average predictions over background samples
predictions = []
for samples in masked_samples:
preds = self.model(torch.tensor(samples, dtype=torch.float32))
predictions.append(preds.mean().item())
predictions = np.array(predictions)
# Compute kernel weights
weights = np.array([self.kernel_weight(z) for z in coalitions])
# Weighted linear regression: f(z) β Οβ + Ξ£ Οα΅’ zα΅’
model = Ridge(alpha=0.01)
model.fit(coalitions, predictions, sample_weight=weights)
# SHAP values are the coefficients
shap_values = model.coef_
if verbose:
print(f"Expected value (intercept): {model.intercept_:.4f}")
print(f"Prediction: {predictions[-1]:.4f}") # Full coalition
print(f"Sum of SHAP values: {shap_values.sum():.4f}")
print(f"Difference (should match): {predictions[-1] - predictions[0]:.4f}")
return shap_values
class DeepSHAP:
"""
DeepSHAP: Efficient SHAP approximation for neural networks
Theory:
- Uses backpropagation-like approach to compute approximate Shapley values
- Compares activation with reference (background) activation
- Distributes attributions through network layers
"""
def __init__(self, model, background_data):
"""
Args:
model: PyTorch neural network
background_data: Reference samples (N, C, H, W)
"""
self.model = model
self.background = background_data
self.model.eval()
# Compute reference activations
with torch.no_grad():
self.ref_output = model(background_data).mean(0)
def explain(self, x, target_class=None):
"""
Compute DeepSHAP attributions
Approximation: Use Integrated Gradients as proxy for Shapley values
IG satisfies efficiency axiom (completeness) like SHAP
"""
x = x.requires_grad_()
# Compute average baseline
baseline = self.background.mean(0, keepdim=True)
# Integrated Gradients (fast approximation to SHAP)
num_steps = 50
attributions = torch.zeros_like(x)
for alpha in np.linspace(0, 1, num_steps):
x_interp = baseline + alpha * (x - baseline)
x_interp.requires_grad_()
output = self.model(x_interp)
if target_class is None:
target_class = output.argmax(1)
output[0, target_class].backward()
attributions += x_interp.grad
attributions = attributions * (x - baseline) / num_steps
return attributions.detach()
# ============================================================================
# LIME Implementation
# ============================================================================
class LIME:
"""
Local Interpretable Model-agnostic Explanations
Theory:
- Approximates complex model f locally with interpretable model g
- Samples around instance x, weights by proximity
- Fits linear model: g(z) = wβ + Ξ£ wα΅’ zα΅’
"""
def __init__(self, model, kernel_width=0.25):
"""
Args:
model: Prediction function
kernel_width: Controls locality (smaller = more local)
"""
self.model = model
self.kernel_width = kernel_width
def explain_image(self, image, num_samples=1000, num_features=10, num_superpixels=100):
"""
Explain image classification using superpixels
Args:
image: Input image (C, H, W) tensor
num_samples: Number of perturbations
num_features: Number of top features to return
num_superpixels: Number of image segments
Returns:
weights: Importance of each superpixel
segments: Superpixel segmentation
"""
# Convert to numpy for segmentation
img_np = image.permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
# Segment image into superpixels using SLIC
segments = slic(img_np, n_segments=num_superpixels, compactness=10, sigma=1, start_label=0)
num_segments = len(np.unique(segments))
# Sample perturbations (binary vectors indicating superpixel presence)
perturbations = np.random.binomial(1, 0.5, size=(num_samples, num_segments))
# Evaluate model on perturbed images
predictions = []
distances = []
for pert in perturbations:
# Create perturbed image
perturbed = image.clone()
for i, keep in enumerate(pert):
if keep == 0:
# Zero out superpixel (black)
mask = segments == i
perturbed[:, mask] = 0
# Get prediction
with torch.no_grad():
pred = self.model(perturbed.unsqueeze(0))
predictions.append(pred.softmax(1).cpu().numpy()[0])
# Compute distance (fraction of superpixels changed)
distance = np.sum(pert != 1) / num_segments
distances.append(distance)
predictions = np.array(predictions)
distances = np.array(distances)
# Compute kernel weights (exponential kernel)
kernel_width = self.kernel_width * np.sqrt(num_segments)
weights = np.exp(-(distances ** 2) / (kernel_width ** 2))
# Fit weighted linear regression for each class
all_weights = []
for class_idx in range(predictions.shape[1]):
y = predictions[:, class_idx]
# Weighted least squares
model = Ridge(alpha=1.0)
model.fit(perturbations, y, sample_weight=weights)
all_weights.append(model.coef_)
# Return weights for predicted class
pred_class = self.model(image.unsqueeze(0)).argmax(1).item()
superpixel_weights = all_weights[pred_class]
return superpixel_weights, segments
def visualize_explanation(self, image, weights, segments, num_features=5):
"""Visualize top positive and negative superpixels"""
# Get top positive features
top_positive = np.argsort(weights)[-num_features:]
# Create mask
mask = np.zeros(segments.shape, dtype=bool)
for idx in top_positive:
mask |= (segments == idx)
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original image
img_np = image.permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
axes[0].imshow(img_np)
axes[0].set_title("Original Image")
axes[0].axis('off')
# Superpixel boundaries
from skimage.segmentation import mark_boundaries
axes[1].imshow(mark_boundaries(img_np, segments))
axes[1].set_title(f"Superpixels (n={len(np.unique(segments))})")
axes[1].axis('off')
# Highlighted regions
highlighted = img_np.copy()
highlighted[~mask] = highlighted[~mask] * 0.3 # Dim non-important regions
axes[2].imshow(highlighted)
axes[2].set_title(f"Top {num_features} Important Regions")
axes[2].axis('off')
plt.tight_layout()
return fig
# ============================================================================
# Attention Visualization
# ============================================================================
class AttentionRollout:
"""
Attention Rollout for Vision Transformers
Theory:
- Recursively multiply attention matrices through layers
- A_rollout^(l) = A^(l) @ A_rollout^(l-1)
- Captures token-to-token information flow
"""
def __init__(self, model, head_fusion='mean', discard_ratio=0.1):
"""
Args:
model: Vision Transformer with attention weights
head_fusion: How to combine multi-head attention ('mean', 'max', 'min')
discard_ratio: Fraction of lowest attention to discard
"""
self.model = model
self.head_fusion = head_fusion
self.discard_ratio = discard_ratio
self.attentions = []
# Register hooks to capture attention weights
self._register_hooks()
def _register_hooks(self):
"""Register forward hooks to capture attention matrices"""
def hook_fn(module, input, output):
# Assuming attention is returned as second element
if isinstance(output, tuple) and len(output) > 1:
self.attentions.append(output[1])
# This is model-specific; adapt to your architecture
# For ViT: hook into each transformer block's attention module
for module in self.model.modules():
if hasattr(module, 'attn') or 'attention' in module.__class__.__name__.lower():
module.register_forward_hook(hook_fn)
def compute_rollout(self, input_tensor):
"""
Compute attention rollout
Returns:
rollout: (num_patches, num_patches) attention flow matrix
"""
self.attentions = []
# Forward pass to collect attentions
with torch.no_grad():
_ = self.model(input_tensor)
# Fuse multi-head attention
fused_attentions = []
for attn in self.attentions:
# attn shape: (batch, heads, tokens, tokens)
if self.head_fusion == 'mean':
attn_fused = attn.mean(dim=1)
elif self.head_fusion == 'max':
attn_fused = attn.max(dim=1)[0]
elif self.head_fusion == 'min':
attn_fused = attn.min(dim=1)[0]
else:
raise ValueError(f"Unknown fusion: {self.head_fusion}")
fused_attentions.append(attn_fused)
# Rollout: multiply attention matrices
rollout = torch.eye(fused_attentions[0].size(1)).to(input_tensor.device)
for attn in fused_attentions:
# Add residual connection
attn = attn + torch.eye(attn.size(1)).to(attn.device)
attn = attn / attn.sum(dim=-1, keepdim=True)
# Multiply
rollout = torch.matmul(attn[0], rollout)
# Focus on CLS token (first token) to all patches
rollout_cls = rollout[0, 1:] # Exclude CLS itself
return rollout_cls
# ============================================================================
# Concept Activation Vectors (TCAV)
# ============================================================================
class TCAV:
"""
Testing with Concept Activation Vectors
Theory:
- Train linear classifier to separate concept from random examples
- CAV = normal vector to decision boundary
- TCAV score = fraction of examples where βhΒ·CAV > 0
"""
def __init__(self, model, layer_name):
"""
Args:
model: Neural network
layer_name: Name of layer to extract activations from
"""
self.model = model
self.layer_name = layer_name
self.activations = None
# Register hook
self._register_hook()
def _register_hook(self):
"""Capture activations at specified layer"""
def hook_fn(module, input, output):
self.activations = output.detach()
# Find layer by name
for name, module in self.model.named_modules():
if name == self.layer_name:
module.register_forward_hook(hook_fn)
return
raise ValueError(f"Layer {self.layer_name} not found")
def get_activations(self, images):
"""Extract activations for images"""
activations = []
with torch.no_grad():
for img in images:
_ = self.model(img.unsqueeze(0))
# Global average pool if spatial
act = self.activations
if act.dim() == 4: # (B, C, H, W)
act = act.mean(dim=[2, 3])
activations.append(act.cpu().numpy().flatten())
return np.array(activations)
def train_cav(self, concept_images, random_images):
"""
Train Concept Activation Vector
Args:
concept_images: Images containing concept (e.g., "stripes")
random_images: Random images (negative examples)
Returns:
cav: Concept activation vector (normal to decision boundary)
"""
# Get activations
concept_acts = self.get_activations(concept_images)
random_acts = self.get_activations(random_images)
# Prepare data
X = np.vstack([concept_acts, random_acts])
y = np.array([1] * len(concept_acts) + [0] * len(random_acts))
# Train linear classifier
clf = LinearSVC(C=1.0, max_iter=5000)
clf.fit(X, y)
# CAV is the normal vector (coefficients)
cav = clf.coef_[0]
cav = cav / np.linalg.norm(cav) # Normalize
return cav, clf.score(X, y)
def compute_tcav_score(self, class_images, cav):
"""
Compute TCAV score: fraction of class where gradient aligns with CAV
Args:
class_images: Images of target class
cav: Concept activation vector
Returns:
tcav_score: Value between 0 and 1
"""
cav_tensor = torch.tensor(cav, dtype=torch.float32)
alignments = []
for img in class_images:
img = img.unsqueeze(0).requires_grad_()
# Forward pass
_ = self.model(img)
act = self.activations
if act.dim() == 4:
act = act.mean(dim=[2, 3])
# Compute gradient of activation w.r.t. image
# We want βhΒ·CAV, so first get scalar hΒ·CAV
directional_derivative = (act.flatten() * cav_tensor).sum()
# Backward to get gradient
directional_derivative.backward()
# Check if positive (concept increases activation)
alignments.append(1 if directional_derivative.item() > 0 else 0)
tcav_score = np.mean(alignments)
return tcav_score
# ============================================================================
# Counterfactual Generation
# ============================================================================
class CounterfactualGenerator:
"""
Generate minimal perturbations that flip model prediction
Theory:
- Minimize: Ξ»βΒ·||Ξ΄|| + Ξ»βΒ·max(0, f(x+Ξ΄)_orig - max_{cβ orig} f(x+Ξ΄)_c + ΞΊ)
- First term: Keep change small
- Second term: Push to different class
"""
def __init__(self, model, lambda_dist=1.0, lambda_class=10.0, kappa=0.0):
"""
Args:
model: Classification model
lambda_dist: Weight for distance term
lambda_class: Weight for classification term
kappa: Confidence margin
"""
self.model = model
self.lambda_dist = lambda_dist
self.lambda_class = lambda_class
self.kappa = kappa
def generate(self, x, target_class=None, num_steps=500, lr=0.01):
"""
Generate counterfactual
Args:
x: Original input (C, H, W)
target_class: Target class (None = any different class)
num_steps: Optimization steps
lr: Learning rate
Returns:
x_cf: Counterfactual input
history: Loss history
"""
# Get original prediction
with torch.no_grad():
orig_pred = self.model(x.unsqueeze(0)).argmax(1).item()
# Initialize perturbation
delta = torch.zeros_like(x, requires_grad=True)
optimizer = torch.optim.Adam([delta], lr=lr)
history = {'loss': [], 'dist': [], 'class_loss': []}
for step in range(num_steps):
optimizer.zero_grad()
# Perturbed input (clip to valid range)
x_pert = torch.clamp(x + delta, 0, 1)
# Get predictions
logits = self.model(x_pert.unsqueeze(0))[0]
# Distance term (L2)
dist_loss = torch.norm(delta)
# Classification term
if target_class is not None:
# Targeted: maximize target class score
class_loss = -logits[target_class]
else:
# Untargeted: minimize original class, maximize others
orig_score = logits[orig_pred]
other_scores = torch.cat([logits[:orig_pred], logits[orig_pred+1:]])
max_other = other_scores.max()
# Hinge loss: want orig_score < max_other - kappa
class_loss = F.relu(orig_score - max_other + self.kappa)
# Combined loss
loss = self.lambda_dist * dist_loss + self.lambda_class * class_loss
loss.backward()
optimizer.step()
# Record
history['loss'].append(loss.item())
history['dist'].append(dist_loss.item())
history['class_loss'].append(class_loss.item())
# Check if successful
with torch.no_grad():
new_pred = self.model(x_pert.unsqueeze(0)).argmax(1).item()
if new_pred != orig_pred:
if target_class is None or new_pred == target_class:
print(f"Counterfactual found at step {step}")
print(f"Original class: {orig_pred}, New class: {new_pred}")
print(f"L2 distance: {dist_loss.item():.4f}")
break
x_cf = torch.clamp(x + delta, 0, 1).detach()
return x_cf, history
# ============================================================================
# Faithfulness Metrics
# ============================================================================
class FaithfulnessEvaluator:
"""
Evaluate attribution faithfulness using deletion/insertion metrics
Theory:
- Deletion: Remove features by importance, measure output drop
- Insertion: Add features by importance, measure output rise
- Good attribution β sharp curves
"""
def __init__(self, model):
self.model = model
def deletion_metric(self, x, attribution, target_class, num_steps=20):
"""
Deletion curve: Remove top features, track output
Args:
x: Input image (C, H, W)
attribution: Pixel attributions (C, H, W)
target_class: Class to track
num_steps: Number of deletion steps
Returns:
scores: Output scores after each deletion step
"""
# Flatten and sort by importance
attr_flat = attribution.flatten()
indices = torch.argsort(attr_flat, descending=True)
x_flat = x.flatten()
x_current = x.clone()
scores = []
pixels_per_step = len(indices) // num_steps
for step in range(num_steps + 1):
# Get current output
with torch.no_grad():
output = self.model(x_current.unsqueeze(0))
score = output[0, target_class].item()
scores.append(score)
# Remove next batch of pixels
if step < num_steps:
remove_indices = indices[step * pixels_per_step:(step + 1) * pixels_per_step]
x_flat[remove_indices] = 0
x_current = x_flat.view_as(x)
return np.array(scores)
def insertion_metric(self, x, attribution, target_class, num_steps=20):
"""
Insertion curve: Add top features to blank image, track output
"""
# Start with blank image
x_current = torch.zeros_like(x)
# Flatten and sort
attr_flat = attribution.flatten()
indices = torch.argsort(attr_flat, descending=True)
x_flat = x.flatten()
x_current_flat = x_current.flatten()
scores = []
pixels_per_step = len(indices) // num_steps
for step in range(num_steps + 1):
# Get current output
with torch.no_grad():
output = self.model(x_current.unsqueeze(0))
score = output[0, target_class].item()
scores.append(score)
# Add next batch of pixels
if step < num_steps:
add_indices = indices[step * pixels_per_step:(step + 1) * pixels_per_step]
x_current_flat[add_indices] = x_flat[add_indices]
x_current = x_current_flat.view_as(x)
return np.array(scores)
def plot_curves(self, deletion_scores, insertion_scores):
"""Plot deletion and insertion curves"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
steps = np.arange(len(deletion_scores))
# Deletion
ax1.plot(steps, deletion_scores, 'r-', linewidth=2)
ax1.fill_between(steps, deletion_scores, alpha=0.3, color='red')
ax1.set_xlabel('Fraction of Features Removed')
ax1.set_ylabel('Model Output')
ax1.set_title('Deletion Curve (β faster is better)')
ax1.grid(True, alpha=0.3)
# Insertion
ax2.plot(steps, insertion_scores, 'g-', linewidth=2)
ax2.fill_between(steps, insertion_scores, alpha=0.3, color='green')
ax2.set_xlabel('Fraction of Features Added')
ax2.set_ylabel('Model Output')
ax2.set_title('Insertion Curve (β faster is better)')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
return fig
# ============================================================================
# Demonstration
# ============================================================================
print("Advanced Interpretability Methods Implemented:")
print("=" * 60)
print("1. KernelSHAP - Model-agnostic Shapley value approximation")
print("2. DeepSHAP - Efficient SHAP for neural networks")
print("3. LIME - Local interpretable model-agnostic explanations")
print("4. AttentionRollout - Attention flow in transformers")
print("5. TCAV - Testing with concept activation vectors")
print("6. CounterfactualGenerator - Minimal prediction-flipping perturbations")
print("7. FaithfulnessEvaluator - Deletion/insertion metrics")
print("=" * 60)
# Example: LIME on our trained model
print("\nExample: LIME Explanation")
print("-" * 60)
# Get a test sample
test_sample = next(iter(test_loader))[0][0]
original_pred = model(test_sample.unsqueeze(0)).argmax(1).item()
print(f"Original prediction: {original_pred}")
# LIME explanation
lime = LIME(model, kernel_width=0.25)
superpixel_weights, segments = lime.explain_image(
test_sample,
num_samples=500,
num_features=10,
num_superpixels=50
)
print(f"Number of superpixels: {len(np.unique(segments))}")
print(f"Top 5 important superpixels: {np.argsort(superpixel_weights)[-5:]}")
print(f"Importance range: [{superpixel_weights.min():.4f}, {superpixel_weights.max():.4f}]")
# Visualize
fig = lime.visualize_explanation(test_sample, superpixel_weights, segments, num_features=5)
plt.savefig('lime_explanation.png', dpi=150, bbox_inches='tight')
print("\nLIME visualization saved to 'lime_explanation.png'")
plt.show()
# Example: Faithfulness evaluation
print("\nExample: Faithfulness Evaluation")
print("-" * 60)
# Generate attribution (using integrated gradients from earlier)
test_sample_ig = test_sample.unsqueeze(0).requires_grad_()
baseline = torch.zeros_like(test_sample_ig)
attribution = torch.zeros_like(test_sample_ig)
num_steps_ig = 50
for alpha in np.linspace(0, 1, num_steps_ig):
x_interp = baseline + alpha * (test_sample_ig - baseline)
x_interp.requires_grad_()
output = model(x_interp)
output[0, original_pred].backward()
attribution += x_interp.grad
attribution = (attribution * (test_sample_ig - baseline) / num_steps_ig).detach()[0]
# Evaluate faithfulness
evaluator = FaithfulnessEvaluator(model)
deletion_scores = evaluator.deletion_metric(test_sample, attribution, original_pred, num_steps=20)
insertion_scores = evaluator.insertion_metric(test_sample, attribution, original_pred, num_steps=20)
print(f"Deletion AUC: {np.trapz(deletion_scores) / len(deletion_scores):.4f}")
print(f"Insertion AUC: {np.trapz(insertion_scores) / len(insertion_scores):.4f}")
fig = evaluator.plot_curves(deletion_scores, insertion_scores)
plt.savefig('faithfulness_curves.png', dpi=150, bbox_inches='tight')
print("\nFaithfulness curves saved to 'faithfulness_curves.png'")
plt.show()
print("\n" + "=" * 60)
print("Key Takeaways:")
print("=" * 60)
print("1. SHAP: Theoretically grounded (Shapley values), but computationally expensive")
print("2. LIME: Fast and model-agnostic, but no theoretical guarantees")
print("3. Deletion/Insertion: Quantify attribution quality objectively")
print("4. Different methods suitable for different use cases")
print("=" * 60)
1. GradCAM TheoryΒΆ
Gradient-weighted Class Activation MappingΒΆ
where \(A^k\) are feature maps, \(\alpha_k^c\) are importance weights.
π Reference Materials:
foundation_neural_network.pdf - Foundation Neural Network
class ConvNet(nn.Module):
"""CNN for MNIST with feature extraction."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 3 * 3, 256)
self.fc2 = nn.Linear(256, 10)
self.features = None
self.gradients = None
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2)
# Save features
self.features = x
# Register hook
if x.requires_grad:
x.register_hook(self.save_gradient)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
def save_gradient(self, grad):
self.gradients = grad
model = ConvNet().to(device)
print("Model created")
Train ModelΒΆ
We train a convolutional neural network on an image classification task to serve as the subject of our interpretability analysis. The model should achieve high accuracy so that its learned features are meaningful β interpreting a poorly trained model yields noisy, uninformative explanations. Using a well-known architecture (like a small ResNet) also makes the results more comparable to published interpretability studies.
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=1000)
def train_model(model, train_loader, n_epochs=5):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(n_epochs):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
output = model(x)
loss = F.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} complete")
train_model(model, train_loader, n_epochs=5)
Implement GradCAMΒΆ
Gradient-weighted Class Activation Mapping (Grad-CAM) produces a coarse localization heatmap highlighting which regions of the input image contributed most to the modelβs prediction for a given class. It works by computing the gradient of the target class score with respect to the feature maps of the last convolutional layer, then weighting each feature map channel by its average gradient and summing: \(L^c_{\text{Grad-CAM}} = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)\), where \(\alpha_k^c = \frac{1}{Z}\sum_{i,j} \frac{\partial y^c}{\partial A^k_{ij}}\). The ReLU removes negative contributions, keeping only features that positively influence the target class. This method is architecture-agnostic and requires no retraining or architectural modification.
def generate_gradcam(model, img, target_class=None):
"""Generate GradCAM heatmap."""
model.eval()
# Forward pass
img.requires_grad = True
output = model(img)
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
model.zero_grad()
target = output[0, target_class]
target.backward()
# Get gradients and features
gradients = model.gradients
features = model.features
# Compute weights
weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
# Weighted combination
cam = torch.sum(weights * features, dim=1).squeeze()
cam = F.relu(cam)
# Normalize
cam = cam - cam.min()
cam = cam / cam.max()
return cam.cpu().detach().numpy(), target_class
# Test
x_test, y_test = next(iter(test_loader))
img = x_test[0:1].to(device)
cam, pred_class = generate_gradcam(model, img)
print(f"Predicted: {pred_class}, CAM shape: {cam.shape}")
Visualize GradCAMΒΆ
Overlaying the Grad-CAM heatmap on the original image reveals whether the model is looking at the right regions for the right reasons. For a correctly classified image, the heatmap should highlight the object of interest rather than background features or spurious correlations. Comparing heatmaps across different classes for the same image shows how the modelβs attention shifts depending on the target class. This visualization is widely used in healthcare, autonomous driving, and other safety-critical domains to build trust in model predictions.
def overlay_gradcam(img, cam):
"""Overlay heatmap on image."""
# Resize CAM to image size
cam_resized = cv2.resize(cam, (28, 28))
# Create heatmap
heatmap = plt.cm.jet(cam_resized)[:, :, :3]
# Overlay
img_gray = img.squeeze().cpu().numpy()
img_rgb = np.stack([img_gray] * 3, axis=-1)
overlay = 0.6 * img_rgb + 0.4 * heatmap
overlay = np.clip(overlay, 0, 1)
return overlay, heatmap
# Visualize multiple samples
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
img = x_test[i:i+1].to(device)
cam, pred = generate_gradcam(model, img)
overlay, heatmap = overlay_gradcam(img, cam)
axes[i, 0].imshow(img.squeeze().cpu(), cmap='gray')
axes[i, 0].set_title(f'Original (True: {y_test[i]})', fontsize=9)
axes[i, 0].axis('off')
axes[i, 1].imshow(cam, cmap='jet')
axes[i, 1].set_title(f'GradCAM (Pred: {pred})', fontsize=9)
axes[i, 1].axis('off')
axes[i, 2].imshow(heatmap)
axes[i, 2].set_title('Heatmap', fontsize=9)
axes[i, 2].axis('off')
axes[i, 3].imshow(overlay)
axes[i, 3].set_title('Overlay', fontsize=9)
axes[i, 3].axis('off')
plt.suptitle('GradCAM Visualizations', fontsize=13)
plt.tight_layout()
plt.show()
Saliency MapsΒΆ
Saliency maps are computed by taking the gradient of the output class score with respect to the input pixels: \(\frac{\partial y^c}{\partial x}\), then visualizing the absolute values. Unlike Grad-CAM which operates at the feature map level (coarse), saliency maps provide pixel-level attribution showing exactly which input pixels the model is most sensitive to. The resulting maps tend to highlight edges and textures, reflecting the features that convolutional networks rely on most heavily for classification.
def generate_saliency(model, img, target_class=None):
"""Generate saliency map."""
model.eval()
img.requires_grad = True
output = model(img)
if target_class is None:
target_class = output.argmax(dim=1).item()
model.zero_grad()
target = output[0, target_class]
target.backward()
saliency = img.grad.abs().squeeze()
return saliency.cpu().numpy(), target_class
# Visualize
fig, axes = plt.subplots(3, 6, figsize=(15, 7))
for i in range(6):
img = x_test[i:i+1].to(device)
# Original
axes[0, i].imshow(img.squeeze().cpu(), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_ylabel('Original', fontsize=11)
# GradCAM
cam, _ = generate_gradcam(model, img)
axes[1, i].imshow(cam, cmap='jet')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_ylabel('GradCAM', fontsize=11)
# Saliency
saliency, _ = generate_saliency(model, img)
axes[2, i].imshow(saliency, cmap='hot')
axes[2, i].axis('off')
if i == 0:
axes[2, i].set_ylabel('Saliency', fontsize=11)
plt.suptitle('Comparison: GradCAM vs Saliency', fontsize=13)
plt.tight_layout()
plt.show()
Integrated GradientsΒΆ
Integrated Gradients (Sundararajan et al., 2017) provides a more principled attribution method that satisfies two important axioms: completeness (attributions sum to the modelβs output difference from a baseline) and sensitivity (any feature that changes the output receives non-zero attribution). It computes attributions by integrating the gradient along a straight-line path from a baseline input (typically all zeros) to the actual input: \(\text{IG}_i(x) = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x - x'))}{\partial x_i} d\alpha\). In practice, the integral is approximated with 50-300 interpolation steps, making it more computationally expensive than Grad-CAM or saliency maps but significantly more reliable.
def integrated_gradients(model, img, target_class=None, steps=50):
"""Compute integrated gradients."""
model.eval()
# Baseline (zeros)
baseline = torch.zeros_like(img)
# Get target class
if target_class is None:
output = model(img)
target_class = output.argmax(dim=1).item()
# Compute gradients along path
integrated_grads = torch.zeros_like(img)
for alpha in np.linspace(0, 1, steps):
interpolated = baseline + alpha * (img - baseline)
interpolated.requires_grad = True
output = model(interpolated)
target = output[0, target_class]
model.zero_grad()
target.backward()
integrated_grads += interpolated.grad
# Average and multiply by input difference
integrated_grads = integrated_grads / steps
integrated_grads = (img - baseline) * integrated_grads
return integrated_grads.abs().squeeze().cpu().detach().numpy()
# Compare methods
img = x_test[0:1].to(device)
ig = integrated_gradients(model, img)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(img.squeeze().cpu(), cmap='gray')
axes[0].set_title('Original', fontsize=11)
axes[0].axis('off')
saliency, _ = generate_saliency(model, img)
axes[1].imshow(saliency, cmap='hot')
axes[1].set_title('Saliency', fontsize=11)
axes[1].axis('off')
axes[2].imshow(ig, cmap='hot')
axes[2].set_title('Integrated Gradients', fontsize=11)
axes[2].axis('off')
plt.tight_layout()
plt.show()
SummaryΒΆ
Interpretability Methods:ΒΆ
GradCAM: Class activation via gradients
Saliency Maps: Input gradient magnitude
Integrated Gradients: Path integral of gradients
SHAP: Shapley value-based attribution
Key Insights:ΒΆ
GradCAM shows spatial attention
Saliency highlights important pixels
IG reduces noise via integration
Different methods reveal different aspects
Applications:ΒΆ
Model debugging
Trust and transparency
Feature validation
Bias detection
Medical diagnosis explanation
Best Practices:ΒΆ
Use multiple methods
Validate with domain experts
Consider method limitations
Combine with quantitative metrics
Extensions:ΒΆ
Grad-CAM++: Improved localization
Score-CAM: Gradient-free
LIME: Local explanations
Attention rollout: Transformer interpretation