import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Advanced Neural Network Interpretability TheoryΒΆ

1. Foundations and TaxonomyΒΆ

Definition: Interpretability is the degree to which a human can understand the cause of a decision made by a model.

Taxonomy of Interpretability:

Property

Intrinsic

Post-hoc

Model-specific

Linear/Tree models

GradCAM, CAM

Model-agnostic

Sparse models

LIME, SHAP

Local

Rule-based

Single prediction explanation

Global

Feature importance

Overall behavior understanding

Fidelity vs. Interpretability Trade-off:

  • High fidelity (complex models): Better performance, harder to interpret

  • High interpretability (simple models): Easier to understand, may sacrifice performance

  • Goal: Post-hoc methods bridge this gap

2. Gradient-Based Attribution MethodsΒΆ

2.1 Vanilla Gradients (Saliency Maps)

Definition: Attribution is the gradient of the output w.r.t. input:

\[A_i = \left|\frac{\partial f(x)}{\partial x_i}\right|\]

Where \(f(x)\) is the model output for input \(x\).

First-order Taylor approximation:

\[f(x) \approx f(x_0) + \nabla f(x_0)^T (x - x_0)\]

The gradient \(\nabla f\) indicates local sensitivity: how much output changes with small input perturbations.

Limitations:

  1. Saturation: Gradients vanish in saturated regions (ReLU dead neurons, sigmoid plateaus)

  2. Noise: High-frequency artifacts due to local nature

  3. No baseline: Doesn’t distinguish important vs. unimportant features

2.2 Integrated Gradients (IG)

Axioms for attribution:

  1. Sensitivity: If feature differs from baseline and affects output, attribution should be non-zero

  2. Implementation invariance: Functionally equivalent models should have identical attributions

IG Definition:

