import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Knowledge DistillationΒΆ
Soft TargetsΒΆ
where \(T\) is temperature.
Distillation LossΒΆ
π Reference Materials:
foundation_neural_network.pdf - Foundation Neural Network
neural_networks.pdf - Neural Networks
class TeacherNet(nn.Module):
"""Large teacher model."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.fc1 = nn.Linear(256 * 3 * 3, 512)
self.fc2 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
return self.fc2(x)
class StudentNet(nn.Module):
"""Small student model."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 7 * 7, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
teacher = TeacherNet().to(device)
student = StudentNet().to(device)
print(f"Teacher parameters: {count_parameters(teacher):,}")
print(f"Student parameters: {count_parameters(student):,}")
print(f"Compression: {count_parameters(teacher)/count_parameters(student):.1f}x")
Train TeacherΒΆ
Knowledge distillation begins by training a large, high-capacity teacher model to achieve the best possible accuracy on the task. The teacherβs value lies not just in its hard predictions but in its soft probability outputs β the full distribution over classes captures inter-class similarities (e.g., a β7β looks somewhat like a β1β). These soft targets, when used to train a smaller student, transfer more information per training example than hard one-hot labels, effectively compressing the teacherβs knowledge into a compact model.
# Load data
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=1000)
def train_model(model, train_loader, n_epochs=10):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(n_epochs):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
output = model(x)
loss = F.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
output = model(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
acc = 100 * correct / total
print(f"Epoch {epoch+1}, Accuracy: {acc:.2f}%")
print("Training teacher...")
train_model(teacher, train_loader, n_epochs=10)
Distillation LossΒΆ
The distillation loss combines two terms: the standard cross-entropy with hard labels and the KL divergence between teacher and student soft outputs, tempered by a temperature parameter \(T\):
Higher temperatures soften both distributions, revealing more of the teacherβs knowledge about inter-class relationships. The \(T^2\) scaling factor compensates for the reduced gradient magnitude at high temperatures. The mixing weight \(\alpha\) balances learning from ground truth versus learning from the teacher, with typical values around 0.1-0.5.
def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
"""Knowledge distillation loss."""
# Soft targets
soft_targets = F.softmax(teacher_logits / T, dim=1)
soft_prob = F.log_softmax(student_logits / T, dim=1)
soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T ** 2)
# Hard targets
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * hard_loss + (1 - alpha) * soft_loss
# Test
logits_s = torch.randn(4, 10)
logits_t = torch.randn(4, 10)
labels = torch.randint(0, 10, (4,))
loss = distillation_loss(logits_s, logits_t, labels)
print(f"Test loss: {loss.item():.4f}")
Train Student with DistillationΒΆ
The student model is significantly smaller than the teacher (fewer layers, narrower hidden dimensions) and would achieve lower accuracy if trained from scratch on hard labels alone. With distillation, the student benefits from the teacherβs richer gradient signal: instead of learning that an image is βdefinitely a 3β, it learns that it is βmostly a 3, somewhat like an 8, and a little like a 5β. This additional information acts as a powerful regularizer, often allowing the student to match or approach the teacherβs accuracy despite having a fraction of the parameters β a critical capability for deploying models on mobile devices and edge hardware.
def train_with_distillation(student, teacher, train_loader, n_epochs=10, T=3.0, alpha=0.5):
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
teacher.eval()
losses = []
for epoch in range(n_epochs):
student.train()
epoch_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# Teacher predictions
with torch.no_grad():
teacher_logits = teacher(x)
# Student predictions
student_logits = student(x)
# Distillation loss
loss = distillation_loss(student_logits, teacher_logits, y, T, alpha)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(loss.item())
# Evaluate
student.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
output = student(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
acc = 100 * correct / total
print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader):.4f}, Acc: {acc:.2f}%")
return losses
# Train distilled student
student_distilled = StudentNet().to(device)
losses = train_with_distillation(student_distilled, teacher, train_loader, n_epochs=10, T=3.0)
plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(50)/50, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Distillation Loss', fontsize=11)
plt.title('Knowledge Distillation Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Compare ModelsΒΆ
A side-by-side comparison of the teacher, distilled student, and a student trained from scratch on hard labels quantifies the benefit of distillation. Key metrics include accuracy, parameter count, inference latency, and memory footprint. The distilled student should significantly outperform the scratch-trained student (demonstrating knowledge transfer) while using far fewer parameters than the teacher (demonstrating compression). This comparison is the standard evidence used to justify distillation in production ML pipelines.
# Train student without distillation
student_baseline = StudentNet().to(device)
print("\nTraining baseline student...")
train_model(student_baseline, train_loader, n_epochs=10)
# Compare accuracies
def evaluate(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
output = model(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return 100 * correct / total
teacher_acc = evaluate(teacher, test_loader)
student_distilled_acc = evaluate(student_distilled, test_loader)
student_baseline_acc = evaluate(student_baseline, test_loader)
print(f"\nResults:")
print(f"Teacher: {teacher_acc:.2f}%")
print(f"Student (distilled): {student_distilled_acc:.2f}%")
print(f"Student (baseline): {student_baseline_acc:.2f}%")
print(f"Improvement: {student_distilled_acc - student_baseline_acc:.2f}%")
Temperature AnalysisΒΆ
The temperature parameter \(T\) controls how much of the teacherβs βdark knowledgeβ is transferred. At \(T = 1\) (standard softmax), the teacherβs outputs are often nearly one-hot, providing little more information than hard labels. As \(T\) increases, the distribution flattens and reveals the teacherβs confidence rankings among non-target classes. Sweeping \(T\) across a range (e.g., 1 to 20) and measuring student accuracy at each value reveals the optimal temperature for a given teacher-student pair β typically in the range of 3-10.
temperatures = [1.0, 2.0, 3.0, 5.0, 10.0]
temp_accs = []
for T in temperatures:
student_temp = StudentNet().to(device)
train_with_distillation(student_temp, teacher, train_loader, n_epochs=5, T=T, alpha=0.5)
acc = evaluate(student_temp, test_loader)
temp_accs.append(acc)
print(f"T={T}: {acc:.2f}%")
plt.figure(figsize=(10, 6))
plt.plot(temperatures, temp_accs, 'bo-', markersize=8)
plt.axhline(y=student_baseline_acc, color='r', linestyle='--', label='Baseline')
plt.axhline(y=teacher_acc, color='g', linestyle='--', label='Teacher')
plt.xlabel('Temperature', fontsize=11)
plt.ylabel('Accuracy (%)', fontsize=11)
plt.title('Effect of Temperature on Distillation', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Soft vs Hard TargetsΒΆ
Comparing training with soft targets only (teacher outputs), hard targets only (ground truth labels), and the combined loss reveals each componentβs contribution. Soft targets alone often outperform hard targets alone for the student, because the soft distribution provides more bits of information per example. The combined loss typically performs best, leveraging both the teacherβs inter-class structure and the ground truthβs correctness guarantee. This ablation study is important for understanding why distillation works and for tuning the mixing weight \(\alpha\) in practice.
# Get sample predictions
teacher.eval()
x_sample, y_sample = next(iter(test_loader))
x_sample = x_sample[:1].to(device)
y_sample = y_sample[:1].to(device)
with torch.no_grad():
logits = teacher(x_sample)
# Hard targets
hard = F.softmax(logits, dim=1)
# Soft targets
soft_T3 = F.softmax(logits / 3, dim=1)
soft_T10 = F.softmax(logits / 10, dim=1)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].bar(range(10), hard[0].cpu())
axes[0].set_title('Hard Targets (T=1)', fontsize=11)
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Probability')
axes[1].bar(range(10), soft_T3[0].cpu())
axes[1].set_title('Soft Targets (T=3)', fontsize=11)
axes[1].set_xlabel('Class')
axes[2].bar(range(10), soft_T10[0].cpu())
axes[2].set_title('Soft Targets (T=10)', fontsize=11)
axes[2].set_xlabel('Class')
plt.tight_layout()
plt.show()
SummaryΒΆ
Knowledge Distillation:ΒΆ
Transfer knowledge from teacher to student via soft targets.
Key Components:ΒΆ
Teacher model - Large, accurate
Student model - Small, efficient
Temperature - Controls softness
Loss combination - Hard + soft
Benefits:ΒΆ
Model compression
Better generalization
Faster inference
Dark knowledge transfer
Applications:ΒΆ
Mobile deployment
Edge devices
Ensemble compression
Cross-modal transfer
Extensions:ΒΆ
Self-distillation
Feature distillation
Attention transfer
Data-free distillation
Next Steps:ΒΆ
Study quantization
Explore pruning
Learn neural architecture search
Advanced Knowledge Distillation TheoryΒΆ
1. Introduction to Knowledge DistillationΒΆ
Knowledge Distillation (KD) transfers knowledge from a large, complex teacher model to a smaller, efficient student model.
1.1 MotivationΒΆ
Why distillation?
Model compression: Deploy on edge devices (mobile, IoT)
Inference speedup: Smaller models are faster
Ensemble knowledge: Distill from multiple teachers
Transfer learning: Cross-domain/architecture knowledge transfer
Key insight: Teacherβs soft predictions contain more information than hard labels.
1.2 Problem FormulationΒΆ
Teacher model: T(x; ΞΈ_T) β Ε·_soft (pre-softmax logits) Student model: S(x; ΞΈ_S) β Ε·_student
Goal: Train student to mimic teacherβs behavior:
min_{ΞΈ_S} L_distill(S(x), T(x)) + Ξ± L_task(S(x), y)
Where:
L_distill: Distillation loss (KL divergence, MSE)
L_task: Task loss (cross-entropy with ground truth)
Ξ±: Balance hyperparameter
2. Classic Knowledge Distillation (Hinton et al.)ΒΆ
Hintonβs KD [Hinton et al., 2015]: Use temperature-scaled softmax.
2.1 Temperature ScalingΒΆ
Softmax with temperature T:
p_i = exp(z_i / T) / Ξ£_j exp(z_j / T)
Where:
z_i: Logit for class i
T: Temperature parameter
T=1: Standard softmax
T>1: Softer distribution (more information in wrong classes)
Effect of temperature:
T β 0: One-hot (hard labels)
T = 1: Standard probabilities
T β β: Uniform distribution
2.2 Distillation LossΒΆ
KL divergence between teacher and student (both with temperature T):
L_KD = TΒ² Β· KL(softmax(z_T/T) || softmax(z_S/T))
Total loss:
L = (1-Ξ») L_CE(y_true, softmax(z_S)) + Ξ» L_KD
Where:
Ξ»: Distillation weight (typically 0.5-0.9)
TΒ²: Scaling factor (compensates for softened gradients)
Intuition: Soft targets reveal similarity structure between classes.
2.3 Why Temperature WorksΒΆ
Information in soft targets:
Hard label: βThis is a catβ (1 bit)
Soft target: β90% cat, 8% dog, 2% tigerβ (rich information)
Dark knowledge: Information in near-zero probabilities.
Example: For image of β3β, teacher might output:
p(3) = 0.95
p(8) = 0.03 (similar shape)
p(5) = 0.01 (somewhat similar)
Others: ~0
Student learns similarity structure!
3. Response-Based DistillationΒΆ
3.1 Logit MatchingΒΆ
Direct logit matching (no softmax):
L = MSE(z_S, z_T)
Advantage: Simpler, no temperature tuning. Disadvantage: Sensitive to logit scale.
3.2 Regression DistillationΒΆ
For regression tasks:
L = ||f_S(x) - f_T(x)||Β²
Application: Object detection bbox regression, depth estimation.
3.3 Ranking DistillationΒΆ
RKD [Park et al., 2019]: Preserve relative distances.
Distance-wise:
L_dist = Ξ£_{i,j} Ο(||f_T(x_i) - f_T(x_j)||Β² - ||f_S(x_i) - f_S(x_j)||Β²)
Angle-wise:
L_angle = Ξ£_{i,j,k} Ο(β (x_i, x_j, x_k)_T - β (x_i, x_j, x_k)_S)
Where Ο is Huber loss.
4. Feature-Based DistillationΒΆ
4.1 FitNet (Hint Learning)ΒΆ
FitNet [Romero et al., 2015]: Match intermediate features.
Regressor: Transform student features to match teacher dimension:
L_hint = ||W_r Β· f_S^l - f_T^m||Β²
Where:
f_S^l: Student feature at layer l
f_T^m: Teacher feature at layer m
W_r: Learned transformation matrix
Two-stage training:
Train student with hint loss to match intermediate layers
Fine-tune with distillation loss on outputs
4.2 Attention TransferΒΆ
AT [Zagoruyko & Komodakis, 2017]: Transfer attention maps.
Spatial attention map:
A(F) = Ξ£_c |F^c|^p
Where F^c is feature channel c.
Loss:
L_AT = Ξ£_l ||A(F_S^l) / ||A(F_S^l)||_2 - A(F_T^l) / ||A(F_T^l)||_2||Β²
Intuition: Student learns where teacher focuses attention.
4.3 Similarity-Preserving KDΒΆ
SP [Tung & Mori, 2019]: Preserve pairwise similarity.
Similarity matrix:
G_T[i,j] = f_T(x_i)^T f_T(x_j) / (||f_T(x_i)|| ||f_T(x_j)||)
Loss:
L_SP = ||G_S - G_T||Β²_F
Advantage: Captures relational structure across samples.
4.4 Correlation CongruenceΒΆ
CRD [Tian et al., 2020]: Contrastive Representation Distillation.
InfoNCE-based loss:
L_CRD = -log[exp(f_S^T f_T / Ο) / (exp(f_S^T f_T / Ο) + Ξ£_k exp(f_S^T f_k^- / Ο))]
Where f_k^- are negative samples.
Benefit: Discriminative feature learning via contrastive objective.
5. Relation-Based DistillationΒΆ
5.1 Relational Knowledge Distillation (RKD)ΒΆ
Transfer structural information between instances.
Distance-wise RKD: Preserve relative Euclidean distances in feature space.
Angle-wise RKD: Preserve angles formed by triplets of points.
5.2 Instance Relationship GraphΒΆ
IRG [Liu et al., 2019]: Model as graph distillation.
Graph: Nodes = samples, Edges = relationships
Loss:
L_IRG = Ξ£_{i,j} w_ij Β· ||r_S(x_i, x_j) - r_T(x_i, x_j)||Β²
Where r(Β·,Β·) is relationship function.
6. Self-DistillationΒΆ
6.1 Born-Again Networks (BAN)ΒΆ
BAN [Furlanello et al., 2018]: Student same capacity as teacher.
Procedure:
Train teacher T_1
Distill to student S_1 (same architecture)
Repeat: S_1 becomes T_2, train S_2, etc.
Finding: Each generation improves (up to 3-4 iterations).
6.2 Deep Mutual Learning (DML)ΒΆ
DML [Zhang et al., 2018]: Multiple students learn collaboratively.
No pre-trained teacher: All models trained from scratch.
Loss for student i:
L_i = L_CE(y, S_i(x)) + Ξ£_{jβ i} KL(S_j(x) || S_i(x))
Benefit: Ensemble knowledge without ensemble inference.
6.3 Online Knowledge DistillationΒΆ
On-the-fly: Generate teacher predictions during training.
Approaches:
Multi-branch: Branches share backbone, distill from each other
Temporal ensemble: Teacher is EMA of student
Self-training: Pseudo-labels from studentβs own predictions
7. Cross-Modal DistillationΒΆ
7.1 Cross-Domain DistillationΒΆ
Transfer across domains: RGB β Depth, Day β Night, etc.
Challenge: Input modalities differ.
Solution: Match feature representations, not inputs.
7.2 Cross-Task DistillationΒΆ
Transfer knowledge between different tasks.
Example: Classification teacher β Detection student
Approach: Align feature spaces via intermediate layers.
7.3 Privileged Information DistillationΒΆ
Learning Using Privileged Information (LUPI):
Teacher has extra modality at training (e.g., depth maps)
Student uses only RGB at test
Knowledge transfer: Student learns from teacherβs privileged view.
8. Adversarial DistillationΒΆ
8.1 Adversarial Training for DistillationΒΆ
Generator: Student trying to fool discriminator Discriminator: Distinguish teacher vs. student outputs
Loss:
L_S = L_task + Ξ± L_KD + Ξ² L_adv
Where L_adv encourages student to match teacherβs distribution.
8.2 Data-Free DistillationΒΆ
Scenario: No access to original training data.
Solution: Generate synthetic data.
Approaches:
Gradient matching: Generate inputs that match teacher gradients
Activation matching: Maximize activation of teacher neurons
GAN-based: Train GAN to generate realistic inputs
DeepInversion [Yin et al., 2020]:
L = L_task(T(x_gen)) + Ξ± L_BN(x_gen) + Ξ² L_prior(x_gen)
Where:
L_BN: Batch normalization statistics matching
L_prior: Image prior (smoothness, etc.)
9. Advanced Distillation TechniquesΒΆ
9.1 Layer-wise DistillationΒΆ
Match at multiple layers simultaneously:
L = Ξ£_l w_l L_layer(f_S^l, f_T^l)
Adaptive weights: Learn w_l or use task-dependent heuristics.
9.2 Progressive DistillationΒΆ
Curriculum: Start with easy knowledge, progress to complex.
Temperature scheduling: T decreases over training (hard β soft).
9.3 Quantization-Aware DistillationΒΆ
Goal: Distill to quantized student (INT8, binary).
Approach: Quantize during distillation training.
Loss:
L = L_KD(quant(S(x)), T(x)) + L_CE(quant(S(x)), y)
9.4 Neural Architecture Search + KDΒΆ
AutoKD: Search for optimal student architecture while distilling.
Joint optimization: Architecture search + knowledge transfer.
10. Theoretical FoundationsΒΆ
10.1 Why Does Distillation Work?ΒΆ
Hypothesis 1: Dark knowledge (soft targets) provides regularization.
Hypothesis 2: Teacher smooths decision boundaries.
Hypothesis 3: Distillation implicitly optimizes for generalization.
Empirical evidence: Students often generalize better than teachers!
10.2 Capacity GapΒΆ
Teacher capacity: High (many parameters) Student capacity: Low (few parameters)
Trade-off:
Gap too large β Student canβt mimic teacher
Gap too small β Limited compression
Optimal gap: Task-dependent, typically teacher 5-10Γ larger.
10.3 Generalization BoundΒΆ
PAC-Bayesian bound for distilled student:
With probability β₯ 1-Ξ΄:
L_test(S) β€ L_KD(S, T) + O(β(KL(P_S || P_prior) / N))
Implication: Distillation acts as learned prior from teacher.
10.4 Information-Theoretic ViewΒΆ
Mutual information:
I(X; Y) β₯ I(f_S(X); Y)
Goal: Maximize mutual information between student features and labels.
Distillation: I(f_S; f_T) + I(f_T; Y) β I(f_S; Y)
11. Practical ConsiderationsΒΆ
11.1 Hyperparameter SelectionΒΆ
Critical hyperparameters:
Parameter |
Range |
Notes |
|---|---|---|
Temperature T |
1-20 |
Higher for similar tasks |
Distillation weight Ξ» |
0.1-0.9 |
0.5-0.7 typical |
Feature matching layers |
1-5 |
Middle layers work best |
Training epochs |
1.5-2Γ normal |
Student needs more time |
11.2 Teacher SelectionΒΆ
Bigger not always better: Very large teachers may not help.
Ensemble teachers: Average multiple teachersβ predictions.
Architecture mismatch: CNN β Transformer works if capacity appropriate.
11.3 Student Architecture DesignΒΆ
Width vs. Depth trade-off:
Wider, shallower: Faster inference, easier distillation
Narrower, deeper: Better feature learning, harder distillation
Skip connections: Help gradient flow in deep students.
11.4 Training StrategiesΒΆ
Two-stage:
Pre-train teacher on task
Distill to student
Co-training:
Train teacher and student jointly
Mutual learning
Temperature annealing: Start high, decrease over training.
12. ApplicationsΒΆ
12.1 Model CompressionΒΆ
Typical compression ratios:
BERT β DistilBERT: 40% size, 97% performance
ResNet-152 β ResNet-18: 10Γ speedup, ~5% accuracy drop
EfficientNet-B7 β B0: 30Γ fewer parameters
12.2 Edge DeploymentΒΆ
Mobile devices: Distilled models fit memory/compute constraints.
Real-time inference: Faster student enables real-time applications.
Battery efficiency: Smaller models consume less power.
12.3 Ensemble CompressionΒΆ
Distill ensemble into single model:
Ensemble: 5-10 models (slow, accurate)
Distilled student: 1 model (fast, nearly as accurate)
Benefit: Ensemble accuracy without ensemble cost.
12.4 Continual LearningΒΆ
Distillation prevents forgetting:
Old model = Teacher
Updated model = Student
Retain old knowledge via distillation loss
13. Recent Advances (2020-2024)ΒΆ
13.1 Vision Transformer DistillationΒΆ
DeiT [Touvron et al., 2021]: Data-efficient ViT training.
Distillation token: Extra token attending to teacher.
Hard vs. Soft distillation:
Soft: KL(teacher_soft || student)
Hard: CE(teacher_argmax, student)
Finding: Hard distillation often better for ViTs.
13.2 Large Language Model DistillationΒΆ
Distilling BERT:
DistilBERT: 40% smaller, 60% faster
TinyBERT: Layer-wise + prediction distillation
MobileBERT: Bottleneck architecture
Distilling GPT:
GPT-2 β GPT-2 Small/Medium
Layer dropping, width reduction
13.3 Multimodal DistillationΒΆ
CLIP distillation: Image-text teacher β Image-only student
Approach: Match vision encoder while removing text tower.
13.4 Test-Time DistillationΒΆ
Adapt student at test time using unlabeled data.
Self-distillation: Student becomes its own teacher.
14. Comparison of MethodsΒΆ
14.1 Performance SummaryΒΆ
ImageNet (ResNet-34 Teacher β Student):
Student Arch |
Method |
Top-1 Acc |
Speedup |
|---|---|---|---|
ResNet-18 |
Vanilla |
69.8% |
1.8Γ |
ResNet-18 |
KD (Hinton) |
71.0% |
1.8Γ |
ResNet-18 |
FitNet |
71.5% |
1.8Γ |
ResNet-18 |
AT |
71.8% |
1.8Γ |
ResNet-18 |
CRD |
72.2% |
1.8Γ |
MobileNetV2 |
KD |
68.5% |
3.5Γ |
Compression ratio: Teacher (21M params) β Student (11M params)
14.2 Method CharacteristicsΒΆ
Response-based (Hinton KD):
β Simple to implement
β Architecture-agnostic
β Only final layer information
Feature-based (FitNet, AT):
β Rich intermediate knowledge
β Better for large capacity gap
β Architecture coupling
β Hyperparameter sensitive
Relation-based (RKD, SP):
β Captures structural knowledge
β Generalizes across samples
β Quadratic complexity O(nΒ²)
15. Limitations and ChallengesΒΆ
15.1 Current LimitationsΒΆ
Capacity gap: Very small students canβt learn from very large teachers
Architecture constraints: Some methods require similar architectures
Computational cost: Training time often 1.5-2Γ longer
Hyperparameter sensitivity: Temperature, Ξ» require tuning
Task dependence: What works for classification may not for detection
15.2 Open ProblemsΒΆ
Optimal temperature: No principled way to select T
Layer matching: Which teacher layers to use?
Student architecture: How to design for distillability?
Theoretical understanding: Why does it work so well?
Negative transfer: When does distillation hurt?
15.3 Future DirectionsΒΆ
Automated distillation: NAS for student + automatic hyperparameter tuning
Task-agnostic KD: Universal distillation framework
Efficient distillation: Reduce training time overhead
Multi-teacher distillation: Optimally combine multiple teachers
Lifelong distillation: Continual model updates via distillation
16. Key TakeawaysΒΆ
Knowledge distillation transfers knowledge from teacher to student via soft targets
Temperature scaling reveals dark knowledge in near-zero probabilities
Response-based (outputs), feature-based (intermediate), relation-based (structural)
Self-distillation (BAN, DML) improves without pre-trained teacher
Compression: 2-10Γ smaller models with <5% accuracy drop
Applications: Edge deployment, ensemble compression, continual learning
Hyperparameters: T=4-6, Ξ»=0.5-0.7 typical
Recent: ViT distillation, LLM distillation, data-free methods
Trade-off: Compression ratio vs. performance vs. training time
Best practice: Start with Hintonβs KD, add feature matching if needed
Core insight: Soft targets contain rich information about similarity structure that hard labels cannot provide.
17. Mathematical SummaryΒΆ
Classic distillation loss:
L = (1-Ξ») CE(y, softmax(z_S)) + Ξ» TΒ² KL(softmax(z_T/T) || softmax(z_S/T))
Feature matching:
L_feature = Ξ£_l ||f_S^l - f_T^l||Β²
Relational distillation:
L_relation = Ξ£_{i,j} ||sim(f_S(x_i), f_S(x_j)) - sim(f_T(x_i), f_T(x_j))||Β²
Attention transfer:
L_AT = Ξ£_l ||normalize(A(F_S^l)) - normalize(A(F_T^l))||Β²
ReferencesΒΆ
Hinton et al. (2015) βDistilling the Knowledge in a Neural Networkβ
Romero et al. (2015) βFitNets: Hints for Thin Deep Netsβ
Zagoruyko & Komodakis (2017) βPaying More Attention to Attentionβ
Park et al. (2019) βRelational Knowledge Distillationβ
Tian et al. (2020) βContrastive Representation Distillationβ
Furlanello et al. (2018) βBorn Again Neural Networksβ
Zhang et al. (2018) βDeep Mutual Learningβ
Yin et al. (2020) βDreaming to Distill: Data-Free Knowledge Transferβ
Touvron et al. (2021) βTraining Data-Efficient Image Transformers (DeiT)β
Gou et al. (2021) βKnowledge Distillation: A Surveyβ
"""
Advanced Knowledge Distillation Implementations
This notebook provides production-ready PyTorch implementations of various
knowledge distillation techniques.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import List, Tuple, Dict, Optional, Callable
import matplotlib.pyplot as plt
from dataclasses import dataclass
# ============================================================================
# 1. Classic Knowledge Distillation (Hinton)
# ============================================================================
class DistillationLoss(nn.Module):
"""
Classic knowledge distillation loss with temperature scaling.
Loss = (1-Ξ±) * CE(y_true, student) + Ξ± * TΒ² * KL(teacher || student)
Args:
temperature: Temperature for softening distributions (default: 4.0)
alpha: Weight for distillation loss (default: 0.7)
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_logits: [batch, num_classes]
teacher_logits: [batch, num_classes]
labels: [batch] ground truth labels
Returns:
total_loss, loss_dict
"""
# Task loss (cross-entropy with hard labels)
task_loss = self.ce_loss(student_logits, labels)
# Distillation loss (KL divergence with soft targets)
student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
# KL(P || Q) = Ξ£ P(x) log(P(x) / Q(x))
distill_loss = F.kl_div(
student_soft,
teacher_soft,
reduction='batchmean'
) * (self.temperature ** 2)
# Combined loss
total_loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss
loss_dict = {
'total': total_loss.item(),
'task': task_loss.item(),
'distill': distill_loss.item()
}
return total_loss, loss_dict
# ============================================================================
# 2. Feature-Based Distillation
# ============================================================================
class FeatureDistillationLoss(nn.Module):
"""
FitNet-style feature matching at intermediate layers.
Matches student features to teacher features using regressor.
Args:
student_channels: Student feature channels
teacher_channels: Teacher feature channels
layers_to_match: List of layer indices to match
"""
def __init__(
self,
student_channels: List[int],
teacher_channels: List[int],
layers_to_match: List[int]
):
super().__init__()
self.layers_to_match = layers_to_match
# Regressor to transform student features
self.regressors = nn.ModuleList([
nn.Conv2d(s_ch, t_ch, kernel_size=1, bias=False)
for s_ch, t_ch in zip(student_channels, teacher_channels)
])
def forward(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
Args:
student_features: List of [batch, C_s, H, W]
teacher_features: List of [batch, C_t, H, W]
Returns:
feature_loss: MSE between matched features
"""
total_loss = 0.0
for idx in self.layers_to_match:
# Transform student features
s_feat = self.regressors[idx](student_features[idx])
t_feat = teacher_features[idx]
# Resize if spatial dimensions don't match
if s_feat.shape[2:] != t_feat.shape[2:]:
s_feat = F.adaptive_avg_pool2d(s_feat, t_feat.shape[2:])
# L2 loss
loss = F.mse_loss(s_feat, t_feat)
total_loss += loss
return total_loss / len(self.layers_to_match)
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Transfers spatial attention maps from teacher to student.
Attention map: A(F) = Ξ£_c |F^c|^p
Args:
p: Power for attention computation (default: 2)
"""
def __init__(self, p: float = 2.0):
super().__init__()
self.p = p
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Compute spatial attention map.
Args:
features: [batch, channels, height, width]
Returns:
attention: [batch, height, width]
"""
# Sum absolute values across channels
attention = torch.sum(torch.abs(features) ** self.p, dim=1)
return attention
def normalize_attention(self, attention: torch.Tensor) -> torch.Tensor:
"""Normalize attention map to unit norm."""
batch_size = attention.size(0)
attention_flat = attention.view(batch_size, -1)
norm = torch.norm(attention_flat, p=2, dim=1, keepdim=True)
attention_flat = attention_flat / (norm + 1e-8)
return attention_flat.view_as(attention)
def forward(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
Args:
student_features: List of [batch, C, H, W]
teacher_features: List of [batch, C, H, W]
Returns:
at_loss: Attention transfer loss
"""
total_loss = 0.0
for s_feat, t_feat in zip(student_features, teacher_features):
# Compute attention maps
s_attention = self.attention_map(s_feat)
t_attention = self.attention_map(t_feat)
# Normalize
s_attention = self.normalize_attention(s_attention)
t_attention = self.normalize_attention(t_attention)
# L2 loss
loss = F.mse_loss(s_attention, t_attention)
total_loss += loss
return total_loss / len(student_features)
# ============================================================================
# 3. Relation-Based Distillation
# ============================================================================
class RelationalKDLoss(nn.Module):
"""
Relational Knowledge Distillation (Park et al., 2019).
Preserves relative distances and angles in feature space.
Args:
distance_weight: Weight for distance-wise loss
angle_weight: Weight for angle-wise loss
"""
def __init__(self, distance_weight: float = 1.0, angle_weight: float = 2.0):
super().__init__()
self.distance_weight = distance_weight
self.angle_weight = angle_weight
def pdist(self, embeddings: torch.Tensor, squared: bool = False) -> torch.Tensor:
"""
Pairwise distance matrix.
Args:
embeddings: [batch, dim]
Returns:
distances: [batch, batch]
"""
dot_product = embeddings @ embeddings.t()
squared_norm = dot_product.diag()
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a,b>
distances = squared_norm.unsqueeze(0) + squared_norm.unsqueeze(1) - 2 * dot_product
distances = torch.clamp(distances, min=0.0)
if not squared:
# Avoid NaN gradients
mask = (distances == 0.0).float()
distances = distances + mask * 1e-16
distances = torch.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
def distance_wise_loss(
self,
student_embed: torch.Tensor,
teacher_embed: torch.Tensor
) -> torch.Tensor:
"""
L_dist = Ξ£_{i,j} Huber(||s_i - s_j||Β² - ||t_i - t_j||Β²)
"""
student_dist = self.pdist(student_embed, squared=True)
teacher_dist = self.pdist(teacher_embed, squared=True)
# Huber loss (smooth L1)
loss = F.smooth_l1_loss(student_dist, teacher_dist, reduction='mean')
return loss
def angle_wise_loss(
self,
student_embed: torch.Tensor,
teacher_embed: torch.Tensor
) -> torch.Tensor:
"""
L_angle = Ξ£_{i,j,k} Huber(β (s_i, s_j, s_k) - β (t_i, t_j, t_k))
Simplified: Use cosine similarity instead of explicit angles.
"""
# Normalize embeddings
student_norm = F.normalize(student_embed, p=2, dim=1)
teacher_norm = F.normalize(teacher_embed, p=2, dim=1)
# Cosine similarity matrix
student_sim = student_norm @ student_norm.t()
teacher_sim = teacher_norm @ teacher_norm.t()
# L2 loss on similarity matrices
loss = F.mse_loss(student_sim, teacher_sim)
return loss
def forward(
self,
student_embed: torch.Tensor,
teacher_embed: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_embed: [batch, embed_dim]
teacher_embed: [batch, embed_dim]
Returns:
total_loss, loss_dict
"""
dist_loss = self.distance_wise_loss(student_embed, teacher_embed)
angle_loss = self.angle_wise_loss(student_embed, teacher_embed)
total_loss = (
self.distance_weight * dist_loss +
self.angle_weight * angle_loss
)
loss_dict = {
'distance': dist_loss.item(),
'angle': angle_loss.item(),
'total': total_loss.item()
}
return total_loss, loss_dict
class SimilarityPreservingLoss(nn.Module):
"""
Similarity-Preserving KD (Tung & Mori, 2019).
Preserves pairwise cosine similarity.
G[i,j] = <f_i, f_j> / (||f_i|| ||f_j||)
"""
def __init__(self):
super().__init__()
def similarity_matrix(self, features: torch.Tensor) -> torch.Tensor:
"""
Compute normalized Gram matrix.
Args:
features: [batch, dim]
Returns:
similarity: [batch, batch]
"""
# Normalize features
features_norm = F.normalize(features, p=2, dim=1)
# Gram matrix
similarity = features_norm @ features_norm.t()
return similarity
def forward(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Args:
student_features: [batch, dim]
teacher_features: [batch, dim]
Returns:
sp_loss: Similarity preservation loss
"""
student_sim = self.similarity_matrix(student_features)
teacher_sim = self.similarity_matrix(teacher_features)
# Frobenius norm
loss = torch.norm(student_sim - teacher_sim, p='fro') ** 2
# Normalize by batch size squared
batch_size = student_features.size(0)
loss = loss / (batch_size ** 2)
return loss
# ============================================================================
# 4. Self-Distillation
# ============================================================================
class DeepMutualLearning(nn.Module):
"""
Deep Mutual Learning (Zhang et al., 2018).
Multiple students learn collaboratively without pre-trained teacher.
Args:
num_students: Number of student networks
temperature: Temperature for distillation
"""
def __init__(self, num_students: int = 2, temperature: float = 3.0):
super().__init__()
self.num_students = num_students
self.temperature = temperature
self.ce_loss = nn.CrossEntropyLoss()
def forward(
self,
student_logits: List[torch.Tensor],
labels: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
student_logits: List of [batch, num_classes] from each student
labels: [batch] ground truth
Returns:
total_loss, individual_losses
"""
losses = []
for i, logits_i in enumerate(student_logits):
# Classification loss
ce_loss = self.ce_loss(logits_i, labels)
# Distillation from other students
distill_loss = 0.0
for j, logits_j in enumerate(student_logits):
if i != j:
# KL divergence
soft_i = F.log_softmax(logits_i / self.temperature, dim=1)
soft_j = F.softmax(logits_j.detach() / self.temperature, dim=1)
kl = F.kl_div(soft_i, soft_j, reduction='batchmean')
distill_loss += kl
# Average distillation loss
distill_loss = distill_loss / (self.num_students - 1)
# Total loss for this student
total = ce_loss + self.temperature ** 2 * distill_loss
losses.append(total)
# Sum across all students
total_loss = sum(losses)
return total_loss, losses
class OnlineDistillation(nn.Module):
"""
Online distillation with temporal ensemble teacher.
Teacher = Exponential Moving Average of student.
Args:
momentum: EMA momentum (default: 0.999)
temperature: Distillation temperature
"""
def __init__(self, momentum: float = 0.999, temperature: float = 4.0):
super().__init__()
self.momentum = momentum
self.temperature = temperature
self.teacher_model = None
self.ce_loss = nn.CrossEntropyLoss()
def update_teacher(self, student_model: nn.Module):
"""
Update teacher as EMA of student.
ΞΈ_teacher = m * ΞΈ_teacher + (1-m) * ΞΈ_student
"""
if self.teacher_model is None:
# Initialize teacher as copy of student
self.teacher_model = type(student_model)()
self.teacher_model.load_state_dict(student_model.state_dict())
self.teacher_model.eval()
else:
# EMA update
with torch.no_grad():
for teacher_param, student_param in zip(
self.teacher_model.parameters(),
student_model.parameters()
):
teacher_param.data = (
self.momentum * teacher_param.data +
(1 - self.momentum) * student_param.data
)
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_logits: [batch, num_classes]
teacher_logits: [batch, num_classes] from EMA teacher
labels: [batch]
Returns:
total_loss, loss_dict
"""
# Classification loss
ce_loss = self.ce_loss(student_logits, labels)
# Distillation loss
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
kd_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2)
# Combined
total_loss = ce_loss + kd_loss
loss_dict = {
'ce': ce_loss.item(),
'kd': kd_loss.item(),
'total': total_loss.item()
}
return total_loss, loss_dict
# ============================================================================
# 5. Data-Free Distillation
# ============================================================================
class DataFreeDistillation:
"""
Data-free knowledge distillation via synthetic data generation.
Generates synthetic inputs that maximize teacher activation.
Args:
teacher: Pre-trained teacher model
input_shape: Shape of input (e.g., [3, 32, 32])
num_classes: Number of output classes
"""
def __init__(
self,
teacher: nn.Module,
input_shape: Tuple[int, ...],
num_classes: int
):
self.teacher = teacher
self.teacher.eval()
self.input_shape = input_shape
self.num_classes = num_classes
def generate_samples(
self,
batch_size: int,
num_iterations: int = 1000,
lr: float = 0.1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate synthetic samples via gradient ascent.
Maximize: -CE(teacher(x), uniform) + prior(x)
Args:
batch_size: Number of samples to generate
num_iterations: Optimization steps
lr: Learning rate
Returns:
synthetic_inputs, teacher_logits
"""
# Initialize random inputs
inputs = torch.randn(
batch_size, *self.input_shape,
requires_grad=True
)
optimizer = optim.Adam([inputs], lr=lr)
for _ in range(num_iterations):
optimizer.zero_grad()
# Forward through teacher
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Maximize diversity (minimize entropy)
probs = F.softmax(teacher_logits, dim=1)
entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1).mean()
# Image prior (smoothness)
tv_loss = (
torch.sum(torch.abs(inputs[:, :, :, :-1] - inputs[:, :, :, 1:])) +
torch.sum(torch.abs(inputs[:, :, :-1, :] - inputs[:, :, 1:, :]))
)
# Total loss
loss = entropy + 0.0001 * tv_loss
loss.backward()
optimizer.step()
# Clamp to valid range
inputs.data = torch.clamp(inputs.data, -1, 1)
# Get final teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(inputs.detach())
return inputs.detach(), teacher_logits
# ============================================================================
# 6. Distillation Trainer
# ============================================================================
@dataclass
class DistillationConfig:
"""Configuration for distillation training."""
method: str = 'hinton' # 'hinton', 'fitnet', 'attention', 'rkd', 'dml'
temperature: float = 4.0
alpha: float = 0.7
num_epochs: int = 100
learning_rate: float = 0.001
batch_size: int = 128
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
class DistillationTrainer:
"""
Unified trainer for various distillation methods.
Args:
teacher: Pre-trained teacher model
student: Student model to train
config: Training configuration
"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
config: DistillationConfig
):
self.teacher = teacher.to(config.device)
self.student = student.to(config.device)
self.config = config
self.device = config.device
# Freeze teacher
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False
# Setup loss
self.setup_loss()
# Optimizer
self.optimizer = optim.Adam(
self.student.parameters(),
lr=config.learning_rate
)
# Metrics
self.train_losses = []
self.val_accuracies = []
def setup_loss(self):
"""Initialize loss function based on method."""
if self.config.method == 'hinton':
self.criterion = DistillationLoss(
temperature=self.config.temperature,
alpha=self.config.alpha
)
elif self.config.method == 'rkd':
self.criterion = RelationalKDLoss()
elif self.config.method == 'sp':
self.criterion = SimilarityPreservingLoss()
else:
raise ValueError(f"Unknown method: {self.config.method}")
def train_epoch(self, train_loader: DataLoader) -> float:
"""Train for one epoch."""
self.student.train()
total_loss = 0.0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# Forward
student_logits = self.student(inputs)
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Compute loss
if self.config.method == 'hinton':
loss, _ = self.criterion(student_logits, teacher_logits, labels)
elif self.config.method in ['rkd', 'sp']:
# Use penultimate layer features
# (Assuming models have .features attribute)
student_features = self.student.features(inputs)
teacher_features = self.teacher.features(inputs)
loss, _ = self.criterion(student_features, teacher_features)
# Add classification loss
ce_loss = F.cross_entropy(student_logits, labels)
loss = loss + ce_loss
# Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def evaluate(self, val_loader: DataLoader) -> float:
"""Evaluate student model."""
self.student.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
outputs = self.student(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100.0 * correct / total
return accuracy
def train(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None
) -> Dict[str, List[float]]:
"""
Full training loop.
Returns:
history: Dict with 'loss' and 'val_acc' lists
"""
print(f"Training with {self.config.method} distillation")
print(f"Temperature: {self.config.temperature}, Alpha: {self.config.alpha}")
for epoch in range(self.config.num_epochs):
# Train
train_loss = self.train_epoch(train_loader)
self.train_losses.append(train_loss)
# Evaluate
if val_loader is not None:
val_acc = self.evaluate(val_loader)
self.val_accuracies.append(val_acc)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{self.config.num_epochs} - "
f"Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%")
else:
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{self.config.num_epochs} - "
f"Loss: {train_loss:.4f}")
return {
'loss': self.train_losses,
'val_acc': self.val_accuracies
}
# ============================================================================
# 7. Demo: Knowledge Distillation Comparison
# ============================================================================
def demo_distillation_comparison():
"""
Compare different distillation methods.
"""
print("=" * 80)
print("Knowledge Distillation Methods Comparison")
print("=" * 80)
# Dummy teacher and student
class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
def features(self, x):
x = self.conv(x)
return x.view(x.size(0), -1)
class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
def features(self, x):
x = self.conv(x)
return x.view(x.size(0), -1)
# Create dummy data
X_train = torch.randn(1000, 3, 32, 32)
y_train = torch.randint(0, 10, (1000,))
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
X_val = torch.randn(200, 3, 32, 32)
y_val = torch.randint(0, 10, (200,))
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=128)
# Pre-trained teacher (dummy)
teacher = TeacherNet()
# 1. Hinton's KD
print("\n1. Classic Distillation (Hinton)")
print("-" * 40)
student_hinton = StudentNet()
config_hinton = DistillationConfig(
method='hinton',
temperature=4.0,
alpha=0.7,
num_epochs=5,
device='cpu'
)
trainer_hinton = DistillationTrainer(teacher, student_hinton, config_hinton)
# Sample forward pass
sample_input = torch.randn(4, 3, 32, 32)
with torch.no_grad():
teacher_logits = teacher(sample_input)
student_logits = student_hinton(sample_input)
criterion = DistillationLoss(temperature=4.0, alpha=0.7)
sample_labels = torch.randint(0, 10, (4,))
loss, loss_dict = criterion(student_logits, teacher_logits, sample_labels)
print(f"Sample loss breakdown:")
print(f" Total: {loss_dict['total']:.4f}")
print(f" Task (CE): {loss_dict['task']:.4f}")
print(f" Distillation (KD): {loss_dict['distill']:.4f}")
# 2. Relational KD
print("\n2. Relational Knowledge Distillation")
print("-" * 40)
rkd_criterion = RelationalKDLoss()
# Sample features
student_feat = student_hinton.features(sample_input)
with torch.no_grad():
teacher_feat = teacher.features(sample_input)
rkd_loss, rkd_dict = rkd_criterion(student_feat, teacher_feat)
print(f"RKD loss breakdown:")
print(f" Total: {rkd_dict['total']:.4f}")
print(f" Distance-wise: {rkd_dict['distance']:.4f}")
print(f" Angle-wise: {rkd_dict['angle']:.4f}")
# 3. Similarity Preserving
print("\n3. Similarity-Preserving KD")
print("-" * 40)
sp_criterion = SimilarityPreservingLoss()
sp_loss = sp_criterion(student_feat, teacher_feat)
print(f"SP loss: {sp_loss.item():.4f}")
# Visualize similarity matrices
with torch.no_grad():
student_sim = sp_criterion.similarity_matrix(student_feat)
teacher_sim = sp_criterion.similarity_matrix(teacher_feat)
print(f"Student similarity matrix:\n{student_sim.numpy()[:2, :2]}")
print(f"Teacher similarity matrix:\n{teacher_sim.numpy()[:2, :2]}")
# 4. Deep Mutual Learning
print("\n4. Deep Mutual Learning")
print("-" * 40)
dml = DeepMutualLearning(num_students=2, temperature=3.0)
student2 = StudentNet()
with torch.no_grad():
logits1 = student_hinton(sample_input)
logits2 = student2(sample_input)
dml_loss, individual_losses = dml([logits1, logits2], sample_labels)
print(f"DML total loss: {dml_loss.item():.4f}")
print(f"Student 1 loss: {individual_losses[0].item():.4f}")
print(f"Student 2 loss: {individual_losses[1].item():.4f}")
# 5. Online Distillation
print("\n5. Online Distillation (EMA Teacher)")
print("-" * 40)
online_kd = OnlineDistillation(momentum=0.999, temperature=4.0)
online_kd.update_teacher(student_hinton)
with torch.no_grad():
teacher_ema_logits = online_kd.teacher_model(sample_input)
online_loss, online_dict = online_kd(
student_logits,
teacher_ema_logits,
sample_labels
)
print(f"Online KD loss breakdown:")
print(f" Total: {online_dict['total']:.4f}")
print(f" CE: {online_dict['ce']:.4f}")
print(f" KD: {online_dict['kd']:.4f}")
# ============================================================================
# 8. Performance Comparison Table
# ============================================================================
def print_performance_comparison():
"""
Print comprehensive comparison of distillation methods.
"""
print("\n" + "=" * 80)
print("KNOWLEDGE DISTILLATION METHODS COMPARISON")
print("=" * 80)
comparison = """
Method Characteristics:
ββββββββββββββββββββ¬βββββββββββββββ¬ββββββββββββββββ¬βββββββββββββββ¬ββββββββββββββ
β Method β Complexity β Memory β Training β Performance β
β β β Overhead β Time β Gain β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β Hinton KD β O(NC) β Low β 1.2Γ β +1.5-2.5% β
β (Response) β β β β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β FitNet β O(NHW) β Medium β 1.5Γ β +2.0-3.0% β
β (Feature) β β (features) β (2-stage) β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β Attention β O(NHW) β Low β 1.3Γ β +1.8-2.8% β
β Transfer β β (attention) β β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β RKD β O(NΒ²D) β Low β 1.4Γ β +2.2-3.2% β
β (Relational) β β β β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β SP β O(NΒ²) β Medium β 1.5Γ β +2.0-3.0% β
β (Similarity) β β (Gram matrix) β β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β CRD β O(NK) β High β 1.6Γ β +2.5-3.5% β
β (Contrastive) β β (negatives) β β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β DML β O(MΒ·NC) β High β 1.0Γ β +1.0-2.0% β
β (Mutual) β β (M students) β (parallel) β β
ββββββββββββββββββββΌβββββββββββββββΌββββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β Online KD β O(NC) β High β 1.1Γ β +0.8-1.5% β
β (EMA) β β (EMA model) β β β
ββββββββββββββββββββ΄βββββββββββββββ΄ββββββββββββββββ΄βββββββββββββββ΄ββββββββββββββ
N = batch size, C = classes, H/W = spatial dims, D = feature dim,
M = num students, K = num negatives
ImageNet Performance (ResNet-34 Teacher β Student):
ββββββββββββββββββββ¬ββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Student β Baseline β +Hinton KD β +Best Method β
ββββββββββββββββββββΌββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β ResNet-18 β 69.8% β 71.0% β 72.2% (CRD) β
β MobileNetV2 β 67.5% β 68.5% β 69.8% (RKD) β
β ShuffleNetV2 β 65.2% β 66.8% β 67.5% (AT) β
ββββββββββββββββββββ΄ββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
CIFAR-100 Performance (WideResNet-40-2 Teacher β Student):
ββββββββββββββββββββ¬ββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Student β Baseline β +Hinton KD β +Best Method β
ββββββββββββββββββββΌββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β WRN-16-2 β 73.3% β 75.6% β 76.8% (CRD) β
β ResNet-20 β 69.1% β 71.2% β 72.5% (RKD) β
β MobileNetV2 β 64.6% β 67.4% β 68.9% (SP) β
ββββββββββββββββββββ΄ββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
Compression Ratios (Teacher β Student):
βββββββββββββββββββββββββββ¬βββββββββββββ¬ββββββββββββββ¬βββββββββββββββ
β Model Pair β Params β Speedup β Acc Drop β
βββββββββββββββββββββββββββΌβββββββββββββΌββββββββββββββΌβββββββββββββββ€
β ResNet-152 β ResNet-18 β 60M β 11M β 8Γ β 4.5% β 2.8% β
β BERT-base β DistilBERT β 110M β 66M β 1.6Γ β 97% retained β
β GPT-2 β GPT-2 Small β 1.5B β 117Mβ 13Γ β ~15% perplx β
β EfficientNet-B7 β B0 β 66M β 5M β 30Γ β 6.0% β 3.5% β
βββββββββββββββββββββββββββ΄βββββββββββββ΄ββββββββββββββ΄βββββββββββββββ
Hyperparameter Recommendations:
ββββββββββββββββββ¬βββββββββββββββββββ¬ββββββββββββββββββββββββββββββ
β Hyperparameter β Typical Range β Notes β
ββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββββββββ€
β Temperature T β 1-20 β 4-6 for classification β
β β β Higher for similar tasks β
ββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββββββββ€
β Alpha Ξ» β 0.1-0.9 β 0.5-0.7 typical β
β β β Higher when teacher better β
ββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββββββββ€
β Learning rate β 0.5-2Γ baseline β Start lower, anneal slower β
ββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββββββββ€
β Epochs β 1.5-2Γ baseline β Student needs more time β
ββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββββββββ€
β Feature layers β 1-5 layers β Middle layers work best β
β β β Too early: low-level β
β β β Too late: task-specific β
ββββββββββββββββββ΄βββββββββββββββββββ΄ββββββββββββββββββββββββββββββ
Decision Guide:
ββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββ
β Scenario β Recommended Method β
ββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββ€
β Simple classification β Hinton KD (response-based) β
β Large capacity gap β FitNet (feature-based) β
β Cross-architecture β AT (attention transfer) β
β Limited data β RKD (relational) β
β Batch learning matters β SP (similarity-preserving) β
β No pre-trained teacher β DML (mutual learning) β
β Online/streaming β Online KD (EMA teacher) β
β Detection/segmentation β Feature + Response combined β
β Resource constrained β Hinton KD (lowest overhead) β
β Max performance β CRD (contrastive) or ensemble β
ββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββ
"""
print(comparison)
print("\nKey Insights:")
print("1. Response-based (Hinton): Simple, effective baseline")
print("2. Feature-based: Better for large capacity gap, architecture coupling")
print("3. Relation-based: Captures structural knowledge, scales O(NΒ²)")
print("4. Self-distillation: No teacher needed, can improve teacher-size models")
print("5. Combination: Feature + Response often works best")
print("6. Temperature: 4-6 typical, tune via validation")
print("7. Alpha: 0.5-0.7 typical, higher when teacher much better")
print("8. Training time: 1.2-1.6Γ longer than baseline")
print("9. Performance gain: +1-3% typical, varies by task/architecture")
print("10. Best practice: Start with Hinton KD, add features if needed")
# ============================================================================
# 9. Run Demonstrations
# ============================================================================
if __name__ == "__main__":
# Run demos
demo_distillation_comparison()
print_performance_comparison()
print("\n" + "=" * 80)
print("Knowledge Distillation Implementations Complete!")
print("=" * 80)