import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Knowledge DistillationΒΆ

Soft TargetsΒΆ

\[q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}\]

where \(T\) is temperature.

Distillation LossΒΆ

\[\mathcal{L} = \alpha \cdot \mathcal{L}_{\text{hard}} + (1-\alpha) \cdot T^2 \cdot \mathcal{L}_{\text{soft}}\]

πŸ“š Reference Materials:

class TeacherNet(nn.Module):
    """Large teacher model."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.fc1 = nn.Linear(256 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

class StudentNet(nn.Module):
    """Small student model."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

teacher = TeacherNet().to(device)
student = StudentNet().to(device)

print(f"Teacher parameters: {count_parameters(teacher):,}")
print(f"Student parameters: {count_parameters(student):,}")
print(f"Compression: {count_parameters(teacher)/count_parameters(student):.1f}x")

Train TeacherΒΆ

Knowledge distillation begins by training a large, high-capacity teacher model to achieve the best possible accuracy on the task. The teacher’s value lies not just in its hard predictions but in its soft probability outputs – the full distribution over classes captures inter-class similarities (e.g., a β€˜7’ looks somewhat like a β€˜1’). These soft targets, when used to train a smaller student, transfer more information per training example than hard one-hot labels, effectively compressing the teacher’s knowledge into a compact model.

# Load data
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=1000)

def train_model(model, train_loader, n_epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            
            output = model(x)
            loss = F.cross_entropy(output, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                output = model(x)
                pred = output.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        acc = 100 * correct / total
        print(f"Epoch {epoch+1}, Accuracy: {acc:.2f}%")

print("Training teacher...")
train_model(teacher, train_loader, n_epochs=10)

Distillation LossΒΆ

The distillation loss combines two terms: the standard cross-entropy with hard labels and the KL divergence between teacher and student soft outputs, tempered by a temperature parameter \(T\):

\[\mathcal{L} = \alpha \cdot \text{CE}(y, \sigma(z_s)) + (1-\alpha) \cdot T^2 \cdot \text{KL}(\sigma(z_t/T) \| \sigma(z_s/T))\]

Higher temperatures soften both distributions, revealing more of the teacher’s knowledge about inter-class relationships. The \(T^2\) scaling factor compensates for the reduced gradient magnitude at high temperatures. The mixing weight \(\alpha\) balances learning from ground truth versus learning from the teacher, with typical values around 0.1-0.5.

def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
    """Knowledge distillation loss."""
    # Soft targets
    soft_targets = F.softmax(teacher_logits / T, dim=1)
    soft_prob = F.log_softmax(student_logits / T, dim=1)
    soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T ** 2)
    
    # Hard targets
    hard_loss = F.cross_entropy(student_logits, labels)
    
    return alpha * hard_loss + (1 - alpha) * soft_loss

# Test
logits_s = torch.randn(4, 10)
logits_t = torch.randn(4, 10)
labels = torch.randint(0, 10, (4,))
loss = distillation_loss(logits_s, logits_t, labels)
print(f"Test loss: {loss.item():.4f}")

Train Student with DistillationΒΆ

The student model is significantly smaller than the teacher (fewer layers, narrower hidden dimensions) and would achieve lower accuracy if trained from scratch on hard labels alone. With distillation, the student benefits from the teacher’s richer gradient signal: instead of learning that an image is β€œdefinitely a 3”, it learns that it is β€œmostly a 3, somewhat like an 8, and a little like a 5”. This additional information acts as a powerful regularizer, often allowing the student to match or approach the teacher’s accuracy despite having a fraction of the parameters – a critical capability for deploying models on mobile devices and edge hardware.

def train_with_distillation(student, teacher, train_loader, n_epochs=10, T=3.0, alpha=0.5):
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    teacher.eval()
    
    losses = []
    
    for epoch in range(n_epochs):
        student.train()
        epoch_loss = 0
        
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            
            # Teacher predictions
            with torch.no_grad():
                teacher_logits = teacher(x)
            
            # Student predictions
            student_logits = student(x)
            
            # Distillation loss
            loss = distillation_loss(student_logits, teacher_logits, y, T, alpha)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            losses.append(loss.item())
        
        # Evaluate
        student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                output = student(x)
                pred = output.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        acc = 100 * correct / total
        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader):.4f}, Acc: {acc:.2f}%")
    
    return losses

# Train distilled student
student_distilled = StudentNet().to(device)
losses = train_with_distillation(student_distilled, teacher, train_loader, n_epochs=10, T=3.0)

plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(50)/50, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Distillation Loss', fontsize=11)
plt.title('Knowledge Distillation Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Compare ModelsΒΆ

A side-by-side comparison of the teacher, distilled student, and a student trained from scratch on hard labels quantifies the benefit of distillation. Key metrics include accuracy, parameter count, inference latency, and memory footprint. The distilled student should significantly outperform the scratch-trained student (demonstrating knowledge transfer) while using far fewer parameters than the teacher (demonstrating compression). This comparison is the standard evidence used to justify distillation in production ML pipelines.

# Train student without distillation
student_baseline = StudentNet().to(device)
print("\nTraining baseline student...")
train_model(student_baseline, train_loader, n_epochs=10)

# Compare accuracies
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

teacher_acc = evaluate(teacher, test_loader)
student_distilled_acc = evaluate(student_distilled, test_loader)
student_baseline_acc = evaluate(student_baseline, test_loader)

print(f"\nResults:")
print(f"Teacher: {teacher_acc:.2f}%")
print(f"Student (distilled): {student_distilled_acc:.2f}%")
print(f"Student (baseline): {student_baseline_acc:.2f}%")
print(f"Improvement: {student_distilled_acc - student_baseline_acc:.2f}%")

Temperature AnalysisΒΆ

The temperature parameter \(T\) controls how much of the teacher’s β€œdark knowledge” is transferred. At \(T = 1\) (standard softmax), the teacher’s outputs are often nearly one-hot, providing little more information than hard labels. As \(T\) increases, the distribution flattens and reveals the teacher’s confidence rankings among non-target classes. Sweeping \(T\) across a range (e.g., 1 to 20) and measuring student accuracy at each value reveals the optimal temperature for a given teacher-student pair – typically in the range of 3-10.

temperatures = [1.0, 2.0, 3.0, 5.0, 10.0]
temp_accs = []

for T in temperatures:
    student_temp = StudentNet().to(device)
    train_with_distillation(student_temp, teacher, train_loader, n_epochs=5, T=T, alpha=0.5)
    acc = evaluate(student_temp, test_loader)
    temp_accs.append(acc)
    print(f"T={T}: {acc:.2f}%")

plt.figure(figsize=(10, 6))
plt.plot(temperatures, temp_accs, 'bo-', markersize=8)
plt.axhline(y=student_baseline_acc, color='r', linestyle='--', label='Baseline')
plt.axhline(y=teacher_acc, color='g', linestyle='--', label='Teacher')
plt.xlabel('Temperature', fontsize=11)
plt.ylabel('Accuracy (%)', fontsize=11)
plt.title('Effect of Temperature on Distillation', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Soft vs Hard TargetsΒΆ

Comparing training with soft targets only (teacher outputs), hard targets only (ground truth labels), and the combined loss reveals each component’s contribution. Soft targets alone often outperform hard targets alone for the student, because the soft distribution provides more bits of information per example. The combined loss typically performs best, leveraging both the teacher’s inter-class structure and the ground truth’s correctness guarantee. This ablation study is important for understanding why distillation works and for tuning the mixing weight \(\alpha\) in practice.

# Get sample predictions
teacher.eval()
x_sample, y_sample = next(iter(test_loader))
x_sample = x_sample[:1].to(device)
y_sample = y_sample[:1].to(device)

with torch.no_grad():
    logits = teacher(x_sample)
    
    # Hard targets
    hard = F.softmax(logits, dim=1)
    
    # Soft targets
    soft_T3 = F.softmax(logits / 3, dim=1)
    soft_T10 = F.softmax(logits / 10, dim=1)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].bar(range(10), hard[0].cpu())
axes[0].set_title('Hard Targets (T=1)', fontsize=11)
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Probability')

axes[1].bar(range(10), soft_T3[0].cpu())
axes[1].set_title('Soft Targets (T=3)', fontsize=11)
axes[1].set_xlabel('Class')

axes[2].bar(range(10), soft_T10[0].cpu())
axes[2].set_title('Soft Targets (T=10)', fontsize=11)
axes[2].set_xlabel('Class')

plt.tight_layout()
plt.show()

SummaryΒΆ

Knowledge Distillation:ΒΆ

Transfer knowledge from teacher to student via soft targets.

Key Components:ΒΆ

  1. Teacher model - Large, accurate

  2. Student model - Small, efficient

  3. Temperature - Controls softness

  4. Loss combination - Hard + soft

Benefits:ΒΆ

  • Model compression

  • Better generalization

  • Faster inference

  • Dark knowledge transfer

Applications:ΒΆ

  • Mobile deployment

  • Edge devices

  • Ensemble compression

  • Cross-modal transfer

Extensions:ΒΆ

  • Self-distillation

  • Feature distillation

  • Attention transfer

  • Data-free distillation

Next Steps:ΒΆ

  • Study quantization

  • Explore pruning

  • Learn neural architecture search

Advanced Knowledge Distillation TheoryΒΆ

1. Introduction to Knowledge DistillationΒΆ

Knowledge Distillation (KD) transfers knowledge from a large, complex teacher model to a smaller, efficient student model.

1.1 MotivationΒΆ

Why distillation?

  • Model compression: Deploy on edge devices (mobile, IoT)

  • Inference speedup: Smaller models are faster

  • Ensemble knowledge: Distill from multiple teachers

  • Transfer learning: Cross-domain/architecture knowledge transfer

Key insight: Teacher’s soft predictions contain more information than hard labels.

1.2 Problem FormulationΒΆ

Teacher model: T(x; ΞΈ_T) β†’ Ε·_soft (pre-softmax logits) Student model: S(x; ΞΈ_S) β†’ Ε·_student

Goal: Train student to mimic teacher’s behavior:

min_{ΞΈ_S} L_distill(S(x), T(x)) + Ξ± L_task(S(x), y)

Where:

  • L_distill: Distillation loss (KL divergence, MSE)

  • L_task: Task loss (cross-entropy with ground truth)

  • Ξ±: Balance hyperparameter

2. Classic Knowledge Distillation (Hinton et al.)ΒΆ

Hinton’s KD [Hinton et al., 2015]: Use temperature-scaled softmax.

2.1 Temperature ScalingΒΆ

Softmax with temperature T:

p_i = exp(z_i / T) / Ξ£_j exp(z_j / T)

Where:

  • z_i: Logit for class i

  • T: Temperature parameter

  • T=1: Standard softmax

  • T>1: Softer distribution (more information in wrong classes)

Effect of temperature:

  • T β†’ 0: One-hot (hard labels)

  • T = 1: Standard probabilities

  • T β†’ ∞: Uniform distribution

2.2 Distillation LossΒΆ

KL divergence between teacher and student (both with temperature T):

L_KD = TΒ² Β· KL(softmax(z_T/T) || softmax(z_S/T))

Total loss:

L = (1-Ξ») L_CE(y_true, softmax(z_S)) + Ξ» L_KD

Where:

  • Ξ»: Distillation weight (typically 0.5-0.9)

  • TΒ²: Scaling factor (compensates for softened gradients)

Intuition: Soft targets reveal similarity structure between classes.

2.3 Why Temperature WorksΒΆ

Information in soft targets:

  • Hard label: β€œThis is a cat” (1 bit)

  • Soft target: β€œ90% cat, 8% dog, 2% tiger” (rich information)

Dark knowledge: Information in near-zero probabilities.

Example: For image of β€œ3”, teacher might output:

  • p(3) = 0.95

  • p(8) = 0.03 (similar shape)

  • p(5) = 0.01 (somewhat similar)

  • Others: ~0

Student learns similarity structure!

3. Response-Based DistillationΒΆ

3.1 Logit MatchingΒΆ

Direct logit matching (no softmax):

L = MSE(z_S, z_T)

Advantage: Simpler, no temperature tuning. Disadvantage: Sensitive to logit scale.

3.2 Regression DistillationΒΆ

For regression tasks:

L = ||f_S(x) - f_T(x)||Β²

Application: Object detection bbox regression, depth estimation.

3.3 Ranking DistillationΒΆ

RKD [Park et al., 2019]: Preserve relative distances.

Distance-wise:

L_dist = Ξ£_{i,j} Ο†(||f_T(x_i) - f_T(x_j)||Β² - ||f_S(x_i) - f_S(x_j)||Β²)

Angle-wise:

L_angle = Ξ£_{i,j,k} Ο†(∠(x_i, x_j, x_k)_T - ∠(x_i, x_j, x_k)_S)

Where Ο† is Huber loss.

4. Feature-Based DistillationΒΆ

4.1 FitNet (Hint Learning)ΒΆ

FitNet [Romero et al., 2015]: Match intermediate features.

Regressor: Transform student features to match teacher dimension:

L_hint = ||W_r Β· f_S^l - f_T^m||Β²

Where:

  • f_S^l: Student feature at layer l

  • f_T^m: Teacher feature at layer m

  • W_r: Learned transformation matrix

Two-stage training:

  1. Train student with hint loss to match intermediate layers

  2. Fine-tune with distillation loss on outputs

4.2 Attention TransferΒΆ

AT [Zagoruyko & Komodakis, 2017]: Transfer attention maps.

Spatial attention map:

A(F) = Ξ£_c |F^c|^p

Where F^c is feature channel c.

Loss:

L_AT = Ξ£_l ||A(F_S^l) / ||A(F_S^l)||_2 - A(F_T^l) / ||A(F_T^l)||_2||Β²

Intuition: Student learns where teacher focuses attention.

4.3 Similarity-Preserving KDΒΆ

SP [Tung & Mori, 2019]: Preserve pairwise similarity.

Similarity matrix:

G_T[i,j] = f_T(x_i)^T f_T(x_j) / (||f_T(x_i)|| ||f_T(x_j)||)

Loss:

L_SP = ||G_S - G_T||Β²_F

Advantage: Captures relational structure across samples.

4.4 Correlation CongruenceΒΆ

CRD [Tian et al., 2020]: Contrastive Representation Distillation.

InfoNCE-based loss:

L_CRD = -log[exp(f_S^T f_T / Ο„) / (exp(f_S^T f_T / Ο„) + Ξ£_k exp(f_S^T f_k^- / Ο„))]

Where f_k^- are negative samples.

Benefit: Discriminative feature learning via contrastive objective.

5. Relation-Based DistillationΒΆ

5.1 Relational Knowledge Distillation (RKD)ΒΆ

Transfer structural information between instances.

Distance-wise RKD: Preserve relative Euclidean distances in feature space.

Angle-wise RKD: Preserve angles formed by triplets of points.

5.2 Instance Relationship GraphΒΆ

IRG [Liu et al., 2019]: Model as graph distillation.

Graph: Nodes = samples, Edges = relationships

Loss:

L_IRG = Ξ£_{i,j} w_ij Β· ||r_S(x_i, x_j) - r_T(x_i, x_j)||Β²

Where r(Β·,Β·) is relationship function.

6. Self-DistillationΒΆ

6.1 Born-Again Networks (BAN)ΒΆ

BAN [Furlanello et al., 2018]: Student same capacity as teacher.

Procedure:

  1. Train teacher T_1

  2. Distill to student S_1 (same architecture)

  3. Repeat: S_1 becomes T_2, train S_2, etc.

Finding: Each generation improves (up to 3-4 iterations).

6.2 Deep Mutual Learning (DML)ΒΆ

DML [Zhang et al., 2018]: Multiple students learn collaboratively.

No pre-trained teacher: All models trained from scratch.

Loss for student i:

L_i = L_CE(y, S_i(x)) + Σ_{j≠i} KL(S_j(x) || S_i(x))

Benefit: Ensemble knowledge without ensemble inference.

6.3 Online Knowledge DistillationΒΆ

On-the-fly: Generate teacher predictions during training.

Approaches:

  • Multi-branch: Branches share backbone, distill from each other

  • Temporal ensemble: Teacher is EMA of student

  • Self-training: Pseudo-labels from student’s own predictions

7. Cross-Modal DistillationΒΆ

7.1 Cross-Domain DistillationΒΆ

Transfer across domains: RGB β†’ Depth, Day β†’ Night, etc.

Challenge: Input modalities differ.

Solution: Match feature representations, not inputs.

7.2 Cross-Task DistillationΒΆ

Transfer knowledge between different tasks.

Example: Classification teacher β†’ Detection student

Approach: Align feature spaces via intermediate layers.

7.3 Privileged Information DistillationΒΆ

Learning Using Privileged Information (LUPI):

  • Teacher has extra modality at training (e.g., depth maps)

  • Student uses only RGB at test

Knowledge transfer: Student learns from teacher’s privileged view.

8. Adversarial DistillationΒΆ

8.1 Adversarial Training for DistillationΒΆ

Generator: Student trying to fool discriminator Discriminator: Distinguish teacher vs. student outputs

Loss:

L_S = L_task + Ξ± L_KD + Ξ² L_adv

Where L_adv encourages student to match teacher’s distribution.

8.2 Data-Free DistillationΒΆ

Scenario: No access to original training data.

Solution: Generate synthetic data.

Approaches:

  1. Gradient matching: Generate inputs that match teacher gradients

  2. Activation matching: Maximize activation of teacher neurons

  3. GAN-based: Train GAN to generate realistic inputs

DeepInversion [Yin et al., 2020]:

L = L_task(T(x_gen)) + Ξ± L_BN(x_gen) + Ξ² L_prior(x_gen)

Where:

  • L_BN: Batch normalization statistics matching

  • L_prior: Image prior (smoothness, etc.)

9. Advanced Distillation TechniquesΒΆ

9.1 Layer-wise DistillationΒΆ

Match at multiple layers simultaneously:

L = Ξ£_l w_l L_layer(f_S^l, f_T^l)

Adaptive weights: Learn w_l or use task-dependent heuristics.

9.2 Progressive DistillationΒΆ

Curriculum: Start with easy knowledge, progress to complex.

Temperature scheduling: T decreases over training (hard β†’ soft).

9.3 Quantization-Aware DistillationΒΆ

Goal: Distill to quantized student (INT8, binary).

Approach: Quantize during distillation training.

Loss:

L = L_KD(quant(S(x)), T(x)) + L_CE(quant(S(x)), y)

9.4 Neural Architecture Search + KDΒΆ

AutoKD: Search for optimal student architecture while distilling.

Joint optimization: Architecture search + knowledge transfer.

10. Theoretical FoundationsΒΆ

10.1 Why Does Distillation Work?ΒΆ

Hypothesis 1: Dark knowledge (soft targets) provides regularization.

Hypothesis 2: Teacher smooths decision boundaries.

Hypothesis 3: Distillation implicitly optimizes for generalization.

Empirical evidence: Students often generalize better than teachers!

10.2 Capacity GapΒΆ

Teacher capacity: High (many parameters) Student capacity: Low (few parameters)

Trade-off:

  • Gap too large β†’ Student can’t mimic teacher

  • Gap too small β†’ Limited compression

Optimal gap: Task-dependent, typically teacher 5-10Γ— larger.

10.3 Generalization BoundΒΆ

PAC-Bayesian bound for distilled student:

With probability β‰₯ 1-Ξ΄:

L_test(S) ≀ L_KD(S, T) + O(√(KL(P_S || P_prior) / N))

Implication: Distillation acts as learned prior from teacher.

10.4 Information-Theoretic ViewΒΆ

Mutual information:

I(X; Y) β‰₯ I(f_S(X); Y)

Goal: Maximize mutual information between student features and labels.

Distillation: I(f_S; f_T) + I(f_T; Y) β†’ I(f_S; Y)

11. Practical ConsiderationsΒΆ

11.1 Hyperparameter SelectionΒΆ

Critical hyperparameters:

Parameter

Range

Notes

Temperature T

1-20

Higher for similar tasks

Distillation weight Ξ»

0.1-0.9

0.5-0.7 typical

Feature matching layers

1-5

Middle layers work best

Training epochs

1.5-2Γ— normal

Student needs more time

11.2 Teacher SelectionΒΆ

Bigger not always better: Very large teachers may not help.

Ensemble teachers: Average multiple teachers’ predictions.

Architecture mismatch: CNN β†’ Transformer works if capacity appropriate.

11.3 Student Architecture DesignΒΆ

Width vs. Depth trade-off:

  • Wider, shallower: Faster inference, easier distillation

  • Narrower, deeper: Better feature learning, harder distillation

Skip connections: Help gradient flow in deep students.

11.4 Training StrategiesΒΆ

Two-stage:

  1. Pre-train teacher on task

  2. Distill to student

Co-training:

  • Train teacher and student jointly

  • Mutual learning

Temperature annealing: Start high, decrease over training.

12. ApplicationsΒΆ

12.1 Model CompressionΒΆ

Typical compression ratios:

  • BERT β†’ DistilBERT: 40% size, 97% performance

  • ResNet-152 β†’ ResNet-18: 10Γ— speedup, ~5% accuracy drop

  • EfficientNet-B7 β†’ B0: 30Γ— fewer parameters

12.2 Edge DeploymentΒΆ

Mobile devices: Distilled models fit memory/compute constraints.

Real-time inference: Faster student enables real-time applications.

Battery efficiency: Smaller models consume less power.

12.3 Ensemble CompressionΒΆ

Distill ensemble into single model:

  • Ensemble: 5-10 models (slow, accurate)

  • Distilled student: 1 model (fast, nearly as accurate)

Benefit: Ensemble accuracy without ensemble cost.

12.4 Continual LearningΒΆ

Distillation prevents forgetting:

  • Old model = Teacher

  • Updated model = Student

  • Retain old knowledge via distillation loss

13. Recent Advances (2020-2024)ΒΆ

13.1 Vision Transformer DistillationΒΆ

DeiT [Touvron et al., 2021]: Data-efficient ViT training.

Distillation token: Extra token attending to teacher.

Hard vs. Soft distillation:

  • Soft: KL(teacher_soft || student)

  • Hard: CE(teacher_argmax, student)

Finding: Hard distillation often better for ViTs.

13.2 Large Language Model DistillationΒΆ

Distilling BERT:

  • DistilBERT: 40% smaller, 60% faster

  • TinyBERT: Layer-wise + prediction distillation

  • MobileBERT: Bottleneck architecture

Distilling GPT:

  • GPT-2 β†’ GPT-2 Small/Medium

  • Layer dropping, width reduction

13.3 Multimodal DistillationΒΆ

CLIP distillation: Image-text teacher β†’ Image-only student

Approach: Match vision encoder while removing text tower.

13.4 Test-Time DistillationΒΆ

Adapt student at test time using unlabeled data.

Self-distillation: Student becomes its own teacher.

14. Comparison of MethodsΒΆ

14.1 Performance SummaryΒΆ

ImageNet (ResNet-34 Teacher β†’ Student):

Student Arch

Method

Top-1 Acc

Speedup

ResNet-18

Vanilla

69.8%

1.8Γ—

ResNet-18

KD (Hinton)

71.0%

1.8Γ—

ResNet-18

FitNet

71.5%

1.8Γ—

ResNet-18

AT

71.8%

1.8Γ—

ResNet-18

CRD

72.2%

1.8Γ—

MobileNetV2

KD

68.5%

3.5Γ—

Compression ratio: Teacher (21M params) β†’ Student (11M params)

14.2 Method CharacteristicsΒΆ

Response-based (Hinton KD):

  • βœ“ Simple to implement

  • βœ“ Architecture-agnostic

  • βœ— Only final layer information

Feature-based (FitNet, AT):

  • βœ“ Rich intermediate knowledge

  • βœ“ Better for large capacity gap

  • βœ— Architecture coupling

  • βœ— Hyperparameter sensitive

Relation-based (RKD, SP):

  • βœ“ Captures structural knowledge

  • βœ“ Generalizes across samples

  • βœ— Quadratic complexity O(nΒ²)

15. Limitations and ChallengesΒΆ

15.1 Current LimitationsΒΆ

  1. Capacity gap: Very small students can’t learn from very large teachers

  2. Architecture constraints: Some methods require similar architectures

  3. Computational cost: Training time often 1.5-2Γ— longer

  4. Hyperparameter sensitivity: Temperature, Ξ» require tuning

  5. Task dependence: What works for classification may not for detection

15.2 Open ProblemsΒΆ

  1. Optimal temperature: No principled way to select T

  2. Layer matching: Which teacher layers to use?

  3. Student architecture: How to design for distillability?

  4. Theoretical understanding: Why does it work so well?

  5. Negative transfer: When does distillation hurt?

15.3 Future DirectionsΒΆ

  1. Automated distillation: NAS for student + automatic hyperparameter tuning

  2. Task-agnostic KD: Universal distillation framework

  3. Efficient distillation: Reduce training time overhead

  4. Multi-teacher distillation: Optimally combine multiple teachers

  5. Lifelong distillation: Continual model updates via distillation

16. Key TakeawaysΒΆ

  1. Knowledge distillation transfers knowledge from teacher to student via soft targets

  2. Temperature scaling reveals dark knowledge in near-zero probabilities

  3. Response-based (outputs), feature-based (intermediate), relation-based (structural)

  4. Self-distillation (BAN, DML) improves without pre-trained teacher

  5. Compression: 2-10Γ— smaller models with <5% accuracy drop

  6. Applications: Edge deployment, ensemble compression, continual learning

  7. Hyperparameters: T=4-6, Ξ»=0.5-0.7 typical

  8. Recent: ViT distillation, LLM distillation, data-free methods

  9. Trade-off: Compression ratio vs. performance vs. training time

  10. Best practice: Start with Hinton’s KD, add feature matching if needed

Core insight: Soft targets contain rich information about similarity structure that hard labels cannot provide.

17. Mathematical SummaryΒΆ

Classic distillation loss:

L = (1-Ξ») CE(y, softmax(z_S)) + Ξ» TΒ² KL(softmax(z_T/T) || softmax(z_S/T))

Feature matching:

L_feature = Ξ£_l ||f_S^l - f_T^l||Β²

Relational distillation:

L_relation = Ξ£_{i,j} ||sim(f_S(x_i), f_S(x_j)) - sim(f_T(x_i), f_T(x_j))||Β²

Attention transfer:

L_AT = Ξ£_l ||normalize(A(F_S^l)) - normalize(A(F_T^l))||Β²

ReferencesΒΆ

  1. Hinton et al. (2015) β€œDistilling the Knowledge in a Neural Network”

  2. Romero et al. (2015) β€œFitNets: Hints for Thin Deep Nets”

  3. Zagoruyko & Komodakis (2017) β€œPaying More Attention to Attention”

  4. Park et al. (2019) β€œRelational Knowledge Distillation”

  5. Tian et al. (2020) β€œContrastive Representation Distillation”

  6. Furlanello et al. (2018) β€œBorn Again Neural Networks”

  7. Zhang et al. (2018) β€œDeep Mutual Learning”

  8. Yin et al. (2020) β€œDreaming to Distill: Data-Free Knowledge Transfer”

  9. Touvron et al. (2021) β€œTraining Data-Efficient Image Transformers (DeiT)”

  10. Gou et al. (2021) β€œKnowledge Distillation: A Survey”

"""
Advanced Knowledge Distillation Implementations

This notebook provides production-ready PyTorch implementations of various
knowledge distillation techniques.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import List, Tuple, Dict, Optional, Callable
import matplotlib.pyplot as plt
from dataclasses import dataclass

# ============================================================================
# 1. Classic Knowledge Distillation (Hinton)
# ============================================================================

class DistillationLoss(nn.Module):
    """
    Classic knowledge distillation loss with temperature scaling.
    
    Loss = (1-Ξ±) * CE(y_true, student) + Ξ± * TΒ² * KL(teacher || student)
    
    Args:
        temperature: Temperature for softening distributions (default: 4.0)
        alpha: Weight for distillation loss (default: 0.7)
    """
    def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Args:
            student_logits: [batch, num_classes]
            teacher_logits: [batch, num_classes]
            labels: [batch] ground truth labels
            
        Returns:
            total_loss, loss_dict
        """
        # Task loss (cross-entropy with hard labels)
        task_loss = self.ce_loss(student_logits, labels)
        
        # Distillation loss (KL divergence with soft targets)
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        
        # KL(P || Q) = Ξ£ P(x) log(P(x) / Q(x))
        distill_loss = F.kl_div(
            student_soft,
            teacher_soft,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # Combined loss
        total_loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss
        
        loss_dict = {
            'total': total_loss.item(),
            'task': task_loss.item(),
            'distill': distill_loss.item()
        }
        
        return total_loss, loss_dict


# ============================================================================
# 2. Feature-Based Distillation
# ============================================================================

class FeatureDistillationLoss(nn.Module):
    """
    FitNet-style feature matching at intermediate layers.
    
    Matches student features to teacher features using regressor.
    
    Args:
        student_channels: Student feature channels
        teacher_channels: Teacher feature channels
        layers_to_match: List of layer indices to match
    """
    def __init__(
        self,
        student_channels: List[int],
        teacher_channels: List[int],
        layers_to_match: List[int]
    ):
        super().__init__()
        self.layers_to_match = layers_to_match
        
        # Regressor to transform student features
        self.regressors = nn.ModuleList([
            nn.Conv2d(s_ch, t_ch, kernel_size=1, bias=False)
            for s_ch, t_ch in zip(student_channels, teacher_channels)
        ])
        
    def forward(
        self,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Args:
            student_features: List of [batch, C_s, H, W]
            teacher_features: List of [batch, C_t, H, W]
            
        Returns:
            feature_loss: MSE between matched features
        """
        total_loss = 0.0
        
        for idx in self.layers_to_match:
            # Transform student features
            s_feat = self.regressors[idx](student_features[idx])
            t_feat = teacher_features[idx]
            
            # Resize if spatial dimensions don't match
            if s_feat.shape[2:] != t_feat.shape[2:]:
                s_feat = F.adaptive_avg_pool2d(s_feat, t_feat.shape[2:])
            
            # L2 loss
            loss = F.mse_loss(s_feat, t_feat)
            total_loss += loss
        
        return total_loss / len(self.layers_to_match)


class AttentionTransferLoss(nn.Module):
    """
    Attention Transfer (Zagoruyko & Komodakis, 2017).
    
    Transfers spatial attention maps from teacher to student.
    Attention map: A(F) = Ξ£_c |F^c|^p
    
    Args:
        p: Power for attention computation (default: 2)
    """
    def __init__(self, p: float = 2.0):
        super().__init__()
        self.p = p
        
    def attention_map(self, features: torch.Tensor) -> torch.Tensor:
        """
        Compute spatial attention map.
        
        Args:
            features: [batch, channels, height, width]
            
        Returns:
            attention: [batch, height, width]
        """
        # Sum absolute values across channels
        attention = torch.sum(torch.abs(features) ** self.p, dim=1)
        return attention
    
    def normalize_attention(self, attention: torch.Tensor) -> torch.Tensor:
        """Normalize attention map to unit norm."""
        batch_size = attention.size(0)
        attention_flat = attention.view(batch_size, -1)
        norm = torch.norm(attention_flat, p=2, dim=1, keepdim=True)
        attention_flat = attention_flat / (norm + 1e-8)
        return attention_flat.view_as(attention)
    
    def forward(
        self,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Args:
            student_features: List of [batch, C, H, W]
            teacher_features: List of [batch, C, H, W]
            
        Returns:
            at_loss: Attention transfer loss
        """
        total_loss = 0.0
        
        for s_feat, t_feat in zip(student_features, teacher_features):
            # Compute attention maps
            s_attention = self.attention_map(s_feat)
            t_attention = self.attention_map(t_feat)
            
            # Normalize
            s_attention = self.normalize_attention(s_attention)
            t_attention = self.normalize_attention(t_attention)
            
            # L2 loss
            loss = F.mse_loss(s_attention, t_attention)
            total_loss += loss
        
        return total_loss / len(student_features)


# ============================================================================
# 3. Relation-Based Distillation
# ============================================================================

class RelationalKDLoss(nn.Module):
    """
    Relational Knowledge Distillation (Park et al., 2019).
    
    Preserves relative distances and angles in feature space.
    
    Args:
        distance_weight: Weight for distance-wise loss
        angle_weight: Weight for angle-wise loss
    """
    def __init__(self, distance_weight: float = 1.0, angle_weight: float = 2.0):
        super().__init__()
        self.distance_weight = distance_weight
        self.angle_weight = angle_weight
        
    def pdist(self, embeddings: torch.Tensor, squared: bool = False) -> torch.Tensor:
        """
        Pairwise distance matrix.
        
        Args:
            embeddings: [batch, dim]
            
        Returns:
            distances: [batch, batch]
        """
        dot_product = embeddings @ embeddings.t()
        squared_norm = dot_product.diag()
        
        # ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a,b>
        distances = squared_norm.unsqueeze(0) + squared_norm.unsqueeze(1) - 2 * dot_product
        distances = torch.clamp(distances, min=0.0)
        
        if not squared:
            # Avoid NaN gradients
            mask = (distances == 0.0).float()
            distances = distances + mask * 1e-16
            distances = torch.sqrt(distances)
            distances = distances * (1.0 - mask)
        
        return distances
    
    def distance_wise_loss(
        self,
        student_embed: torch.Tensor,
        teacher_embed: torch.Tensor
    ) -> torch.Tensor:
        """
        L_dist = Ξ£_{i,j} Huber(||s_i - s_j||Β² - ||t_i - t_j||Β²)
        """
        student_dist = self.pdist(student_embed, squared=True)
        teacher_dist = self.pdist(teacher_embed, squared=True)
        
        # Huber loss (smooth L1)
        loss = F.smooth_l1_loss(student_dist, teacher_dist, reduction='mean')
        return loss
    
    def angle_wise_loss(
        self,
        student_embed: torch.Tensor,
        teacher_embed: torch.Tensor
    ) -> torch.Tensor:
        """
        L_angle = Σ_{i,j,k} Huber(∠(s_i, s_j, s_k) - ∠(t_i, t_j, t_k))
        
        Simplified: Use cosine similarity instead of explicit angles.
        """
        # Normalize embeddings
        student_norm = F.normalize(student_embed, p=2, dim=1)
        teacher_norm = F.normalize(teacher_embed, p=2, dim=1)
        
        # Cosine similarity matrix
        student_sim = student_norm @ student_norm.t()
        teacher_sim = teacher_norm @ teacher_norm.t()
        
        # L2 loss on similarity matrices
        loss = F.mse_loss(student_sim, teacher_sim)
        return loss
    
    def forward(
        self,
        student_embed: torch.Tensor,
        teacher_embed: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Args:
            student_embed: [batch, embed_dim]
            teacher_embed: [batch, embed_dim]
            
        Returns:
            total_loss, loss_dict
        """
        dist_loss = self.distance_wise_loss(student_embed, teacher_embed)
        angle_loss = self.angle_wise_loss(student_embed, teacher_embed)
        
        total_loss = (
            self.distance_weight * dist_loss +
            self.angle_weight * angle_loss
        )
        
        loss_dict = {
            'distance': dist_loss.item(),
            'angle': angle_loss.item(),
            'total': total_loss.item()
        }
        
        return total_loss, loss_dict


class SimilarityPreservingLoss(nn.Module):
    """
    Similarity-Preserving KD (Tung & Mori, 2019).
    
    Preserves pairwise cosine similarity.
    
    G[i,j] = <f_i, f_j> / (||f_i|| ||f_j||)
    """
    def __init__(self):
        super().__init__()
        
    def similarity_matrix(self, features: torch.Tensor) -> torch.Tensor:
        """
        Compute normalized Gram matrix.
        
        Args:
            features: [batch, dim]
            
        Returns:
            similarity: [batch, batch]
        """
        # Normalize features
        features_norm = F.normalize(features, p=2, dim=1)
        
        # Gram matrix
        similarity = features_norm @ features_norm.t()
        return similarity
    
    def forward(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            student_features: [batch, dim]
            teacher_features: [batch, dim]
            
        Returns:
            sp_loss: Similarity preservation loss
        """
        student_sim = self.similarity_matrix(student_features)
        teacher_sim = self.similarity_matrix(teacher_features)
        
        # Frobenius norm
        loss = torch.norm(student_sim - teacher_sim, p='fro') ** 2
        
        # Normalize by batch size squared
        batch_size = student_features.size(0)
        loss = loss / (batch_size ** 2)
        
        return loss


# ============================================================================
# 4. Self-Distillation
# ============================================================================

class DeepMutualLearning(nn.Module):
    """
    Deep Mutual Learning (Zhang et al., 2018).
    
    Multiple students learn collaboratively without pre-trained teacher.
    
    Args:
        num_students: Number of student networks
        temperature: Temperature for distillation
    """
    def __init__(self, num_students: int = 2, temperature: float = 3.0):
        super().__init__()
        self.num_students = num_students
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(
        self,
        student_logits: List[torch.Tensor],
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Args:
            student_logits: List of [batch, num_classes] from each student
            labels: [batch] ground truth
            
        Returns:
            total_loss, individual_losses
        """
        losses = []
        
        for i, logits_i in enumerate(student_logits):
            # Classification loss
            ce_loss = self.ce_loss(logits_i, labels)
            
            # Distillation from other students
            distill_loss = 0.0
            for j, logits_j in enumerate(student_logits):
                if i != j:
                    # KL divergence
                    soft_i = F.log_softmax(logits_i / self.temperature, dim=1)
                    soft_j = F.softmax(logits_j.detach() / self.temperature, dim=1)
                    
                    kl = F.kl_div(soft_i, soft_j, reduction='batchmean')
                    distill_loss += kl
            
            # Average distillation loss
            distill_loss = distill_loss / (self.num_students - 1)
            
            # Total loss for this student
            total = ce_loss + self.temperature ** 2 * distill_loss
            losses.append(total)
        
        # Sum across all students
        total_loss = sum(losses)
        
        return total_loss, losses


class OnlineDistillation(nn.Module):
    """
    Online distillation with temporal ensemble teacher.
    
    Teacher = Exponential Moving Average of student.
    
    Args:
        momentum: EMA momentum (default: 0.999)
        temperature: Distillation temperature
    """
    def __init__(self, momentum: float = 0.999, temperature: float = 4.0):
        super().__init__()
        self.momentum = momentum
        self.temperature = temperature
        self.teacher_model = None
        self.ce_loss = nn.CrossEntropyLoss()
        
    def update_teacher(self, student_model: nn.Module):
        """
        Update teacher as EMA of student.
        
        ΞΈ_teacher = m * ΞΈ_teacher + (1-m) * ΞΈ_student
        """
        if self.teacher_model is None:
            # Initialize teacher as copy of student
            self.teacher_model = type(student_model)()
            self.teacher_model.load_state_dict(student_model.state_dict())
            self.teacher_model.eval()
        else:
            # EMA update
            with torch.no_grad():
                for teacher_param, student_param in zip(
                    self.teacher_model.parameters(),
                    student_model.parameters()
                ):
                    teacher_param.data = (
                        self.momentum * teacher_param.data +
                        (1 - self.momentum) * student_param.data
                    )
    
    def forward(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Args:
            student_logits: [batch, num_classes]
            teacher_logits: [batch, num_classes] from EMA teacher
            labels: [batch]
            
        Returns:
            total_loss, loss_dict
        """
        # Classification loss
        ce_loss = self.ce_loss(student_logits, labels)
        
        # Distillation loss
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        
        kd_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # Combined
        total_loss = ce_loss + kd_loss
        
        loss_dict = {
            'ce': ce_loss.item(),
            'kd': kd_loss.item(),
            'total': total_loss.item()
        }
        
        return total_loss, loss_dict


# ============================================================================
# 5. Data-Free Distillation
# ============================================================================

class DataFreeDistillation:
    """
    Data-free knowledge distillation via synthetic data generation.
    
    Generates synthetic inputs that maximize teacher activation.
    
    Args:
        teacher: Pre-trained teacher model
        input_shape: Shape of input (e.g., [3, 32, 32])
        num_classes: Number of output classes
    """
    def __init__(
        self,
        teacher: nn.Module,
        input_shape: Tuple[int, ...],
        num_classes: int
    ):
        self.teacher = teacher
        self.teacher.eval()
        self.input_shape = input_shape
        self.num_classes = num_classes
        
    def generate_samples(
        self,
        batch_size: int,
        num_iterations: int = 1000,
        lr: float = 0.1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate synthetic samples via gradient ascent.
        
        Maximize: -CE(teacher(x), uniform) + prior(x)
        
        Args:
            batch_size: Number of samples to generate
            num_iterations: Optimization steps
            lr: Learning rate
            
        Returns:
            synthetic_inputs, teacher_logits
        """
        # Initialize random inputs
        inputs = torch.randn(
            batch_size, *self.input_shape,
            requires_grad=True
        )
        
        optimizer = optim.Adam([inputs], lr=lr)
        
        for _ in range(num_iterations):
            optimizer.zero_grad()
            
            # Forward through teacher
            with torch.no_grad():
                teacher_logits = self.teacher(inputs)
            
            # Maximize diversity (minimize entropy)
            probs = F.softmax(teacher_logits, dim=1)
            entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1).mean()
            
            # Image prior (smoothness)
            tv_loss = (
                torch.sum(torch.abs(inputs[:, :, :, :-1] - inputs[:, :, :, 1:])) +
                torch.sum(torch.abs(inputs[:, :, :-1, :] - inputs[:, :, 1:, :]))
            )
            
            # Total loss
            loss = entropy + 0.0001 * tv_loss
            
            loss.backward()
            optimizer.step()
            
            # Clamp to valid range
            inputs.data = torch.clamp(inputs.data, -1, 1)
        
        # Get final teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(inputs.detach())
        
        return inputs.detach(), teacher_logits


# ============================================================================
# 6. Distillation Trainer
# ============================================================================

@dataclass
class DistillationConfig:
    """Configuration for distillation training."""
    method: str = 'hinton'  # 'hinton', 'fitnet', 'attention', 'rkd', 'dml'
    temperature: float = 4.0
    alpha: float = 0.7
    num_epochs: int = 100
    learning_rate: float = 0.001
    batch_size: int = 128
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'


class DistillationTrainer:
    """
    Unified trainer for various distillation methods.
    
    Args:
        teacher: Pre-trained teacher model
        student: Student model to train
        config: Training configuration
    """
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        config: DistillationConfig
    ):
        self.teacher = teacher.to(config.device)
        self.student = student.to(config.device)
        self.config = config
        self.device = config.device
        
        # Freeze teacher
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
        
        # Setup loss
        self.setup_loss()
        
        # Optimizer
        self.optimizer = optim.Adam(
            self.student.parameters(),
            lr=config.learning_rate
        )
        
        # Metrics
        self.train_losses = []
        self.val_accuracies = []
        
    def setup_loss(self):
        """Initialize loss function based on method."""
        if self.config.method == 'hinton':
            self.criterion = DistillationLoss(
                temperature=self.config.temperature,
                alpha=self.config.alpha
            )
        elif self.config.method == 'rkd':
            self.criterion = RelationalKDLoss()
        elif self.config.method == 'sp':
            self.criterion = SimilarityPreservingLoss()
        else:
            raise ValueError(f"Unknown method: {self.config.method}")
    
    def train_epoch(self, train_loader: DataLoader) -> float:
        """Train for one epoch."""
        self.student.train()
        total_loss = 0.0
        
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            
            # Forward
            student_logits = self.student(inputs)
            
            with torch.no_grad():
                teacher_logits = self.teacher(inputs)
            
            # Compute loss
            if self.config.method == 'hinton':
                loss, _ = self.criterion(student_logits, teacher_logits, labels)
            elif self.config.method in ['rkd', 'sp']:
                # Use penultimate layer features
                # (Assuming models have .features attribute)
                student_features = self.student.features(inputs)
                teacher_features = self.teacher.features(inputs)
                loss, _ = self.criterion(student_features, teacher_features)
                
                # Add classification loss
                ce_loss = F.cross_entropy(student_logits, labels)
                loss = loss + ce_loss
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
        
        return total_loss / len(train_loader)
    
    def evaluate(self, val_loader: DataLoader) -> float:
        """Evaluate student model."""
        self.student.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.student(inputs)
                _, predicted = outputs.max(1)
                
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        accuracy = 100.0 * correct / total
        return accuracy
    
    def train(
        self,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None
    ) -> Dict[str, List[float]]:
        """
        Full training loop.
        
        Returns:
            history: Dict with 'loss' and 'val_acc' lists
        """
        print(f"Training with {self.config.method} distillation")
        print(f"Temperature: {self.config.temperature}, Alpha: {self.config.alpha}")
        
        for epoch in range(self.config.num_epochs):
            # Train
            train_loss = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            
            # Evaluate
            if val_loader is not None:
                val_acc = self.evaluate(val_loader)
                self.val_accuracies.append(val_acc)
                
                if (epoch + 1) % 10 == 0:
                    print(f"Epoch {epoch+1}/{self.config.num_epochs} - "
                          f"Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%")
            else:
                if (epoch + 1) % 10 == 0:
                    print(f"Epoch {epoch+1}/{self.config.num_epochs} - "
                          f"Loss: {train_loss:.4f}")
        
        return {
            'loss': self.train_losses,
            'val_acc': self.val_accuracies
        }


# ============================================================================
# 7. Demo: Knowledge Distillation Comparison
# ============================================================================

def demo_distillation_comparison():
    """
    Compare different distillation methods.
    """
    print("=" * 80)
    print("Knowledge Distillation Methods Comparison")
    print("=" * 80)
    
    # Dummy teacher and student
    class TeacherNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 128, 3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d(1)
            )
            self.fc = nn.Linear(128, 10)
        
        def forward(self, x):
            x = self.conv(x)
            x = x.view(x.size(0), -1)
            return self.fc(x)
        
        def features(self, x):
            x = self.conv(x)
            return x.view(x.size(0), -1)
    
    class StudentNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d(1)
            )
            self.fc = nn.Linear(32, 10)
        
        def forward(self, x):
            x = self.conv(x)
            x = x.view(x.size(0), -1)
            return self.fc(x)
        
        def features(self, x):
            x = self.conv(x)
            return x.view(x.size(0), -1)
    
    # Create dummy data
    X_train = torch.randn(1000, 3, 32, 32)
    y_train = torch.randint(0, 10, (1000,))
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    
    X_val = torch.randn(200, 3, 32, 32)
    y_val = torch.randint(0, 10, (200,))
    val_dataset = TensorDataset(X_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=128)
    
    # Pre-trained teacher (dummy)
    teacher = TeacherNet()
    
    # 1. Hinton's KD
    print("\n1. Classic Distillation (Hinton)")
    print("-" * 40)
    
    student_hinton = StudentNet()
    config_hinton = DistillationConfig(
        method='hinton',
        temperature=4.0,
        alpha=0.7,
        num_epochs=5,
        device='cpu'
    )
    
    trainer_hinton = DistillationTrainer(teacher, student_hinton, config_hinton)
    
    # Sample forward pass
    sample_input = torch.randn(4, 3, 32, 32)
    with torch.no_grad():
        teacher_logits = teacher(sample_input)
        student_logits = student_hinton(sample_input)
    
    criterion = DistillationLoss(temperature=4.0, alpha=0.7)
    sample_labels = torch.randint(0, 10, (4,))
    loss, loss_dict = criterion(student_logits, teacher_logits, sample_labels)
    
    print(f"Sample loss breakdown:")
    print(f"  Total: {loss_dict['total']:.4f}")
    print(f"  Task (CE): {loss_dict['task']:.4f}")
    print(f"  Distillation (KD): {loss_dict['distill']:.4f}")
    
    # 2. Relational KD
    print("\n2. Relational Knowledge Distillation")
    print("-" * 40)
    
    rkd_criterion = RelationalKDLoss()
    
    # Sample features
    student_feat = student_hinton.features(sample_input)
    with torch.no_grad():
        teacher_feat = teacher.features(sample_input)
    
    rkd_loss, rkd_dict = rkd_criterion(student_feat, teacher_feat)
    
    print(f"RKD loss breakdown:")
    print(f"  Total: {rkd_dict['total']:.4f}")
    print(f"  Distance-wise: {rkd_dict['distance']:.4f}")
    print(f"  Angle-wise: {rkd_dict['angle']:.4f}")
    
    # 3. Similarity Preserving
    print("\n3. Similarity-Preserving KD")
    print("-" * 40)
    
    sp_criterion = SimilarityPreservingLoss()
    sp_loss = sp_criterion(student_feat, teacher_feat)
    
    print(f"SP loss: {sp_loss.item():.4f}")
    
    # Visualize similarity matrices
    with torch.no_grad():
        student_sim = sp_criterion.similarity_matrix(student_feat)
        teacher_sim = sp_criterion.similarity_matrix(teacher_feat)
    
    print(f"Student similarity matrix:\n{student_sim.numpy()[:2, :2]}")
    print(f"Teacher similarity matrix:\n{teacher_sim.numpy()[:2, :2]}")
    
    # 4. Deep Mutual Learning
    print("\n4. Deep Mutual Learning")
    print("-" * 40)
    
    dml = DeepMutualLearning(num_students=2, temperature=3.0)
    
    student2 = StudentNet()
    with torch.no_grad():
        logits1 = student_hinton(sample_input)
        logits2 = student2(sample_input)
    
    dml_loss, individual_losses = dml([logits1, logits2], sample_labels)
    
    print(f"DML total loss: {dml_loss.item():.4f}")
    print(f"Student 1 loss: {individual_losses[0].item():.4f}")
    print(f"Student 2 loss: {individual_losses[1].item():.4f}")
    
    # 5. Online Distillation
    print("\n5. Online Distillation (EMA Teacher)")
    print("-" * 40)
    
    online_kd = OnlineDistillation(momentum=0.999, temperature=4.0)
    online_kd.update_teacher(student_hinton)
    
    with torch.no_grad():
        teacher_ema_logits = online_kd.teacher_model(sample_input)
    
    online_loss, online_dict = online_kd(
        student_logits,
        teacher_ema_logits,
        sample_labels
    )
    
    print(f"Online KD loss breakdown:")
    print(f"  Total: {online_dict['total']:.4f}")
    print(f"  CE: {online_dict['ce']:.4f}")
    print(f"  KD: {online_dict['kd']:.4f}")


# ============================================================================
# 8. Performance Comparison Table
# ============================================================================

def print_performance_comparison():
    """
    Print comprehensive comparison of distillation methods.
    """
    print("\n" + "=" * 80)
    print("KNOWLEDGE DISTILLATION METHODS COMPARISON")
    print("=" * 80)
    
    comparison = """
    Method Characteristics:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method           β”‚ Complexity   β”‚ Memory        β”‚ Training     β”‚ Performance β”‚
    β”‚                  β”‚              β”‚ Overhead      β”‚ Time         β”‚ Gain        β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Hinton KD        β”‚ O(NC)        β”‚ Low           β”‚ 1.2Γ—         β”‚ +1.5-2.5%   β”‚
    β”‚ (Response)       β”‚              β”‚               β”‚              β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ FitNet           β”‚ O(NHW)       β”‚ Medium        β”‚ 1.5Γ—         β”‚ +2.0-3.0%   β”‚
    β”‚ (Feature)        β”‚              β”‚ (features)    β”‚ (2-stage)    β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Attention        β”‚ O(NHW)       β”‚ Low           β”‚ 1.3Γ—         β”‚ +1.8-2.8%   β”‚
    β”‚ Transfer         β”‚              β”‚ (attention)   β”‚              β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ RKD              β”‚ O(NΒ²D)       β”‚ Low           β”‚ 1.4Γ—         β”‚ +2.2-3.2%   β”‚
    β”‚ (Relational)     β”‚              β”‚               β”‚              β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ SP               β”‚ O(NΒ²)        β”‚ Medium        β”‚ 1.5Γ—         β”‚ +2.0-3.0%   β”‚
    β”‚ (Similarity)     β”‚              β”‚ (Gram matrix) β”‚              β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ CRD              β”‚ O(NK)        β”‚ High          β”‚ 1.6Γ—         β”‚ +2.5-3.5%   β”‚
    β”‚ (Contrastive)    β”‚              β”‚ (negatives)   β”‚              β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ DML              β”‚ O(MΒ·NC)      β”‚ High          β”‚ 1.0Γ—         β”‚ +1.0-2.0%   β”‚
    β”‚ (Mutual)         β”‚              β”‚ (M students)  β”‚ (parallel)   β”‚             β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Online KD        β”‚ O(NC)        β”‚ High          β”‚ 1.1Γ—         β”‚ +0.8-1.5%   β”‚
    β”‚ (EMA)            β”‚              β”‚ (EMA model)   β”‚              β”‚             β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    N = batch size, C = classes, H/W = spatial dims, D = feature dim, 
    M = num students, K = num negatives
    
    ImageNet Performance (ResNet-34 Teacher β†’ Student):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Student          β”‚ Baseline      β”‚ +Hinton KD   β”‚ +Best Method β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ ResNet-18        β”‚ 69.8%         β”‚ 71.0%        β”‚ 72.2% (CRD)  β”‚
    β”‚ MobileNetV2      β”‚ 67.5%         β”‚ 68.5%        β”‚ 69.8% (RKD)  β”‚
    β”‚ ShuffleNetV2     β”‚ 65.2%         β”‚ 66.8%        β”‚ 67.5% (AT)   β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    CIFAR-100 Performance (WideResNet-40-2 Teacher β†’ Student):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Student          β”‚ Baseline      β”‚ +Hinton KD   β”‚ +Best Method β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ WRN-16-2         β”‚ 73.3%         β”‚ 75.6%        β”‚ 76.8% (CRD)  β”‚
    β”‚ ResNet-20        β”‚ 69.1%         β”‚ 71.2%        β”‚ 72.5% (RKD)  β”‚
    β”‚ MobileNetV2      β”‚ 64.6%         β”‚ 67.4%        β”‚ 68.9% (SP)   β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Compression Ratios (Teacher β†’ Student):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Model Pair              β”‚ Params     β”‚ Speedup     β”‚ Acc Drop     β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ ResNet-152 β†’ ResNet-18  β”‚ 60M β†’ 11M  β”‚ 8Γ—          β”‚ 4.5% β†’ 2.8%  β”‚
    β”‚ BERT-base β†’ DistilBERT  β”‚ 110M β†’ 66M β”‚ 1.6Γ—        β”‚ 97% retained β”‚
    β”‚ GPT-2 β†’ GPT-2 Small     β”‚ 1.5B β†’ 117Mβ”‚ 13Γ—         β”‚ ~15% perplx  β”‚
    β”‚ EfficientNet-B7 β†’ B0    β”‚ 66M β†’ 5M   β”‚ 30Γ—         β”‚ 6.0% β†’ 3.5%  β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Hyperparameter Recommendations:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Hyperparameter β”‚ Typical Range    β”‚ Notes                       β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Temperature T  β”‚ 1-20             β”‚ 4-6 for classification      β”‚
    β”‚                β”‚                  β”‚ Higher for similar tasks    β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Alpha Ξ»        β”‚ 0.1-0.9          β”‚ 0.5-0.7 typical             β”‚
    β”‚                β”‚                  β”‚ Higher when teacher better  β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Learning rate  β”‚ 0.5-2Γ— baseline  β”‚ Start lower, anneal slower  β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Epochs         β”‚ 1.5-2Γ— baseline  β”‚ Student needs more time     β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Feature layers β”‚ 1-5 layers       β”‚ Middle layers work best     β”‚
    β”‚                β”‚                  β”‚ Too early: low-level        β”‚
    β”‚                β”‚                  β”‚ Too late: task-specific     β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Decision Guide:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Scenario                     β”‚ Recommended Method              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Simple classification        β”‚ Hinton KD (response-based)      β”‚
    β”‚ Large capacity gap           β”‚ FitNet (feature-based)          β”‚
    β”‚ Cross-architecture           β”‚ AT (attention transfer)         β”‚
    β”‚ Limited data                 β”‚ RKD (relational)                β”‚
    β”‚ Batch learning matters       β”‚ SP (similarity-preserving)      β”‚
    β”‚ No pre-trained teacher       β”‚ DML (mutual learning)           β”‚
    β”‚ Online/streaming             β”‚ Online KD (EMA teacher)         β”‚
    β”‚ Detection/segmentation       β”‚ Feature + Response combined     β”‚
    β”‚ Resource constrained         β”‚ Hinton KD (lowest overhead)     β”‚
    β”‚ Max performance              β”‚ CRD (contrastive) or ensemble   β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    """
    
    print(comparison)
    
    print("\nKey Insights:")
    print("1. Response-based (Hinton): Simple, effective baseline")
    print("2. Feature-based: Better for large capacity gap, architecture coupling")
    print("3. Relation-based: Captures structural knowledge, scales O(NΒ²)")
    print("4. Self-distillation: No teacher needed, can improve teacher-size models")
    print("5. Combination: Feature + Response often works best")
    print("6. Temperature: 4-6 typical, tune via validation")
    print("7. Alpha: 0.5-0.7 typical, higher when teacher much better")
    print("8. Training time: 1.2-1.6Γ— longer than baseline")
    print("9. Performance gain: +1-3% typical, varies by task/architecture")
    print("10. Best practice: Start with Hinton KD, add features if needed")


# ============================================================================
# 9. Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    # Run demos
    demo_distillation_comparison()
    print_performance_comparison()
    
    print("\n" + "=" * 80)
    print("Knowledge Distillation Implementations Complete!")
    print("=" * 80)