\[\text{IG}_i(x) = (x_i - x_i') \int_{\alpha=0}^{1} \frac{\partial f(x' + \alpha(x - x'))}{\partial x_i} d\alpha\]

Where:

  • \(x'\): Baseline (typically zeros or average image)

  • \(\alpha \in [0,1]\): Interpolation parameter

  • Path: Straight line from \(x'\) to \(x\)

Discrete approximation (Riemann sum):

\[\text{IG}_i(x) \approx (x_i - x_i') \cdot \frac{1}{m} \sum_{k=1}^{m} \frac{\partial f(x' + \frac{k}{m}(x - x'))}{\partial x_i}\]

Typical: \(m = 50\) steps

Completeness property:

\[\sum_{i=1}^{n} \text{IG}_i(x) = f(x) - f(x')\]

Attribution sums to the difference between output and baseline.

Proof sketch: By fundamental theorem of calculus:

\[f(x) - f(x') = \int_{\alpha=0}^{1} \frac{d}{d\alpha} f(x' + \alpha(x - x')) d\alpha = \int_{\alpha=0}^{1} \nabla f \cdot (x - x') d\alpha\]

2.3 SmoothGrad

Problem: Gradients are noisy and sensitive to small perturbations

Solution: Average gradients over noisy samples

\[\text{SmoothGrad}(x) = \frac{1}{n} \sum_{i=1}^{n} \frac{\partial f(x + \mathcal{N}(0, \sigma^2))}{\partial x}\]

Where \(\mathcal{N}(0, \sigma^2)\) is Gaussian noise, typical \(\sigma = 0.15 \times (\max(x) - \min(x))\)

Effect: Reduces high-frequency noise, produces visually cleaner attribution maps

Variants:

  • SmoothGrad-Squared: \(\frac{1}{n} \sum (\partial f / \partial x)^2\)

  • VarGrad: Variance of gradients across samples

3. Class Activation Mapping (CAM) MethodsΒΆ

3.1 CAM (Original)

Architecture requirement: Global Average Pooling (GAP) before final classification layer

\[y^c = \sum_k w_k^c \cdot \frac{1}{Z} \sum_{i,j} A_{ij}^k\]

Where:

  • \(A^k\): Feature map k from last convolutional layer

  • \(w_k^c\): Weight connecting feature k to class c

  • Z: Spatial size (height Γ— width)

Class Activation Map:

\[M^c(i,j) = \sum_k w_k^c \cdot A_{ij}^k\]

Interpretation: \(M^c\) highlights regions most discriminative for class c

Limitation: Requires GAP architecture, not applicable to arbitrary networks

3.2 Grad-CAM (Gradient-weighted CAM)

Key innovation: Use gradients to compute importance weights, removing architecture constraint

Gradient-based weights:

\[\alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A_{ij}^k} = \text{GlobalAvgPool}\left(\frac{\partial y^c}{\partial A^k}\right)\]

Intuition: \(\alpha_k^c\) measures how much feature map k contributes to class c

Grad-CAM:

\[L_{\text{Grad-CAM}}^c = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)\]

ReLU: Keeps only positive influences (features increasing class score)

Why this works: Chain rule decomposition:

\[\frac{\partial y^c}{\partial A_{ij}^k} = \sum_{m,n} \frac{\partial y^c}{\partial z_{mn}} \cdot \frac{\partial z_{mn}}{\partial A_{ij}^k}\]

Where \(z\) are downstream activations. Averaging spatially gives global importance.

3.3 Grad-CAM++ (Improved Localization)

Problem: Grad-CAM produces coarse localization, especially for multiple objects

Solution: Weighted combination of pixel-wise gradients

\[\alpha_k^c = \sum_i \sum_j w_{ij}^{kc} \cdot \text{ReLU}\left(\frac{\partial y^c}{\partial A_{ij}^k}\right)\]

Pixel-wise weights:

\[w_{ij}^{kc} = \frac{(\partial y^c / \partial A_{ij}^k)^2}{2 (\partial y^c / \partial A_{ij}^k)^2 + \sum_{a,b} A_{ab}^k (\partial y^c / \partial A_{ij}^k)^3}\]

Advantages:

  • Better localization for multiple occurrences

  • Handles multiple objects of same class

  • Provides confidence through gradient magnitude

3.4 Score-CAM (Gradient-Free)

Motivation: Gradients can be noisy and require backpropagation

Approach: Forward-pass only, using activation maps as masks

Algorithm:

  1. For each feature map \(A^k\), upsample to input size: \(M^k\)

  2. Compute masked input: \(X_k = X \odot M^k\) (element-wise product)

  3. Forward pass through network: \(S_k = f(X_k)\)

  4. Weight by increase in target class score: \(\alpha_k^c = S_k^c - f(X_{\text{baseline}})^c\)

Score-CAM:

\[L_{\text{Score-CAM}}^c = \text{ReLU}\left(\sum_k \alpha_k^c \cdot \text{Normalize}(A^k)\right)\]

Advantages:

  • No gradients needed (faster, more stable)

  • Works with any differentiable or non-differentiable model

  • Less sensitive to gradient saturation

4. Perturbation-Based MethodsΒΆ

4.1 Occlusion Sensitivity

Idea: Occlude parts of input, measure output change

\[\text{Sensitivity}(i,j) = f(x)^c - f(x_{\text{occluded}(i,j)})^c\]

Procedure:

  1. Slide patch (e.g., 5Γ—5 gray square) over input

  2. At each position, compute output drop

  3. Higher drop β†’ more important region

Advantages:

  • Model-agnostic

  • Easy to understand

Limitations:

  • Computationally expensive: \(O(H \times W)\) forward passes

  • Patch size affects results

  • Doesn’t handle feature interactions

4.2 RISE (Randomized Input Sampling for Explanation)

Approach: Generate random masks, aggregate weighted by output

Algorithm:

  1. Generate N random binary masks: \(\{M_1, ..., M_N\}\)

  2. For each mask, forward pass: \(f(x \odot M_i)\)

  3. Weight masks by output: \(w_i = f(x \odot M_i)^c\)

  4. Aggregate: \(S^c = \frac{1}{N} \sum_i w_i \cdot M_i\)

Advantages:

  • Smooth saliency maps

  • Probabilistic interpretation

  • Handles complex interactions

Hyperparameters:

  • N: Number of masks (typically 2000-8000)

  • Mask size: Controls granularity

  • Probability: Fraction of pixels kept (typically 0.5)

5. SHAP (SHapley Additive exPlanations)ΒΆ

5.1 Shapley Values from Game Theory

Setup: Cooperative game with players (features), value function (model output)

Shapley value for feature i:

\[\phi_i = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|! (|N| - |S| - 1)!}{|N|!} [v(S \cup \{i\}) - v(S)]\]

Where:

  • N: Set of all features

  • S: Subset of features

  • v(S): Model output with features in S

  • \(v(S \cup \{i\}) - v(S)\): Marginal contribution of feature i to subset S

Axioms satisfied:

  1. Efficiency: \(\sum_i \phi_i = v(N) - v(\emptyset)\) (complete attribution)

  2. Symmetry: If features i, j contribute equally, \(\phi_i = \phi_j\)

  3. Dummy: If feature has no effect, \(\phi_i = 0\)

  4. Additivity: Linear combination of games β†’ linear combination of values

5.2 SHAP for Deep Networks

Challenge: Computing Shapley values requires \(2^n\) evaluations (exponential in features)

DeepSHAP (approximation): Uses backpropagation-like approach to approximate Shapley values

Linear SHAP (for linear models):

\[\phi_i = (x_i - E[x_i]) \cdot w_i\]

Simple: Deviation from mean Γ— weight

Kernel SHAP (model-agnostic):

Weighted linear regression to approximate Shapley values:

\[\min_{\phi} \sum_{z \in Z} [\pi(z) (f(h_x(z)) - \sum_i \phi_i z_i)^2]\]

Where:

  • \(z\): Binary vector indicating feature presence

  • \(\pi(z)\): SHAP kernel weight \(\frac{|N|-1}{\binom{|N|}{|z|} |z| (|N| - |z|)}\)

  • \(h_x(z)\): Maps binary vector to feature values

Complexity: \(O(2^n)\) exact, \(O(n^2)\) with sampling approximations

6. LIME (Local Interpretable Model-agnostic Explanations)ΒΆ

6.1 Core Idea

Approximate complex model locally with interpretable model (linear, decision tree)

Optimization objective:

\[\xi(x) = \arg\min_{g \in G} \mathcal{L}(f, g, \pi_x) + \Omega(g)\]

Where:

  • \(g \in G\): Interpretable model (e.g., linear)

  • \(\mathcal{L}\): Fidelity loss (how well g approximates f locally)

  • \(\pi_x\): Proximity measure (kernel, e.g., exponential)

  • \(\Omega(g)\): Complexity penalty (e.g., number of features)

6.2 LIME Algorithm

  1. Sample: Generate N perturbed samples around x: \(\{x_1', ..., x_N'\}\)

  2. Evaluate: Get model predictions: \(\{f(x_1'), ..., f(x_N')\}\)

  3. Weight: Compute proximity: \(\pi_x(x_i') = \exp(-D(x, x_i')^2 / \sigma^2)\)

  4. Fit: Train interpretable model g on weighted samples

  5. Explain: Extract coefficients or rules from g

For images:

  • Segment image into superpixels (e.g., SLIC algorithm)

  • Perturb by turning superpixels on/off

  • Fit linear model to predict class probability

Example linear model:

\[g(x') = w_0 + \sum_i w_i \cdot \mathbb{1}[\text{superpixel } i \text{ present}]\]

Top-k positive \(w_i\) indicate most important superpixels

6.3 Comparison: LIME vs SHAP

Property

LIME

SHAP

Framework

Local surrogate

Shapley values

Guarantees

None (heuristic)

Efficiency, symmetry

Sampling

Random perturbations

Coalitional

Consistency

Not guaranteed

Guaranteed

Speed

Fast

Slower (exact)

Use case

Quick explanations

Rigorous attribution

7. Attention Mechanisms as InterpretabilityΒΆ

7.1 Attention Weights

For Transformer models: Attention weights \(\alpha_{ij}\) indicate β€œhow much position j contributes to position i”

\[\alpha_{ij} = \frac{\exp(q_i^T k_j / \sqrt{d_k})}{\sum_{j'} \exp(q_i^T k_{j'} / \sqrt{d_k})}\]

Naive interpretation: High \(\alpha_{ij}\) β†’ token j is important for token i

Problem: Attention is NOT explanation

  • Attention weights don’t necessarily indicate importance

  • Multiple heads can have different patterns

  • Downstream layers complicate interpretation

7.2 Attention Rollout

Idea: Propagate attention through layers to see token-to-token flow

Recursive formula:

\[A_{\text{rollout}}^{(l)} = A^{(l)} \cdot A_{\text{rollout}}^{(l-1)}\]

Where \(A^{(l)}\) is attention at layer l (averaged over heads)

Initialization: \(A_{\text{rollout}}^{(0)} = I\) (identity)

Final attribution: Last layer’s rolled-out attention \(A_{\text{rollout}}^{(L)}\)

7.3 Attention Flow

Max flow through attention graph:

Compute maximum information flow from input to output through attention edges

Algorithm:

  1. Build graph: Nodes = tokens at each layer, edges = attention weights

  2. Run max-flow algorithm from source (input token) to sink (output)

  3. Flow value = attribution

8. Concept-Based InterpretabilityΒΆ

8.1 TCAV (Testing with Concept Activation Vectors)

Motivation: Gradients/attention show low-level features, not high-level concepts

Idea: Test if model uses human-defined concepts (e.g., β€œstripes” for zebra)

Concept Activation Vector (CAV): Linear classifier separating concept examples from random examples in activation space

\[\text{CAV}_C = \text{train\_linear\_classifier}(\{\text{activations}(x): x \in \text{concept } C\}, \{\text{activations}(x): x \in \text{random}\})\]

TCAV Score: Fraction of examples where directional derivative along CAV is positive

\[\text{TCAV}_{C,c} = \frac{1}{|X_c|} \sum_{x \in X_c} \mathbb{1}\left[\nabla h_l(x)^T \text{CAV}_C > 0\right]\]

Where:

  • \(X_c\): Examples of class c

  • \(h_l\): Activations at layer l

Interpretation: TCAV = 0.8 means β€œ80% of class c examples have concept C”

8.2 Network Dissection

Goal: Identify what individual neurons detect

Approach:

  1. Collect semantic segmentation dataset (labels for concepts: sky, grass, car)

  2. For each neuron, compute IoU with each concept

  3. Assign neuron to highest-IoU concept (if IoU > threshold)

\[\text{IoU}(n, c) = \frac{|\{p: A_p^n > t_n\} \cap \{p: L_p = c\}|}{|\{p: A_p^n > t_n\} \cup \{p: L_p = c\}|}\]

Where:

  • \(A^n\): Activation map of neuron n

  • \(t_n\): Threshold (top quantile, e.g., 0.05)

  • \(L_p = c\): Pixel p labeled as concept c

Results: Neurons specialize to detect objects, textures, parts (e.g., β€œdog face detector”)

9. Counterfactual ExplanationsΒΆ

9.1 Definition

β€œWhat minimal change to input would flip the model’s prediction?”

\[x_{\text{cf}} = \arg\min_{x'} D(x, x') \quad \text{s.t.} \quad f(x') \neq f(x)\]

Where:

  • \(D(x, x')\): Distance metric (e.g., L2, L1, or feature distance)

  • Constraint: Prediction changes

9.2 Optimization Approach

Loss function:

\[\mathcal{L}(x') = \lambda_1 \cdot D(x, x') + \lambda_2 \cdot \max(0, f(x')_{\text{original class}} - \max_{c \neq \text{original}} f(x')_c + \kappa)\]

Where:

  • First term: Minimize distance to original

  • Second term: Push to different class (margin \(\kappa\))

Algorithm:

  1. Initialize: \(x' = x\)

  2. Gradient descent on \(\mathcal{L}(x')\)

  3. Project to valid input space (e.g., [0,1] for images)

  4. Stop when \(f(x') \neq f(x)\)

9.3 Diverse Counterfactuals

Generate multiple diverse counterfactuals:

\[\min \sum_i D(x, x_i') + \lambda \cdot \text{DiversityLoss}(\{x_i'\})\]

Diversity loss:

\[\text{DiversityLoss} = -\sum_{i \neq j} D(x_i', x_j')\]

Encourages counterfactuals to differ from each other

10. Faithfulness and EvaluationΒΆ

10.1 Deletion/Insertion Metrics

Deletion: Remove features in order of importance, measure output drop

\[\text{Deletion}(k) = f(x) - f(x_{\text{top-k removed}})\]

Good attribution: Output drops quickly

Insertion: Add features in order of importance (from blank), measure output rise

\[\text{Insertion}(k) = f(x_{\text{top-k added}}) - f(\emptyset)\]

Good attribution: Output rises quickly

Area Under Curve (AUC):

  • Deletion AUC: Higher is better (rapid drop)

  • Insertion AUC: Higher is better (rapid rise)

10.2 Sensitivity-n

Definition: Maximum sensitivity to n-pixel perturbation

\[\text{Sens}_n = \max_{S: |S| \leq n} |f(x) - f(x_S)|\]

Where \(x_S\) has pixels in set S perturbed

Lower is better: Attribution should not change drastically with small input changes

10.3 Infidelity

Measures: How well attribution approximates actual model behavior

\[\text{Infidelity} = \mathbb{E}_I \left[(I^T \phi(x) - (f(x) - f(x - I)))^2\right]\]

Where:

  • \(I\): Random perturbation

  • \(\phi(x)\): Attribution

  • Expectation over perturbations

Lower is better: Predictions from attribution should match actual perturbations

11. Practical ConsiderationsΒΆ

11.1 Method Selection Guide

Goal

Method

Pros

Cons

Quick visualization

Grad-CAM

Fast, class-specific

Coarse resolution

Pixel-level attribution

IG, SmoothGrad

Detailed

Baseline-dependent

Model-agnostic

LIME, SHAP

Any model

Slow, approximations

Conceptual

TCAV

Human concepts

Needs concept datasets

Debugging

Activation max, Dissection

Direct neuron insight

Requires analysis

11.2 Common Pitfalls

  1. Confusing correlation with causation: Attribution β‰  causal importance

  2. Ignoring baseline choice: IG results depend heavily on baseline

  3. Over-interpreting attention: Attention weights β‰  feature importance

  4. Not validating: Always check faithfulness metrics

  5. Cherry-picking examples: Show failures, not just successes

11.3 Hyperparameters

Integrated Gradients:

  • Steps: 50-100 (more = smoother but slower)

  • Baseline: Zero, mean, blur, or black image

LIME:

  • Samples: 1000-5000

  • Kernel width: 0.25 Γ— \(\sqrt{n\_features}\)

  • Superpixels (images): 50-200

SHAP:

  • Background samples: 50-100 (DeepSHAP)

  • Max evaluations: 2000 (KernelSHAP)

SmoothGrad:

  • Noise std: 0.1-0.2 Γ— (max - min)

  • Samples: 20-50

12. Advanced Topics and Research FrontiersΒΆ

12.1 Robustness of Explanations

Problem: Explanations can be fragile to adversarial attacks

Adversarial explanation attack:

\[x' = x + \delta \quad \text{s.t.} \quad f(x') = f(x), \quad \phi(x') \neq \phi(x)\]

Same prediction, different explanation β†’ undermines trust

Defense: Robust attribution methods (e.g., robust integrated gradients)

12.2 Multi-Modal Explanations

For vision-language models (CLIP, GPT-4V):

  • Cross-modal attention: How image regions align with text tokens

  • Grad-CAM for vision encoder + attention weights for text

  • Integrated multimodal gradients

12.3 Explanations for Generated Outputs

Diffusion models: Which parts of noise contribute to final image? LLMs: Token-level attribution for generated text Challenge: Long generation chains complicate credit assignment

12.4 Interactive Explanations

What-if analysis: User perturbs input, sees effect on prediction and explanation Counterfactual exploration: Navigate alternative scenarios Concept activation: Manually adjust concept importance

Key Papers and TimelineΒΆ

Foundation (2013-2016):

  • Simonyan et al. 2013: Saliency Maps - First gradient-based visualization

  • Zhou et al. 2016: CAM - Class activation mapping with GAP

  • Ribeiro et al. 2016: LIME - Local interpretable models

Gradient Methods (2017-2018):

  • Sundararajan et al. 2017: Integrated Gradients - Axioms and path integration

  • Selvaraju et al. 2017: Grad-CAM - Gradient-weighted activation maps

  • Smilkov et al. 2017: SmoothGrad - Noise-based gradient smoothing

Game Theory (2017-2020):

  • Lundberg & Lee 2017: SHAP - Shapley values for ML

  • Chattopadhyay et al. 2018: Grad-CAM++ - Improved localization

  • Wang et al. 2020: Score-CAM - Gradient-free activation maps

Concepts & Validation (2018-2022):

  • Kim et al. 2018: TCAV - Concept activation vectors

  • Bau et al. 2017: Network Dissection - Neuron semantic analysis

  • Hooker et al. 2019: Deletion/Insertion - Faithfulness metrics

Recent (2020-2024):

  • Chefer et al. 2021: Transformer Interpretability - Attention rollout

  • Achtibat et al. 2023: Attribution Robustness - Adversarial explanations

  • Diverse counterfactuals and causal interpretability

Computational Complexity:

  • Grad-CAM: \(O(1)\) (single backward pass)

  • Integrated Gradients: \(O(m)\) where m = steps

  • LIME: \(O(n \cdot k)\) where n = samples, k = model evals

  • SHAP: \(O(2^d)\) exact, \(O(d^2)\) approximation (d = features)

"""
Advanced Interpretability Implementations

This cell provides production-ready implementations of:
1. SHAP (KernelSHAP, DeepSHAP approximation)
2. LIME for images
3. Attention visualization (rollout, flow)
4. Concept Activation Vectors (TCAV)
5. Counterfactual generation
6. Faithfulness evaluation metrics
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.svm import LinearSVC
from skimage.segmentation import slic
from scipy.special import comb
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# SHAP Implementation
# ============================================================================

class KernelSHAP:
    """
    Model-agnostic SHAP using weighted linear regression
    
    Theory:
    - Approximates Shapley values by fitting weighted linear model
    - SHAP kernel weight: Ο€(z) = (M-1) / (C(M,|z|) * |z| * (M-|z|))
    - Where M = number of features, |z| = number of present features
    """
    
    def __init__(self, model, background_data, num_samples=1000):
        """
        Args:
            model: Prediction function (input -> output)
            background_data: Reference samples (N, D) for creating coalitions
            num_samples: Number of coalitional samples to generate
        """
        self.model = model
        self.background = background_data
        self.num_samples = num_samples
        self.M = background_data.shape[1]  # Number of features
        
    def kernel_weight(self, z):
        """SHAP kernel weight for coalition z (binary vector)"""
        num_present = z.sum()
        if num_present == 0 or num_present == self.M:
            return 1e10  # High weight for empty/full coalitions
        return (self.M - 1) / (comb(self.M, num_present) * num_present * (self.M - num_present))
    
    def explain(self, x, verbose=False):
        """
        Compute SHAP values for input x
        
        Returns:
            shap_values: (D,) array of feature attributions
        """
        # Sample coalitions (binary vectors indicating feature presence)
        coalitions = np.random.binomial(1, 0.5, size=(self.num_samples, self.M))
        
        # Ensure we have empty and full coalitions
        coalitions[0] = np.zeros(self.M)
        coalitions[1] = np.ones(self.M)
        
        # Create masked samples
        masked_samples = []
        for z in coalitions:
            # For each coalition, average over background samples
            sample = np.where(z[:, None] == 1, x, self.background)
            masked_samples.append(sample)
        
        masked_samples = np.array(masked_samples)  # (num_samples, M, feature_dim)
        
        # Evaluate model on masked samples
        with torch.no_grad():
            # Average predictions over background samples
            predictions = []
            for samples in masked_samples:
                preds = self.model(torch.tensor(samples, dtype=torch.float32))
                predictions.append(preds.mean().item())
            predictions = np.array(predictions)
        
        # Compute kernel weights
        weights = np.array([self.kernel_weight(z) for z in coalitions])
        
        # Weighted linear regression: f(z) β‰ˆ Ο†β‚€ + Ξ£ Ο†α΅’ zα΅’
        model = Ridge(alpha=0.01)
        model.fit(coalitions, predictions, sample_weight=weights)
        
        # SHAP values are the coefficients
        shap_values = model.coef_
        
        if verbose:
            print(f"Expected value (intercept): {model.intercept_:.4f}")
            print(f"Prediction: {predictions[-1]:.4f}")  # Full coalition
            print(f"Sum of SHAP values: {shap_values.sum():.4f}")
            print(f"Difference (should match): {predictions[-1] - predictions[0]:.4f}")
        
        return shap_values


class DeepSHAP:
    """
    DeepSHAP: Efficient SHAP approximation for neural networks
    
    Theory:
    - Uses backpropagation-like approach to compute approximate Shapley values
    - Compares activation with reference (background) activation
    - Distributes attributions through network layers
    """
    
    def __init__(self, model, background_data):
        """
        Args:
            model: PyTorch neural network
            background_data: Reference samples (N, C, H, W)
        """
        self.model = model
        self.background = background_data
        self.model.eval()
        
        # Compute reference activations
        with torch.no_grad():
            self.ref_output = model(background_data).mean(0)
    
    def explain(self, x, target_class=None):
        """
        Compute DeepSHAP attributions
        
        Approximation: Use Integrated Gradients as proxy for Shapley values
        IG satisfies efficiency axiom (completeness) like SHAP
        """
        x = x.requires_grad_()
        
        # Compute average baseline
        baseline = self.background.mean(0, keepdim=True)
        
        # Integrated Gradients (fast approximation to SHAP)
        num_steps = 50
        attributions = torch.zeros_like(x)
        
        for alpha in np.linspace(0, 1, num_steps):
            x_interp = baseline + alpha * (x - baseline)
            x_interp.requires_grad_()
            
            output = self.model(x_interp)
            
            if target_class is None:
                target_class = output.argmax(1)
            
            output[0, target_class].backward()
            attributions += x_interp.grad
        
        attributions = attributions * (x - baseline) / num_steps
        
        return attributions.detach()


# ============================================================================
# LIME Implementation
# ============================================================================

class LIME:
    """
    Local Interpretable Model-agnostic Explanations
    
    Theory:
    - Approximates complex model f locally with interpretable model g
    - Samples around instance x, weights by proximity
    - Fits linear model: g(z) = wβ‚€ + Ξ£ wα΅’ zα΅’
    """
    
    def __init__(self, model, kernel_width=0.25):
        """
        Args:
            model: Prediction function
            kernel_width: Controls locality (smaller = more local)
        """
        self.model = model
        self.kernel_width = kernel_width
    
    def explain_image(self, image, num_samples=1000, num_features=10, num_superpixels=100):
        """
        Explain image classification using superpixels
        
        Args:
            image: Input image (C, H, W) tensor
            num_samples: Number of perturbations
            num_features: Number of top features to return
            num_superpixels: Number of image segments
            
        Returns:
            weights: Importance of each superpixel
            segments: Superpixel segmentation
        """
        # Convert to numpy for segmentation
        img_np = image.permute(1, 2, 0).cpu().numpy()
        img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
        
        # Segment image into superpixels using SLIC
        segments = slic(img_np, n_segments=num_superpixels, compactness=10, sigma=1, start_label=0)
        num_segments = len(np.unique(segments))
        
        # Sample perturbations (binary vectors indicating superpixel presence)
        perturbations = np.random.binomial(1, 0.5, size=(num_samples, num_segments))
        
        # Evaluate model on perturbed images
        predictions = []
        distances = []
        
        for pert in perturbations:
            # Create perturbed image
            perturbed = image.clone()
            for i, keep in enumerate(pert):
                if keep == 0:
                    # Zero out superpixel (black)
                    mask = segments == i
                    perturbed[:, mask] = 0
            
            # Get prediction
            with torch.no_grad():
                pred = self.model(perturbed.unsqueeze(0))
                predictions.append(pred.softmax(1).cpu().numpy()[0])
            
            # Compute distance (fraction of superpixels changed)
            distance = np.sum(pert != 1) / num_segments
            distances.append(distance)
        
        predictions = np.array(predictions)
        distances = np.array(distances)
        
        # Compute kernel weights (exponential kernel)
        kernel_width = self.kernel_width * np.sqrt(num_segments)
        weights = np.exp(-(distances ** 2) / (kernel_width ** 2))
        
        # Fit weighted linear regression for each class
        all_weights = []
        for class_idx in range(predictions.shape[1]):
            y = predictions[:, class_idx]
            
            # Weighted least squares
            model = Ridge(alpha=1.0)
            model.fit(perturbations, y, sample_weight=weights)
            
            all_weights.append(model.coef_)
        
        # Return weights for predicted class
        pred_class = self.model(image.unsqueeze(0)).argmax(1).item()
        superpixel_weights = all_weights[pred_class]
        
        return superpixel_weights, segments
    
    def visualize_explanation(self, image, weights, segments, num_features=5):
        """Visualize top positive and negative superpixels"""
        # Get top positive features
        top_positive = np.argsort(weights)[-num_features:]
        
        # Create mask
        mask = np.zeros(segments.shape, dtype=bool)
        for idx in top_positive:
            mask |= (segments == idx)
        
        # Visualize
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        img_np = image.permute(1, 2, 0).cpu().numpy()
        img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
        axes[0].imshow(img_np)
        axes[0].set_title("Original Image")
        axes[0].axis('off')
        
        # Superpixel boundaries
        from skimage.segmentation import mark_boundaries
        axes[1].imshow(mark_boundaries(img_np, segments))
        axes[1].set_title(f"Superpixels (n={len(np.unique(segments))})")
        axes[1].axis('off')
        
        # Highlighted regions
        highlighted = img_np.copy()
        highlighted[~mask] = highlighted[~mask] * 0.3  # Dim non-important regions
        axes[2].imshow(highlighted)
        axes[2].set_title(f"Top {num_features} Important Regions")
        axes[2].axis('off')
        
        plt.tight_layout()
        return fig


# ============================================================================
# Attention Visualization
# ============================================================================

class AttentionRollout:
    """
    Attention Rollout for Vision Transformers
    
    Theory:
    - Recursively multiply attention matrices through layers
    - A_rollout^(l) = A^(l) @ A_rollout^(l-1)
    - Captures token-to-token information flow
    """
    
    def __init__(self, model, head_fusion='mean', discard_ratio=0.1):
        """
        Args:
            model: Vision Transformer with attention weights
            head_fusion: How to combine multi-head attention ('mean', 'max', 'min')
            discard_ratio: Fraction of lowest attention to discard
        """
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        self.attentions = []
        
        # Register hooks to capture attention weights
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward hooks to capture attention matrices"""
        def hook_fn(module, input, output):
            # Assuming attention is returned as second element
            if isinstance(output, tuple) and len(output) > 1:
                self.attentions.append(output[1])
        
        # This is model-specific; adapt to your architecture
        # For ViT: hook into each transformer block's attention module
        for module in self.model.modules():
            if hasattr(module, 'attn') or 'attention' in module.__class__.__name__.lower():
                module.register_forward_hook(hook_fn)
    
    def compute_rollout(self, input_tensor):
        """
        Compute attention rollout
        
        Returns:
            rollout: (num_patches, num_patches) attention flow matrix
        """
        self.attentions = []
        
        # Forward pass to collect attentions
        with torch.no_grad():
            _ = self.model(input_tensor)
        
        # Fuse multi-head attention
        fused_attentions = []
        for attn in self.attentions:
            # attn shape: (batch, heads, tokens, tokens)
            if self.head_fusion == 'mean':
                attn_fused = attn.mean(dim=1)
            elif self.head_fusion == 'max':
                attn_fused = attn.max(dim=1)[0]
            elif self.head_fusion == 'min':
                attn_fused = attn.min(dim=1)[0]
            else:
                raise ValueError(f"Unknown fusion: {self.head_fusion}")
            
            fused_attentions.append(attn_fused)
        
        # Rollout: multiply attention matrices
        rollout = torch.eye(fused_attentions[0].size(1)).to(input_tensor.device)
        
        for attn in fused_attentions:
            # Add residual connection
            attn = attn + torch.eye(attn.size(1)).to(attn.device)
            attn = attn / attn.sum(dim=-1, keepdim=True)
            
            # Multiply
            rollout = torch.matmul(attn[0], rollout)
        
        # Focus on CLS token (first token) to all patches
        rollout_cls = rollout[0, 1:]  # Exclude CLS itself
        
        return rollout_cls


# ============================================================================
# Concept Activation Vectors (TCAV)
# ============================================================================

class TCAV:
    """
    Testing with Concept Activation Vectors
    
    Theory:
    - Train linear classifier to separate concept from random examples
    - CAV = normal vector to decision boundary
    - TCAV score = fraction of examples where βˆ‡hΒ·CAV > 0
    """
    
    def __init__(self, model, layer_name):
        """
        Args:
            model: Neural network
            layer_name: Name of layer to extract activations from
        """
        self.model = model
        self.layer_name = layer_name
        self.activations = None
        
        # Register hook
        self._register_hook()
    
    def _register_hook(self):
        """Capture activations at specified layer"""
        def hook_fn(module, input, output):
            self.activations = output.detach()
        
        # Find layer by name
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                return
        
        raise ValueError(f"Layer {self.layer_name} not found")
    
    def get_activations(self, images):
        """Extract activations for images"""
        activations = []
        
        with torch.no_grad():
            for img in images:
                _ = self.model(img.unsqueeze(0))
                # Global average pool if spatial
                act = self.activations
                if act.dim() == 4:  # (B, C, H, W)
                    act = act.mean(dim=[2, 3])
                activations.append(act.cpu().numpy().flatten())
        
        return np.array(activations)
    
    def train_cav(self, concept_images, random_images):
        """
        Train Concept Activation Vector
        
        Args:
            concept_images: Images containing concept (e.g., "stripes")
            random_images: Random images (negative examples)
            
        Returns:
            cav: Concept activation vector (normal to decision boundary)
        """
        # Get activations
        concept_acts = self.get_activations(concept_images)
        random_acts = self.get_activations(random_images)
        
        # Prepare data
        X = np.vstack([concept_acts, random_acts])
        y = np.array([1] * len(concept_acts) + [0] * len(random_acts))
        
        # Train linear classifier
        clf = LinearSVC(C=1.0, max_iter=5000)
        clf.fit(X, y)
        
        # CAV is the normal vector (coefficients)
        cav = clf.coef_[0]
        cav = cav / np.linalg.norm(cav)  # Normalize
        
        return cav, clf.score(X, y)
    
    def compute_tcav_score(self, class_images, cav):
        """
        Compute TCAV score: fraction of class where gradient aligns with CAV
        
        Args:
            class_images: Images of target class
            cav: Concept activation vector
            
        Returns:
            tcav_score: Value between 0 and 1
        """
        cav_tensor = torch.tensor(cav, dtype=torch.float32)
        
        alignments = []
        
        for img in class_images:
            img = img.unsqueeze(0).requires_grad_()
            
            # Forward pass
            _ = self.model(img)
            act = self.activations
            
            if act.dim() == 4:
                act = act.mean(dim=[2, 3])
            
            # Compute gradient of activation w.r.t. image
            # We want βˆ‡hΒ·CAV, so first get scalar hΒ·CAV
            directional_derivative = (act.flatten() * cav_tensor).sum()
            
            # Backward to get gradient
            directional_derivative.backward()
            
            # Check if positive (concept increases activation)
            alignments.append(1 if directional_derivative.item() > 0 else 0)
        
        tcav_score = np.mean(alignments)
        return tcav_score


# ============================================================================
# Counterfactual Generation
# ============================================================================

class CounterfactualGenerator:
    """
    Generate minimal perturbations that flip model prediction
    
    Theory:
    - Minimize: λ₁·||Ξ΄|| + Ξ»β‚‚Β·max(0, f(x+Ξ΄)_orig - max_{cβ‰ orig} f(x+Ξ΄)_c + ΞΊ)
    - First term: Keep change small
    - Second term: Push to different class
    """
    
    def __init__(self, model, lambda_dist=1.0, lambda_class=10.0, kappa=0.0):
        """
        Args:
            model: Classification model
            lambda_dist: Weight for distance term
            lambda_class: Weight for classification term
            kappa: Confidence margin
        """
        self.model = model
        self.lambda_dist = lambda_dist
        self.lambda_class = lambda_class
        self.kappa = kappa
    
    def generate(self, x, target_class=None, num_steps=500, lr=0.01):
        """
        Generate counterfactual
        
        Args:
            x: Original input (C, H, W)
            target_class: Target class (None = any different class)
            num_steps: Optimization steps
            lr: Learning rate
            
        Returns:
            x_cf: Counterfactual input
            history: Loss history
        """
        # Get original prediction
        with torch.no_grad():
            orig_pred = self.model(x.unsqueeze(0)).argmax(1).item()
        
        # Initialize perturbation
        delta = torch.zeros_like(x, requires_grad=True)
        optimizer = torch.optim.Adam([delta], lr=lr)
        
        history = {'loss': [], 'dist': [], 'class_loss': []}
        
        for step in range(num_steps):
            optimizer.zero_grad()
            
            # Perturbed input (clip to valid range)
            x_pert = torch.clamp(x + delta, 0, 1)
            
            # Get predictions
            logits = self.model(x_pert.unsqueeze(0))[0]
            
            # Distance term (L2)
            dist_loss = torch.norm(delta)
            
            # Classification term
            if target_class is not None:
                # Targeted: maximize target class score
                class_loss = -logits[target_class]
            else:
                # Untargeted: minimize original class, maximize others
                orig_score = logits[orig_pred]
                other_scores = torch.cat([logits[:orig_pred], logits[orig_pred+1:]])
                max_other = other_scores.max()
                
                # Hinge loss: want orig_score < max_other - kappa
                class_loss = F.relu(orig_score - max_other + self.kappa)
            
            # Combined loss
            loss = self.lambda_dist * dist_loss + self.lambda_class * class_loss
            
            loss.backward()
            optimizer.step()
            
            # Record
            history['loss'].append(loss.item())
            history['dist'].append(dist_loss.item())
            history['class_loss'].append(class_loss.item())
            
            # Check if successful
            with torch.no_grad():
                new_pred = self.model(x_pert.unsqueeze(0)).argmax(1).item()
                if new_pred != orig_pred:
                    if target_class is None or new_pred == target_class:
                        print(f"Counterfactual found at step {step}")
                        print(f"Original class: {orig_pred}, New class: {new_pred}")
                        print(f"L2 distance: {dist_loss.item():.4f}")
                        break
        
        x_cf = torch.clamp(x + delta, 0, 1).detach()
        
        return x_cf, history


# ============================================================================
# Faithfulness Metrics
# ============================================================================

class FaithfulnessEvaluator:
    """
    Evaluate attribution faithfulness using deletion/insertion metrics
    
    Theory:
    - Deletion: Remove features by importance, measure output drop
    - Insertion: Add features by importance, measure output rise
    - Good attribution β†’ sharp curves
    """
    
    def __init__(self, model):
        self.model = model
    
    def deletion_metric(self, x, attribution, target_class, num_steps=20):
        """
        Deletion curve: Remove top features, track output
        
        Args:
            x: Input image (C, H, W)
            attribution: Pixel attributions (C, H, W)
            target_class: Class to track
            num_steps: Number of deletion steps
            
        Returns:
            scores: Output scores after each deletion step
        """
        # Flatten and sort by importance
        attr_flat = attribution.flatten()
        indices = torch.argsort(attr_flat, descending=True)
        
        x_flat = x.flatten()
        x_current = x.clone()
        
        scores = []
        pixels_per_step = len(indices) // num_steps
        
        for step in range(num_steps + 1):
            # Get current output
            with torch.no_grad():
                output = self.model(x_current.unsqueeze(0))
                score = output[0, target_class].item()
                scores.append(score)
            
            # Remove next batch of pixels
            if step < num_steps:
                remove_indices = indices[step * pixels_per_step:(step + 1) * pixels_per_step]
                x_flat[remove_indices] = 0
                x_current = x_flat.view_as(x)
        
        return np.array(scores)
    
    def insertion_metric(self, x, attribution, target_class, num_steps=20):
        """
        Insertion curve: Add top features to blank image, track output
        """
        # Start with blank image
        x_current = torch.zeros_like(x)
        
        # Flatten and sort
        attr_flat = attribution.flatten()
        indices = torch.argsort(attr_flat, descending=True)
        x_flat = x.flatten()
        x_current_flat = x_current.flatten()
        
        scores = []
        pixels_per_step = len(indices) // num_steps
        
        for step in range(num_steps + 1):
            # Get current output
            with torch.no_grad():
                output = self.model(x_current.unsqueeze(0))
                score = output[0, target_class].item()
                scores.append(score)
            
            # Add next batch of pixels
            if step < num_steps:
                add_indices = indices[step * pixels_per_step:(step + 1) * pixels_per_step]
                x_current_flat[add_indices] = x_flat[add_indices]
                x_current = x_current_flat.view_as(x)
        
        return np.array(scores)
    
    def plot_curves(self, deletion_scores, insertion_scores):
        """Plot deletion and insertion curves"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        steps = np.arange(len(deletion_scores))
        
        # Deletion
        ax1.plot(steps, deletion_scores, 'r-', linewidth=2)
        ax1.fill_between(steps, deletion_scores, alpha=0.3, color='red')
        ax1.set_xlabel('Fraction of Features Removed')
        ax1.set_ylabel('Model Output')
        ax1.set_title('Deletion Curve (↓ faster is better)')
        ax1.grid(True, alpha=0.3)
        
        # Insertion
        ax2.plot(steps, insertion_scores, 'g-', linewidth=2)
        ax2.fill_between(steps, insertion_scores, alpha=0.3, color='green')
        ax2.set_xlabel('Fraction of Features Added')
        ax2.set_ylabel('Model Output')
        ax2.set_title('Insertion Curve (↑ faster is better)')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        return fig


# ============================================================================
# Demonstration
# ============================================================================

print("Advanced Interpretability Methods Implemented:")
print("=" * 60)
print("1. KernelSHAP - Model-agnostic Shapley value approximation")
print("2. DeepSHAP - Efficient SHAP for neural networks")
print("3. LIME - Local interpretable model-agnostic explanations")
print("4. AttentionRollout - Attention flow in transformers")
print("5. TCAV - Testing with concept activation vectors")
print("6. CounterfactualGenerator - Minimal prediction-flipping perturbations")
print("7. FaithfulnessEvaluator - Deletion/insertion metrics")
print("=" * 60)

# Example: LIME on our trained model
print("\nExample: LIME Explanation")
print("-" * 60)

# Get a test sample
test_sample = next(iter(test_loader))[0][0]
original_pred = model(test_sample.unsqueeze(0)).argmax(1).item()
print(f"Original prediction: {original_pred}")

# LIME explanation
lime = LIME(model, kernel_width=0.25)
superpixel_weights, segments = lime.explain_image(
    test_sample, 
    num_samples=500,
    num_features=10,
    num_superpixels=50
)

print(f"Number of superpixels: {len(np.unique(segments))}")
print(f"Top 5 important superpixels: {np.argsort(superpixel_weights)[-5:]}")
print(f"Importance range: [{superpixel_weights.min():.4f}, {superpixel_weights.max():.4f}]")

# Visualize
fig = lime.visualize_explanation(test_sample, superpixel_weights, segments, num_features=5)
plt.savefig('lime_explanation.png', dpi=150, bbox_inches='tight')
print("\nLIME visualization saved to 'lime_explanation.png'")
plt.show()

# Example: Faithfulness evaluation
print("\nExample: Faithfulness Evaluation")
print("-" * 60)

# Generate attribution (using integrated gradients from earlier)
test_sample_ig = test_sample.unsqueeze(0).requires_grad_()
baseline = torch.zeros_like(test_sample_ig)
attribution = torch.zeros_like(test_sample_ig)
num_steps_ig = 50

for alpha in np.linspace(0, 1, num_steps_ig):
    x_interp = baseline + alpha * (test_sample_ig - baseline)
    x_interp.requires_grad_()
    output = model(x_interp)
    output[0, original_pred].backward()
    attribution += x_interp.grad

attribution = (attribution * (test_sample_ig - baseline) / num_steps_ig).detach()[0]

# Evaluate faithfulness
evaluator = FaithfulnessEvaluator(model)
deletion_scores = evaluator.deletion_metric(test_sample, attribution, original_pred, num_steps=20)
insertion_scores = evaluator.insertion_metric(test_sample, attribution, original_pred, num_steps=20)

print(f"Deletion AUC: {np.trapz(deletion_scores) / len(deletion_scores):.4f}")
print(f"Insertion AUC: {np.trapz(insertion_scores) / len(insertion_scores):.4f}")

fig = evaluator.plot_curves(deletion_scores, insertion_scores)
plt.savefig('faithfulness_curves.png', dpi=150, bbox_inches='tight')
print("\nFaithfulness curves saved to 'faithfulness_curves.png'")
plt.show()

print("\n" + "=" * 60)
print("Key Takeaways:")
print("=" * 60)
print("1. SHAP: Theoretically grounded (Shapley values), but computationally expensive")
print("2. LIME: Fast and model-agnostic, but no theoretical guarantees")
print("3. Deletion/Insertion: Quantify attribution quality objectively")
print("4. Different methods suitable for different use cases")
print("=" * 60)

1. GradCAM TheoryΒΆ

Gradient-weighted Class Activation MappingΒΆ

\[\alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A_{ij}^k}\]
\[L_{\text{GradCAM}}^c = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)\]

where \(A^k\) are feature maps, \(\alpha_k^c\) are importance weights.

πŸ“š Reference Materials:

class ConvNet(nn.Module):
    """CNN for MNIST with feature extraction."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 10)
        
        self.features = None
        self.gradients = None
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        
        # Save features
        self.features = x
        
        # Register hook
        if x.requires_grad:
            x.register_hook(self.save_gradient)
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
    def save_gradient(self, grad):
        self.gradients = grad

model = ConvNet().to(device)
print("Model created")

Train ModelΒΆ

We train a convolutional neural network on an image classification task to serve as the subject of our interpretability analysis. The model should achieve high accuracy so that its learned features are meaningful – interpreting a poorly trained model yields noisy, uninformative explanations. Using a well-known architecture (like a small ResNet) also makes the results more comparable to published interpretability studies.

transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=1000)

def train_model(model, train_loader, n_epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            
            output = model(x)
            loss = F.cross_entropy(output, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch+1} complete")

train_model(model, train_loader, n_epochs=5)

Implement GradCAMΒΆ

Gradient-weighted Class Activation Mapping (Grad-CAM) produces a coarse localization heatmap highlighting which regions of the input image contributed most to the model’s prediction for a given class. It works by computing the gradient of the target class score with respect to the feature maps of the last convolutional layer, then weighting each feature map channel by its average gradient and summing: \(L^c_{\text{Grad-CAM}} = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)\), where \(\alpha_k^c = \frac{1}{Z}\sum_{i,j} \frac{\partial y^c}{\partial A^k_{ij}}\). The ReLU removes negative contributions, keeping only features that positively influence the target class. This method is architecture-agnostic and requires no retraining or architectural modification.

def generate_gradcam(model, img, target_class=None):
    """Generate GradCAM heatmap."""
    model.eval()
    
    # Forward pass
    img.requires_grad = True
    output = model(img)
    
    if target_class is None:
        target_class = output.argmax(dim=1).item()
    
    # Backward pass
    model.zero_grad()
    target = output[0, target_class]
    target.backward()
    
    # Get gradients and features
    gradients = model.gradients
    features = model.features
    
    # Compute weights
    weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
    
    # Weighted combination
    cam = torch.sum(weights * features, dim=1).squeeze()
    cam = F.relu(cam)
    
    # Normalize
    cam = cam - cam.min()
    cam = cam / cam.max()
    
    return cam.cpu().detach().numpy(), target_class

# Test
x_test, y_test = next(iter(test_loader))
img = x_test[0:1].to(device)
cam, pred_class = generate_gradcam(model, img)
print(f"Predicted: {pred_class}, CAM shape: {cam.shape}")

Visualize GradCAMΒΆ

Overlaying the Grad-CAM heatmap on the original image reveals whether the model is looking at the right regions for the right reasons. For a correctly classified image, the heatmap should highlight the object of interest rather than background features or spurious correlations. Comparing heatmaps across different classes for the same image shows how the model’s attention shifts depending on the target class. This visualization is widely used in healthcare, autonomous driving, and other safety-critical domains to build trust in model predictions.

def overlay_gradcam(img, cam):
    """Overlay heatmap on image."""
    # Resize CAM to image size
    cam_resized = cv2.resize(cam, (28, 28))
    
    # Create heatmap
    heatmap = plt.cm.jet(cam_resized)[:, :, :3]
    
    # Overlay
    img_gray = img.squeeze().cpu().numpy()
    img_rgb = np.stack([img_gray] * 3, axis=-1)
    
    overlay = 0.6 * img_rgb + 0.4 * heatmap
    overlay = np.clip(overlay, 0, 1)
    
    return overlay, heatmap

# Visualize multiple samples
fig, axes = plt.subplots(4, 4, figsize=(12, 12))

for i in range(4):
    img = x_test[i:i+1].to(device)
    cam, pred = generate_gradcam(model, img)
    
    overlay, heatmap = overlay_gradcam(img, cam)
    
    axes[i, 0].imshow(img.squeeze().cpu(), cmap='gray')
    axes[i, 0].set_title(f'Original (True: {y_test[i]})', fontsize=9)
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(cam, cmap='jet')
    axes[i, 1].set_title(f'GradCAM (Pred: {pred})', fontsize=9)
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(heatmap)
    axes[i, 2].set_title('Heatmap', fontsize=9)
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(overlay)
    axes[i, 3].set_title('Overlay', fontsize=9)
    axes[i, 3].axis('off')

plt.suptitle('GradCAM Visualizations', fontsize=13)
plt.tight_layout()
plt.show()

Saliency MapsΒΆ

Saliency maps are computed by taking the gradient of the output class score with respect to the input pixels: \(\frac{\partial y^c}{\partial x}\), then visualizing the absolute values. Unlike Grad-CAM which operates at the feature map level (coarse), saliency maps provide pixel-level attribution showing exactly which input pixels the model is most sensitive to. The resulting maps tend to highlight edges and textures, reflecting the features that convolutional networks rely on most heavily for classification.

def generate_saliency(model, img, target_class=None):
    """Generate saliency map."""
    model.eval()
    
    img.requires_grad = True
    output = model(img)
    
    if target_class is None:
        target_class = output.argmax(dim=1).item()
    
    model.zero_grad()
    target = output[0, target_class]
    target.backward()
    
    saliency = img.grad.abs().squeeze()
    return saliency.cpu().numpy(), target_class

# Visualize
fig, axes = plt.subplots(3, 6, figsize=(15, 7))

for i in range(6):
    img = x_test[i:i+1].to(device)
    
    # Original
    axes[0, i].imshow(img.squeeze().cpu(), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel('Original', fontsize=11)
    
    # GradCAM
    cam, _ = generate_gradcam(model, img)
    axes[1, i].imshow(cam, cmap='jet')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel('GradCAM', fontsize=11)
    
    # Saliency
    saliency, _ = generate_saliency(model, img)
    axes[2, i].imshow(saliency, cmap='hot')
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_ylabel('Saliency', fontsize=11)

plt.suptitle('Comparison: GradCAM vs Saliency', fontsize=13)
plt.tight_layout()
plt.show()

Integrated GradientsΒΆ

Integrated Gradients (Sundararajan et al., 2017) provides a more principled attribution method that satisfies two important axioms: completeness (attributions sum to the model’s output difference from a baseline) and sensitivity (any feature that changes the output receives non-zero attribution). It computes attributions by integrating the gradient along a straight-line path from a baseline input (typically all zeros) to the actual input: \(\text{IG}_i(x) = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x - x'))}{\partial x_i} d\alpha\). In practice, the integral is approximated with 50-300 interpolation steps, making it more computationally expensive than Grad-CAM or saliency maps but significantly more reliable.

def integrated_gradients(model, img, target_class=None, steps=50):
    """Compute integrated gradients."""
    model.eval()
    
    # Baseline (zeros)
    baseline = torch.zeros_like(img)
    
    # Get target class
    if target_class is None:
        output = model(img)
        target_class = output.argmax(dim=1).item()
    
    # Compute gradients along path
    integrated_grads = torch.zeros_like(img)
    
    for alpha in np.linspace(0, 1, steps):
        interpolated = baseline + alpha * (img - baseline)
        interpolated.requires_grad = True
        
        output = model(interpolated)
        target = output[0, target_class]
        
        model.zero_grad()
        target.backward()
        
        integrated_grads += interpolated.grad
    
    # Average and multiply by input difference
    integrated_grads = integrated_grads / steps
    integrated_grads = (img - baseline) * integrated_grads
    
    return integrated_grads.abs().squeeze().cpu().detach().numpy()

# Compare methods
img = x_test[0:1].to(device)
ig = integrated_gradients(model, img)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(img.squeeze().cpu(), cmap='gray')
axes[0].set_title('Original', fontsize=11)
axes[0].axis('off')

saliency, _ = generate_saliency(model, img)
axes[1].imshow(saliency, cmap='hot')
axes[1].set_title('Saliency', fontsize=11)
axes[1].axis('off')

axes[2].imshow(ig, cmap='hot')
axes[2].set_title('Integrated Gradients', fontsize=11)
axes[2].axis('off')

plt.tight_layout()
plt.show()

SummaryΒΆ

Interpretability Methods:ΒΆ

  1. GradCAM: Class activation via gradients

  2. Saliency Maps: Input gradient magnitude

  3. Integrated Gradients: Path integral of gradients

  4. SHAP: Shapley value-based attribution

Key Insights:ΒΆ

  • GradCAM shows spatial attention

  • Saliency highlights important pixels

  • IG reduces noise via integration

  • Different methods reveal different aspects

Applications:ΒΆ

  • Model debugging

  • Trust and transparency

  • Feature validation

  • Bias detection

  • Medical diagnosis explanation

Best Practices:ΒΆ

  • Use multiple methods

  • Validate with domain experts

  • Consider method limitations

  • Combine with quantitative metrics

Extensions:ΒΆ

  • Grad-CAM++: Improved localization

  • Score-CAM: Gradient-free

  • LIME: Local explanations

  • Attention rollout: Transformer interpretation