import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Few-Shot ClassificationΒΆ
N-way K-shotΒΆ
N classes with K examples each
Support set: \(S = \{(x_i, y_i)\}_{i=1}^{NK}\)
Query: classify new examples
Prototypical Networks IdeaΒΆ
Embed examples: \(f_\theta: \mathbb{R}^d \to \mathbb{R}^m\)
Compute class prototypes (means)
Classify by nearest prototype
π Reference Materials:
bayesian_inference_deep_learning.pdf - Bayesian Inference Deep Learning
2. AlgorithmΒΆ
Class PrototypeΒΆ
ClassificationΒΆ
where \(d(\cdot, \cdot)\) is distance (e.g., Euclidean).
class ConvEmbedding(nn.Module):
"""Simple CNN for embedding."""
def __init__(self, in_channels=1, hidden_dim=64, output_dim=64):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(hidden_dim * 3 * 3, output_dim)
)
def forward(self, x):
return self.encoder(x)
def compute_prototypes(embeddings, labels, n_classes):
"""Compute class prototypes."""
prototypes = torch.zeros(n_classes, embeddings.size(1), device=embeddings.device)
for k in range(n_classes):
mask = labels == k
prototypes[k] = embeddings[mask].mean(dim=0)
return prototypes
def euclidean_distance(x, y):
"""Compute pairwise Euclidean distances."""
return torch.cdist(x, y, p=2)
# Test
model = ConvEmbedding().to(device)
x_test = torch.randn(5, 1, 28, 28).to(device)
emb = model(x_test)
print(f"Embedding shape: {emb.shape}")
Episodic TrainingΒΆ
Prototypical networks are trained episodically: each training iteration samples a random \(N\)-way, \(K\)-shot task from the training classes. The support set (K examples per class) is used to compute class prototypes via averaging in embedding space, and the query set is classified by nearest-prototype distance. The training loss is the negative log-probability of the correct class under the softmax over distances. This episodic protocol mimics the few-shot evaluation setting during training, ensuring the learned embedding space is well-suited for prototype-based classification at test time.
def sample_episode(data, labels, n_way, k_shot, k_query):
"""Sample N-way K-shot episode."""
# Sample N classes
classes = np.random.choice(len(np.unique(labels)), n_way, replace=False)
support_x, support_y = [], []
query_x, query_y = [], []
for i, c in enumerate(classes):
# Get examples from class c
idx = np.where(labels == c)[0]
samples = np.random.choice(idx, k_shot + k_query, replace=False)
support_x.append(data[samples[:k_shot]])
support_y.extend([i] * k_shot)
query_x.append(data[samples[k_shot:]])
query_y.extend([i] * k_query)
support_x = torch.cat(support_x)
support_y = torch.tensor(support_y)
query_x = torch.cat(query_x)
query_y = torch.tensor(query_y)
return support_x, support_y, query_x, query_y
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
# Convert to tensors
train_data = mnist.data.unsqueeze(1).float() / 255.0
train_labels = mnist.targets.numpy()
print(f"Dataset: {train_data.shape}")
def train_prototypical(model, data, labels, n_episodes=1000, n_way=5, k_shot=5, k_query=15):
"""Train prototypical network."""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []
for episode in range(n_episodes):
# Sample episode
support_x, support_y, query_x, query_y = sample_episode(
data, labels, n_way, k_shot, k_query
)
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
# Embed
support_emb = model(support_x)
query_emb = model(query_x)
# Prototypes
prototypes = compute_prototypes(support_emb, support_y, n_way)
# Distances
dists = euclidean_distance(query_emb, prototypes)
# Loss
log_probs = F.log_softmax(-dists, dim=1)
loss = F.nll_loss(log_probs, query_y)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (episode + 1) % 200 == 0:
acc = (log_probs.argmax(dim=1) == query_y).float().mean()
print(f"Episode {episode+1}, Loss: {loss.item():.4f}, Acc: {acc:.3f}")
return losses
# Train
model = ConvEmbedding().to(device)
losses = train_prototypical(model, train_data, train_labels, n_episodes=1000,
n_way=5, k_shot=5, k_query=15)
plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(50)/50, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Episode', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('Prototypical Network Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
EvaluationΒΆ
Evaluating a prototypical network follows the same episodic protocol as training: sample many random \(N\)-way, \(K\)-shot tasks from held-out test classes (classes never seen during training), compute prototypes from the support set, and measure classification accuracy on the query set. Reporting mean accuracy and 95% confidence intervals over many episodes gives a reliable estimate of few-shot performance. The ability to classify novel classes with only a few examples per class β without any gradient updates at test time β is what makes metric-based meta-learning appealing for real-world deployment.
def evaluate_prototypical(model, data, labels, n_episodes=100, n_way=5, k_shot=5, k_query=15):
"""Evaluate prototypical network."""
model.eval()
accuracies = []
with torch.no_grad():
for _ in range(n_episodes):
support_x, support_y, query_x, query_y = sample_episode(
data, labels, n_way, k_shot, k_query
)
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
support_emb = model(support_x)
query_emb = model(query_x)
prototypes = compute_prototypes(support_emb, support_y, n_way)
dists = euclidean_distance(query_emb, prototypes)
preds = (-dists).argmax(dim=1)
acc = (preds == query_y).float().mean()
accuracies.append(acc.item())
return np.array(accuracies)
# Load test data
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_data = mnist_test.data.unsqueeze(1).float() / 255.0
test_labels = mnist_test.targets.numpy()
# Evaluate
accs = evaluate_prototypical(model, test_data, test_labels, n_episodes=200)
print(f"5-way 5-shot accuracy: {accs.mean():.3f} Β± {accs.std():.3f}")
plt.figure(figsize=(10, 5))
plt.hist(accs, bins=20, edgecolor='black', alpha=0.7)
plt.axvline(accs.mean(), color='r', linestyle='--', label=f'Mean: {accs.mean():.3f}')
plt.xlabel('Accuracy', fontsize=11)
plt.ylabel('Count', fontsize=11)
plt.title('Test Episode Accuracies', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
VisualizationΒΆ
Visualizing the embedding space using t-SNE or UMAP reveals whether the learned representations form tight, well-separated clusters for different classes. For a well-trained prototypical network, same-class embeddings should cluster tightly around their prototype, with clear separation between classes. Plotting the prototypes (class means) alongside individual embeddings provides intuition for why nearest-prototype classification works: if the clusters are compact and distant, the decision boundaries are clean and few-shot accuracy will be high.
# Sample episode for visualization
support_x, support_y, query_x, query_y = sample_episode(
test_data, test_labels, n_way=5, k_shot=5, k_query=10
)
model.eval()
with torch.no_grad():
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
support_emb = model(support_x).cpu().numpy()
query_emb = model(query_x).cpu().numpy()
prototypes = compute_prototypes(
torch.from_numpy(support_emb), support_y.cpu(), 5
).numpy()
# UMAP for 2D visualization
from sklearn.manifold import TSNE
all_emb = np.vstack([support_emb, prototypes, query_emb])
tsne = TSNE(n_components=2, random_state=42)
emb_2d = tsne.fit_transform(all_emb)
n_support = len(support_emb)
n_proto = len(prototypes)
support_2d = emb_2d[:n_support]
proto_2d = emb_2d[n_support:n_support+n_proto]
query_2d = emb_2d[n_support+n_proto:]
plt.figure(figsize=(12, 10))
colors = plt.cm.tab10(range(5))
# Support
for k in range(5):
mask = support_y.cpu().numpy() == k
plt.scatter(support_2d[mask, 0], support_2d[mask, 1],
c=[colors[k]], marker='o', s=100, alpha=0.6, label=f'Class {k} Support')
# Prototypes
for k in range(5):
plt.scatter(proto_2d[k, 0], proto_2d[k, 1],
c=[colors[k]], marker='*', s=500, edgecolors='black', linewidths=2)
# Query
for k in range(5):
mask = query_y.cpu().numpy() == k
plt.scatter(query_2d[mask, 0], query_2d[mask, 1],
c=[colors[k]], marker='x', s=150, alpha=0.8)
plt.xlabel('t-SNE 1', fontsize=11)
plt.ylabel('t-SNE 2', fontsize=11)
plt.title('Prototypical Network Embedding Space', fontsize=12)
plt.legend(ncol=2, fontsize=9)
plt.grid(True, alpha=0.3)
plt.show()
Different MetricsΒΆ
The choice of distance metric in embedding space significantly affects prototypical network performance. The original paper uses squared Euclidean distance, but cosine distance and Mahalanobis distance are common alternatives. Euclidean distance assumes isotropic clusters of similar scale; cosine distance normalizes for magnitude and focuses on direction; Mahalanobis distance accounts for per-dimension variance. Comparing metrics on the same task reveals how the embedding geometry interacts with the distance function, helping practitioners select the best combination for their domain.
def cosine_distance(x, y):
"""Cosine distance."""
x_norm = F.normalize(x, dim=1)
y_norm = F.normalize(y, dim=1)
return 1 - x_norm @ y_norm.T
# Compare metrics
metrics = {
'Euclidean': euclidean_distance,
'Cosine': cosine_distance
}
results = {}
for name, metric_fn in metrics.items():
accs = []
model.eval()
with torch.no_grad():
for _ in range(100):
support_x, support_y, query_x, query_y = sample_episode(
test_data, test_labels, 5, 5, 15
)
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
support_emb = model(support_x)
query_emb = model(query_x)
prototypes = compute_prototypes(support_emb, support_y, 5)
dists = metric_fn(query_emb, prototypes)
preds = (-dists).argmax(dim=1)
acc = (preds == query_y).float().mean().item()
accs.append(acc)
results[name] = accs
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
positions = np.arange(len(metrics))
bp = ax.boxplot([results[name] for name in metrics.keys()],
labels=metrics.keys(), patch_artist=True)
for patch in bp['boxes']:
patch.set_facecolor('lightblue')
ax.set_ylabel('Accuracy', fontsize=11)
ax.set_title('Distance Metrics Comparison', fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
plt.show()
for name, accs in results.items():
print(f"{name}: {np.mean(accs):.3f} Β± {np.std(accs):.3f}")
SummaryΒΆ
Prototypical Networks:ΒΆ
Learn embedding where classification is nearest centroid.
Algorithm:ΒΆ
Embed support examples
Compute class prototypes (means)
Classify by distance to prototypes
Advantages:ΒΆ
Simple and efficient
Interpretable (class centers)
Works with any metric
Scales to many classes
Applications:ΒΆ
Few-shot image classification
Zero-shot learning
Domain adaptation
Cold-start recommendation
Extensions:ΒΆ
Relation Networks (learned metric)
Matching Networks (attention)
TADAM (task conditioning)
Next Steps:ΒΆ
16_maml_meta_learning.ipynb - Optimization-based
Study transductive inference
Explore semi-supervised episodes
Advanced Prototypical Networks TheoryΒΆ
1. Few-Shot Learning Problem FormulationΒΆ
1.1 Mathematical SetupΒΆ
Few-shot classification: Learn to classify examples with very few labeled samples per class.
N-way K-shot learning:
N classes
K examples per class (support set)
M query examples per class
Support set: S = {(xβ, yβ), β¦, (xββ, yββ)} where |{i : yα΅’ = c}| = K for each class c
Query set: Q = {(xΜβ, α»Ήβ), β¦, (xΜβ, α»Ήβ)}
Goal: Classify query examples using only K support examples per class
1.2 Episodic TrainingΒΆ
Meta-learning paradigm: Train on many tasks, generalize to new tasks.
Episode structure:
Sample N classes from training set
Sample K examples per class (support)
Sample M examples per class (query)
Train to classify query given support
Repeat for many episodes
Key insight: Training mimics test-time scenario (few examples per class)
2. Prototypical NetworksΒΆ
2.1 Core IdeaΒΆ
Prototype: Representative embedding for each class, computed as mean of support embeddings.
Prototype computation:
c_k = (1/K) Ξ£_{(xα΅’,yα΅’)βSβ} f_ΞΈ(xα΅’)
where:
c_k: prototype for class k
S_k: support examples for class k
f_ΞΈ: embedding function (neural network)
Classification rule:
p(y = k | x) = softmax(-d(f_ΞΈ(x), c_k))
= exp(-d(f_ΞΈ(x), c_k)) / Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))
where d(Β·,Β·) is a distance metric (typically Euclidean).
2.2 Distance MetricsΒΆ
Euclidean distance (most common):
d(x, y) = ||x - y||β = β(Ξ£α΅’ (xα΅’ - yα΅’)Β²)
Squared Euclidean:
d(x, y) = ||x - y||βΒ²
Mathematically equivalent (monotonic transformation) but simpler gradients.
Cosine distance:
d(x, y) = 1 - (xΒ·y) / (||x|| ||y||)
Useful when magnitude doesnβt matter.
Mahalanobis distance:
d(x, y) = β((x-y)α΅ Ξ£β»ΒΉ (x-y))
Accounts for correlations between dimensions.
2.3 Loss FunctionΒΆ
Negative log-likelihood:
L(ΞΈ) = -log p(y = y_true | x)
= -log [exp(-d(f_ΞΈ(x), c_y_true)) / Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))]
= d(f_ΞΈ(x), c_y_true) + log Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))
Intuition:
Minimize distance to correct prototype
Maximize distance to incorrect prototypes
Softmax provides probabilistic interpretation
2.4 Gradient AnalysisΒΆ
Gradient w.r.t. embedding:
βL/βf_ΞΈ(x) = βd(f_ΞΈ(x), c_y_true)/βf_ΞΈ(x) - Ξ£β±Ό p(y=j|x) βd(f_ΞΈ(x), cβ±Ό)/βf_ΞΈ(x)
For Euclidean distance:
βd/βf_ΞΈ(x) = 2(f_ΞΈ(x) - c_k)
Gradient interpretation:
Push embedding towards correct prototype
Pull away from incorrect prototypes (weighted by probability)
3. Theoretical FoundationsΒΆ
3.1 Connection to k-Nearest NeighborsΒΆ
Prototypes as cluster centers: Prototypical Networks can be seen as soft k-NN with k=K.
k-NN decision:
Ε· = argmax_k Ξ£_{(xα΅’,yα΅’)βS, yα΅’=k} Ξ΄(x, xα΅’)
where Ξ΄ is indicator for x_i among k-nearest neighbors.
Prototypical Networks smoothed version:
Ε· = argmax_k p(y=k|x) where prototypes = mean(k-NN)
3.2 Mixture Density EstimationΒΆ
Generative view: Each class modeled as Gaussian distribution.
Class-conditional distribution:
p(x | y=k) = N(x; c_k, ΟΒ²I)
Posterior via Bayes rule:
p(y=k | x) = p(x|y=k)p(y=k) / Ξ£β±Ό p(x|y=j)p(y=j)
With uniform prior p(y=k) = 1/N:
p(y=k | x) β exp(-||x - c_k||Β²/(2ΟΒ²))
Prototypical Networks: Equivalent to maximum likelihood estimation of mixture model!
3.3 Bregman DivergencesΒΆ
Generalization of distance: Prototypical Networks work with any Bregman divergence.
Bregman divergence:
D_Ο(x, y) = Ο(x) - Ο(y) - β¨βΟ(y), x-yβ©
for convex function Ο.
Examples:
Ο(x) = ||x||Β²/2 β Squared Euclidean
Ο(x) = Ξ£α΅’ xα΅’ log xα΅’ β KL divergence
Ο(x) = -log det(x) β Log-determinant divergence
Theorem (Snell et al., 2017): Prototypical Networks with Bregman divergence D_Ο correspond to exponential family distributions with sufficient statistic βΟ(x).
4. Embedding Function DesignΒΆ
4.1 Common ArchitecturesΒΆ
Convolutional Networks (images):
4-layer CNN (Snell et al., 2017 baseline)
ResNet-12 (modern standard)
WideResNet-28-10 (state-of-the-art)
Architecture pattern:
Conv blocks β Global pooling β Embedding
Output dimension: Typically 64-1600 dimensions
Recurrent Networks (sequences):
LSTM/GRU for text
Bidirectional encoding
Attention mechanisms
Transformers (general):
Self-attention for long-range dependencies
Pre-trained models (BERT, ViT) as backbone
4.2 Design PrinciplesΒΆ
1. High-dimensional embeddings:
Typical: 64-1600 dimensions
Higher dimensions β better separation
But: overfitting risk with very few samples
2. Normalized embeddings:
L2 normalization: x β x/||x||
Converts to cosine similarity metric
Prevents magnitude dominating distance
3. Batch normalization:
Stabilizes training
But: tricky with small support sets
Alternative: Layer normalization, Group normalization
4. Pooling strategies:
Global average pooling (most common)
Max pooling
Attention pooling (weighted average)
4.3 Pre-training StrategiesΒΆ
Transfer learning:
Pre-train on large dataset (ImageNet, etc.)
Fine-tune with episodic training
Helps with limited meta-training data
Self-supervised pre-training:
Contrastive learning (SimCLR, MoCo)
Rotation prediction
Jigsaw puzzles
Improves embedding quality without labels
5. Advanced VariantsΒΆ
5.1 Gaussian Prototypical NetworksΒΆ
Motivation: Model class variance, not just mean.
Per-class covariance:
c_k = (1/K) Ξ£ f_ΞΈ(xα΅’) (mean)
Ξ£_k = (1/K) Ξ£ (f_ΞΈ(xα΅’) - c_k)(f_ΞΈ(xα΅’) - c_k)α΅ (covariance)
Mahalanobis distance:
d(x, k) = (f_ΞΈ(x) - c_k)α΅ Ξ£_kβ»ΒΉ (f_ΞΈ(x) - c_k)
Challenge: Singular covariance with K < embedding_dim.
Solutions:
Diagonal covariance only
Regularization: Ξ£_k β Ξ£_k + Ξ»I
Shared covariance across classes
5.2 Semi-Prototypical NetworksΒΆ
Use unlabeled data in support set:
Soft prototypes:
c_k = (1/Z_k) [Ξ£_{labeled} f_ΞΈ(xα΅’) + Ξ£_{unlabeled} p(y=k|xΜβ±Ό) f_ΞΈ(xΜβ±Ό)]
Iterative refinement:
Compute initial prototypes from labeled data
Pseudo-label unlabeled data
Update prototypes with pseudo-labels
Repeat
Benefit: Better prototypes with limited labels + abundant unlabeled data.
5.3 Transductive Prototypical NetworksΒΆ
Use query set to refine prototypes:
Standard (inductive): Prototypes only from support set.
Transductive: Prototypes from support + query.
Algorithm:
Initialize prototypes from support
Pseudo-label query examples
Update prototypes using query embeddings
Re-classify query
Iterate until convergence
Advantage: Query examples provide more data for better prototypes.
Disadvantage: Not applicable when query arrives one-by-one.
5.4 Task-Adaptive Prototypical NetworksΒΆ
Adapt prototypes per task:
Learnable scaling:
d(x, c_k) = ||W_task(f_ΞΈ(x) - c_k)||Β²
where W_task is task-specific diagonal matrix.
Feature selection:
Ξ±_task = softmax(g_Ο(c_1, ..., c_N)) (attention over dimensions)
d(x, c_k) = ||Ξ±_task β (f_ΞΈ(x) - c_k)||Β²
Benefit: Different tasks may need different feature dimensions.
6. Training TechniquesΒΆ
6.1 Episodic Sampling StrategiesΒΆ
Uniform sampling:
Sample N classes uniformly from training set
Simple, balanced
Class-balanced sampling:
Oversample rare classes
Ensures all classes seen equally
Curriculum learning:
Start with easy tasks (more shots, fewer ways)
Gradually increase difficulty
Faster convergence
Hard task mining:
Identify difficult class combinations
Oversample hard episodes
Improves worst-case performance
6.2 Data AugmentationΒΆ
Standard augmentations:
Random crops
Horizontal flips
Color jitter
Rotation
Few-shot specific:
Mixing: Average support examples within class
Hallucination: Generate synthetic examples
Adversarial: Small perturbations for robustness
Augmentation timing:
Support set: Usually yes (increases effective K)
Query set: Typically no (want clean evaluation)
6.3 OptimizationΒΆ
Learning rate scheduling:
Warmup for first few thousand episodes
Cosine annealing or step decay
Lower learning rate than standard supervised
Gradient clipping:
||βΞΈ|| > threshold β βΞΈ β threshold Β· βΞΈ/||βΞΈ||
Prevents instability from variable episode difficulty.
Optimizer choice:
Adam (most common, adaptive)
SGD with momentum (better generalization)
RAdam (combines benefits)
Batch size (episodes per update):
Typical: 1-4 episodes
Larger batch β more stable but slower
Smaller batch β faster iteration
7. Evaluation ProtocolsΒΆ
7.1 Standard BenchmarksΒΆ
Omniglot:
1623 handwritten characters
20 examples per class
Task: 5-way 1-shot, 5-way 5-shot, 20-way 1-shot
miniImageNet:
100 classes, 600 examples each
64 train, 16 validation, 20 test classes
Task: 5-way 1-shot, 5-way 5-shot
tieredImageNet:
608 classes from ImageNet
Hierarchical split (avoids train/test similarity)
Task: 5-way 1-shot, 5-way 5-shot
CIFAR-FS:
100 CIFAR-100 classes
64/16/20 train/val/test split
Task: 5-way 1-shot, 5-way 5-shot
7.2 Evaluation MetricsΒΆ
Accuracy:
Acc = (1/MΒ·N) Ξ£ 1[Ε·α΅’ = yα΅’]
Averaged over query examples and episodes.
95% confidence intervals:
CI = mean Β± 1.96 Β· std/β(num_episodes)
Typical: 600-10,000 test episodes.
Per-class accuracy: Identify which classes are hard (for analysis).
Calibration: Are predicted probabilities p(y=k|x) well-calibrated?
Expected Calibration Error (ECE):
ECE = Ξ£_b (|B_b|/N) |acc(B_b) - conf(B_b)|
where B_b are prediction bins.
7.3 Cross-Domain EvaluationΒΆ
Test on different domain:
Train: miniImageNet
Test: CUB-200 (birds), Cars, etc.
Measures: Transferability of learned embedding.
Finding: Prototypical Networks transfer better than metric learning methods (more generalizable distance function).
8. Comparison with Other MethodsΒΆ
8.1 Matching NetworksΒΆ
Attention-based:
p(y=k|x) = Ξ£_{(xα΅’,yα΅’)βS, yα΅’=k} a(x, xα΅’) where a = softmax(cos(f(x), f(xα΅’)))
Differences from Prototypical:
Attends to individual examples, not prototypes
More expressive but more parameters
Prototypical often performs similarly with simpler approach
8.2 Relation NetworksΒΆ
Learned metric:
d(x, x') = g_Ο(concat(f_ΞΈ(x), f_ΞΈ(x')))
where g_Ο is a learned network.
Prototypical: Fixed distance (Euclidean). Relation: Learned distance (more flexible).
Tradeoff:
Relation: Better performance with enough meta-training data
Prototypical: More robust with limited meta-training
8.3 MAML (Model-Agnostic Meta-Learning)ΒΆ
Optimization-based:
Inner loop: Adapt parameters on support
Outer loop: Update initialization
Prototypical: Non-parametric (no inner optimization).
Comparison:
Method | Inner Loop | Outer Loop | Speed
βββββββββββββββββ|ββββββββββββ|βββββββββββββββββ|ββββββ
Prototypical | None | Update f_ΞΈ | Fast
MAML | K gradient | Update ΞΈ_init | Slow
Relation Net | None | Update f_ΞΈ, g_Ο | Medium
When Prototypical wins: Limited compute, need fast inference. When MAML wins: Complex adaptation required, enough compute.
9. State-of-the-Art ResultsΒΆ
9.1 Benchmark PerformanceΒΆ
miniImageNet (5-way accuracy):
Method | 1-shot | 5-shot
βββββββββββββββββββββββ|βββββββββββ|ββββββββββ
Prototypical (2017) | 49.4% | 68.2%
Prototypical + deeper | 56.5% | 73.7%
FEAT (2019) | 55.2% | 71.5%
MetaOptNet (2019) | 62.6% | 78.6%
Meta-Baseline (2020) | 63.2% | 79.3%
tieredImageNet (5-way):
Method | 1-shot | 5-shot
βββββββββββββββββββββββ|βββββββββββ|ββββββββββ
Prototypical (2017) | 53.3% | 72.7%
Meta-Baseline (2020) | 68.1% | 83.7%
FRN (2021) | 66.5% | 82.8%
9.2 Key Improvements Over BaselineΒΆ
Better backbones: ResNet-12 β +7-10% accuracy. Pre-training: Self-supervised β +5-8% accuracy. Data augmentation: MixUp, CutMix β +2-4%. Transductive inference: Query refinement β +3-5%.
10. Limitations and ChallengesΒΆ
10.1 Known IssuesΒΆ
1. Domain shift:
Training and test classes must be similar
Poor cross-domain generalization without adaptation
2. Prototype quality with K=1:
Single example may not represent class well
Sensitive to outliers and noise
3. Class imbalance:
Assumes balanced support sets
Real-world often has varying shots per class
4. Computational cost of embeddings:
Must embed all support examples at test time
Slow for large support sets
5. Fixed distance metric:
Euclidean may not be optimal for all tasks
Learned metrics can be better but more complex
10.2 Failure ModesΒΆ
Out-of-distribution query: Query very different from support β assigns to nearest prototype (may be wrong).
Solution: Calibration, outlier detection.
Fine-grained classification: When classes are very similar, prototypes overlap.
Solution: Higher-dimensional embeddings, metric learning.
Multi-modal classes: Class has multiple clusters in embedding space.
Solution: Mixture models, multiple prototypes per class.
11. Extensions and ApplicationsΒΆ
11.1 Few-Shot DetectionΒΆ
Object detection with few examples:
Extract region proposals
Compute prototypes from support boxes
Match query regions to prototypes
Challenges: Background class, varying scales.
11.2 Few-Shot SegmentationΒΆ
Semantic segmentation:
Prototype per pixel class
Support: Segmented images
Query: Segment using prototypes
Approach:
For each pixel p in query:
Embedding z_p = f_ΞΈ(local_patch(p))
Class = argmin_k d(z_p, c_k)
11.3 Cross-Modal Few-ShotΒΆ
Match across modalities:
Image-text: CLIP-style prototypes
Audio-visual: Cross-modal embeddings
Joint embedding space:
f_image, f_text map to same space
Prototypes can be text or image
11.4 Continual Few-Shot LearningΒΆ
Stream of few-shot tasks:
Learn task 1, then task 2, etc.
Avoid catastrophic forgetting
Approach:
Store prototypes for old classes
Replay or regularization for stability
12. Implementation Best PracticesΒΆ
12.1 HyperparametersΒΆ
Embedding dimension:
Small datasets (Omniglot): 64-128
ImageNet-scale: 512-1600
Higher β better separation, but overfitting risk
Learning rate:
Typical: 1e-3 to 1e-4
Warmup for 1k-5k episodes
Decay by 0.1 every 20-40k episodes
Number of episodes:
Training: 40k-100k episodes
Validation: 600 episodes
Test: 600-10,000 episodes
Ways and shots:
Train on variable N-way, K-shot (e.g., 5-20 way, 1-5 shot)
Test on fixed task (e.g., 5-way 1-shot)
12.2 Debugging TipsΒΆ
Check prototype separation:
# Visualize prototypes (e.g., with t-SNE)
# Should see N clusters for N-way task
Embedding norms:
Should be similar across classes
If not, consider normalization
Sanity checks:
5-way 5-shot should be easier than 1-shot
5-way should be easier than 20-way
Training accuracy should exceed test
Common bugs:
Leaking test classes into training
Not shuffling support/query within episode
Incorrect prototype computation (check averaging)
12.3 Computational EfficiencyΒΆ
Pre-compute support embeddings:
# Embed support once per episode, reuse for all queries
support_emb = f_theta(support_images) # Compute once
prototypes = support_emb.mean(dim=1) # Average per class
for query in queries:
query_emb = f_theta(query)
distances = compute_distance(query_emb, prototypes)
Batched distance computation:
# Compute all query-prototype distances in one operation
# (Q, D) vs (N, D) β (Q, N) distances
distances = torch.cdist(query_emb, prototypes)
Mixed precision:
Use FP16 for embedding network
Keep distances in FP32 for numerical stability
13. Recent Advances (2020-2024)ΒΆ
13.1 Meta-BaselineΒΆ
Simplified approach:
Pre-train with standard cross-entropy
Fine-tune last layer only on support
Often matches complex meta-learning methods
Insight: Good features + simple adaptation can beat complex meta-learning.
13.2 Distribution CalibrationΒΆ
Problem: Prototypes from few examples have high variance.
Solution: Calibrate using base class statistics.
c_k_calibrated = Ξ±Β·c_k + (1-Ξ±)Β·ΞΌ_base
where ΞΌ_base is mean prototype from base classes.
13.3 Self-Supervision + Few-ShotΒΆ
SimCLR, MoCo for pre-training:
Learn general features without labels
Then few-shot learning on top
Results: +5-10% accuracy on standard benchmarks.
13.4 Optimal Transport for MatchingΒΆ
Wasserstein distance: Match support and query distributions optimally.
Better than prototypes when:
Multi-modal classes
Varying shot numbers per class
14. Future DirectionsΒΆ
1. Cross-domain few-shot:
Better transfer across very different domains
Meta-learning with domain adaptation
2. Task-agnostic meta-learning:
Single model for any N-way K-shot combination
Currently re-train for each configuration
3. Theoretical understanding:
Sample complexity bounds
Generalization guarantees for few-shot
4. Efficient architectures:
Neural architecture search for few-shot
Lightweight models for edge deployment
5. Multi-task few-shot:
Learn multiple tasks simultaneously with few examples each
Leveraging cross-task structure
15. Key TakeawaysΒΆ
Simplicity wins: Prototypical Networksβ simple mean-based approach often matches complex methods.
Embedding quality is key: Good feature representation (from pre-training or architecture) matters more than meta-learning algorithm.
Episodic training crucial: Training must mimic test scenario (few examples per class).
Trade-offs everywhere:
Complexity vs robustness (Prototypical simpler, Relation more flexible)
Speed vs performance (Prototypical fast, MAML slow but expressive)
Pre-training helps: Transfer learning or self-supervision improves few-shot significantly.
Distance metric matters: Euclidean works well, but learned metrics can be better for specific domains.
Not just for images: Prototypical Networks work for any domain with meaningful embeddings (text, audio, graphs).
16. ReferencesΒΆ
Foundational:
Snell et al. (2017): βPrototypical Networks for Few-shot Learningβ (NeurIPS)
Vinyals et al. (2016): βMatching Networks for One Shot Learningβ
Sung et al. (2018): βLearning to Compare: Relation Network for Few-Shot Learningβ
Theoretical:
Banerjee et al. (2005): βClustering with Bregman Divergencesβ
Allen et al. (2019): βInfinite Mixture Prototypes for Few-Shot Learningβ
Recent advances:
Chen et al. (2020): βA Closer Look at Few-shot Classificationβ (Meta-Baseline)
Yang et al. (2021): βFree Lunch for Few-Shot Learning: Distribution Calibrationβ
Rizve et al. (2021): βExploring Complementary Strengths of Invariant and Equivariant Representationsβ
Applications:
Hu et al. (2019): βFew-Shot Object Detectionβ
Shaban et al. (2017): βOne-Shot Learning for Semantic Segmentationβ
# Advanced Prototypical Networks Implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List
# ============================================================================
# 1. Distance Metrics
# ============================================================================
class DistanceMetric:
"""Base class for distance metrics."""
@staticmethod
def euclidean(x, y):
"""Euclidean distance: ||x - y||_2"""
return torch.norm(x - y, p=2, dim=-1)
@staticmethod
def squared_euclidean(x, y):
"""Squared Euclidean: ||x - y||_2^2"""
return torch.sum((x - y) ** 2, dim=-1)
@staticmethod
def cosine(x, y):
"""Cosine distance: 1 - (xΒ·y)/(||x|| ||y||)"""
x_norm = F.normalize(x, p=2, dim=-1)
y_norm = F.normalize(y, p=2, dim=-1)
return 1 - torch.sum(x_norm * y_norm, dim=-1)
@staticmethod
def pairwise_distances(x, y, distance='euclidean'):
"""
Compute pairwise distances between all pairs.
Args:
x: (N, D) - N points in D dimensions
y: (M, D) - M points in D dimensions
distance: 'euclidean', 'squared_euclidean', or 'cosine'
Returns:
distances: (N, M) - distance from each x to each y
"""
if distance == 'euclidean':
# ||x - y||_2 = sqrt(||x||^2 - 2xΒ·y + ||y||^2)
x_norm = (x ** 2).sum(dim=1, keepdim=True) # (N, 1)
y_norm = (y ** 2).sum(dim=1, keepdim=True) # (M, 1)
dist = x_norm + y_norm.T - 2 * torch.matmul(x, y.T)
return torch.sqrt(torch.clamp(dist, min=1e-12))
elif distance == 'squared_euclidean':
x_norm = (x ** 2).sum(dim=1, keepdim=True)
y_norm = (y ** 2).sum(dim=1, keepdim=True)
return x_norm + y_norm.T - 2 * torch.matmul(x, y.T)
elif distance == 'cosine':
x_norm = F.normalize(x, p=2, dim=1)
y_norm = F.normalize(y, p=2, dim=1)
return 1 - torch.matmul(x_norm, y_norm.T)
else:
raise ValueError(f"Unknown distance: {distance}")
# ============================================================================
# 2. Embedding Networks
# ============================================================================
class ConvBlock(nn.Module):
"""Convolutional block for image embedding."""
def __init__(self, in_channels, out_channels, pool=True):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2) if pool else None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
if self.pool is not None:
x = self.pool(x)
return x
class ConvEmbedding(nn.Module):
"""
4-layer CNN for image embedding (Snell et al., 2017 baseline).
"""
def __init__(self, in_channels=3, hidden_dim=64, embedding_dim=64):
super().__init__()
self.encoder = nn.Sequential(
ConvBlock(in_channels, hidden_dim, pool=True), # 64x64 -> 32x32
ConvBlock(hidden_dim, hidden_dim, pool=True), # 32x32 -> 16x16
ConvBlock(hidden_dim, hidden_dim, pool=True), # 16x16 -> 8x8
ConvBlock(hidden_dim, embedding_dim, pool=True), # 8x8 -> 4x4
)
self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
"""
Args:
x: (batch, channels, H, W)
Returns:
embeddings: (batch, embedding_dim)
"""
x = self.encoder(x)
x = self.pool(x)
return x.view(x.size(0), -1)
class ResNetBlock(nn.Module):
"""Residual block for deeper networks."""
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet12Embedding(nn.Module):
"""
ResNet-12 for few-shot learning (modern standard).
Much better than 4-layer CNN.
"""
def __init__(self, in_channels=3, embedding_dim=640):
super().__init__()
self.layer1 = nn.Sequential(
ResNetBlock(in_channels, 64),
ResNetBlock(64, 64),
ResNetBlock(64, 64),
nn.MaxPool2d(2)
)
self.layer2 = nn.Sequential(
ResNetBlock(64, 128),
ResNetBlock(128, 128),
ResNetBlock(128, 128),
nn.MaxPool2d(2)
)
self.layer3 = nn.Sequential(
ResNetBlock(128, 256),
ResNetBlock(256, 256),
ResNetBlock(256, 256),
nn.MaxPool2d(2)
)
self.layer4 = nn.Sequential(
ResNetBlock(256, 512),
ResNetBlock(512, 512),
ResNetBlock(512, 512),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(512, embedding_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# ============================================================================
# 3. Prototypical Networks
# ============================================================================
class PrototypicalNetwork(nn.Module):
"""
Prototypical Networks for few-shot classification.
Computes class prototypes as mean of support embeddings,
classifies queries by distance to prototypes.
"""
def __init__(self, embedding_net, distance='squared_euclidean'):
super().__init__()
self.embedding_net = embedding_net
self.distance = distance
def compute_prototypes(self, support_embeddings, support_labels, n_way):
"""
Compute class prototypes as mean of support embeddings.
Args:
support_embeddings: (n_support, embedding_dim)
support_labels: (n_support,) - class labels 0 to n_way-1
n_way: number of classes
Returns:
prototypes: (n_way, embedding_dim)
"""
prototypes = torch.zeros(n_way, support_embeddings.size(1),
device=support_embeddings.device,
dtype=support_embeddings.dtype)
for k in range(n_way):
# Get all embeddings for class k
class_mask = (support_labels == k)
class_embeddings = support_embeddings[class_mask]
# Prototype = mean of class embeddings
prototypes[k] = class_embeddings.mean(dim=0)
return prototypes
def forward(self, support_images, support_labels, query_images, n_way, n_shot):
"""
Args:
support_images: (n_way * n_shot, C, H, W)
support_labels: (n_way * n_shot,)
query_images: (n_query, C, H, W)
n_way: number of classes
n_shot: number of examples per class
Returns:
logits: (n_query, n_way) - negative distances (log probabilities)
"""
# Embed support and query
support_embeddings = self.embedding_net(support_images)
query_embeddings = self.embedding_net(query_images)
# Compute prototypes
prototypes = self.compute_prototypes(support_embeddings, support_labels, n_way)
# Compute distances from queries to prototypes
distances = DistanceMetric.pairwise_distances(
query_embeddings, prototypes, distance=self.distance
)
# Convert to logits (negative distances for softmax)
logits = -distances
return logits
def loss(self, logits, query_labels):
"""
Cross-entropy loss.
Args:
logits: (n_query, n_way)
query_labels: (n_query,)
Returns:
loss: scalar
"""
return F.cross_entropy(logits, query_labels)
# ============================================================================
# 4. Gaussian Prototypical Networks
# ============================================================================
class GaussianPrototypicalNetwork(nn.Module):
"""
Prototypical Networks with Gaussian class models.
Uses Mahalanobis distance with learned covariance.
"""
def __init__(self, embedding_net, regularization=1e-4):
super().__init__()
self.embedding_net = embedding_net
self.regularization = regularization
def compute_gaussian_prototypes(self, support_embeddings, support_labels, n_way):
"""
Compute mean and covariance for each class.
Returns:
means: (n_way, D)
covariances: (n_way, D, D)
"""
D = support_embeddings.size(1)
device = support_embeddings.device
means = torch.zeros(n_way, D, device=device)
covariances = torch.zeros(n_way, D, D, device=device)
for k in range(n_way):
class_mask = (support_labels == k)
class_embeddings = support_embeddings[class_mask] # (K, D)
# Mean
means[k] = class_embeddings.mean(dim=0)
# Covariance
centered = class_embeddings - means[k].unsqueeze(0)
cov = torch.matmul(centered.T, centered) / class_embeddings.size(0)
# Regularize (add identity to ensure invertibility)
cov = cov + self.regularization * torch.eye(D, device=device)
covariances[k] = cov
return means, covariances
def mahalanobis_distance(self, x, means, covariances):
"""
Compute Mahalanobis distance from x to each Gaussian.
Args:
x: (N, D)
means: (K, D)
covariances: (K, D, D)
Returns:
distances: (N, K)
"""
N, D = x.shape
K = means.size(0)
distances = torch.zeros(N, K, device=x.device)
for k in range(K):
# (x - ΞΌ_k)^T Ξ£_k^{-1} (x - ΞΌ_k)
diff = x - means[k].unsqueeze(0) # (N, D)
inv_cov = torch.inverse(covariances[k]) # (D, D)
# Vectorized computation
temp = torch.matmul(diff, inv_cov) # (N, D)
distances[:, k] = (temp * diff).sum(dim=1) # (N,)
return distances
def forward(self, support_images, support_labels, query_images, n_way, n_shot):
support_embeddings = self.embedding_net(support_images)
query_embeddings = self.embedding_net(query_images)
means, covariances = self.compute_gaussian_prototypes(
support_embeddings, support_labels, n_way
)
distances = self.mahalanobis_distance(query_embeddings, means, covariances)
logits = -distances
return logits
# ============================================================================
# 5. Transductive Prototypical Networks
# ============================================================================
class TransductivePrototypicalNetwork(nn.Module):
"""
Transductive inference: use query set to refine prototypes.
Iteratively update prototypes using pseudo-labels on queries.
"""
def __init__(self, embedding_net, n_iterations=3, distance='squared_euclidean'):
super().__init__()
self.embedding_net = embedding_net
self.n_iterations = n_iterations
self.distance = distance
def forward(self, support_images, support_labels, query_images, n_way, n_shot):
# Initial embeddings
support_embeddings = self.embedding_net(support_images)
query_embeddings = self.embedding_net(query_images)
# Initialize prototypes from support only
prototypes = torch.zeros(n_way, support_embeddings.size(1),
device=support_embeddings.device)
for k in range(n_way):
prototypes[k] = support_embeddings[support_labels == k].mean(dim=0)
# Iterative refinement
for iteration in range(self.n_iterations):
# Compute distances and soft assignments
distances = DistanceMetric.pairwise_distances(
query_embeddings, prototypes, distance=self.distance
)
logits = -distances
probs = F.softmax(logits, dim=1) # (n_query, n_way)
# Update prototypes using weighted query embeddings
for k in range(n_way):
# Support contribution
support_mean = support_embeddings[support_labels == k].mean(dim=0)
# Query contribution (weighted by probability)
query_weights = probs[:, k].unsqueeze(1) # (n_query, 1)
query_contribution = (query_embeddings * query_weights).sum(dim=0)
query_contribution /= (query_weights.sum() + 1e-8)
# Combine (equal weight to support and query)
prototypes[k] = 0.5 * support_mean + 0.5 * query_contribution
# Final classification
distances = DistanceMetric.pairwise_distances(
query_embeddings, prototypes, distance=self.distance
)
return -distances
# ============================================================================
# 6. Episode Sampler
# ============================================================================
class EpisodeSampler:
"""
Sample N-way K-shot episodes for meta-learning.
"""
def __init__(self, data, labels, n_way, n_shot, n_query):
"""
Args:
data: All training images
labels: All training labels
n_way: Number of classes per episode
n_shot: Number of support examples per class
n_query: Number of query examples per class
"""
self.data = data
self.labels = labels
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
# Organize data by class
self.classes = np.unique(labels)
self.class_to_indices = {
c: np.where(labels == c)[0] for c in self.classes
}
def sample_episode(self):
"""
Sample one episode.
Returns:
support_images, support_labels, query_images, query_labels
"""
# Sample N classes
episode_classes = np.random.choice(self.classes, self.n_way, replace=False)
support_images = []
support_labels = []
query_images = []
query_labels = []
for i, c in enumerate(episode_classes):
# Get all indices for this class
class_indices = self.class_to_indices[c]
# Sample K + M examples
sampled_indices = np.random.choice(
class_indices, self.n_shot + self.n_query, replace=False
)
# Split into support and query
support_idx = sampled_indices[:self.n_shot]
query_idx = sampled_indices[self.n_shot:]
support_images.append(self.data[support_idx])
support_labels.extend([i] * self.n_shot) # Relabel to 0, 1, ..., N-1
query_images.append(self.data[query_idx])
query_labels.extend([i] * self.n_query)
support_images = np.concatenate(support_images, axis=0)
query_images = np.concatenate(query_images, axis=0)
support_labels = np.array(support_labels)
query_labels = np.array(query_labels)
return (
torch.from_numpy(support_images).float(),
torch.from_numpy(support_labels).long(),
torch.from_numpy(query_images).float(),
torch.from_numpy(query_labels).long()
)
# ============================================================================
# 7. Training Loop
# ============================================================================
class PrototypicalTrainer:
"""
Trainer for Prototypical Networks with episodic training.
"""
def __init__(self, model, optimizer, device='cpu'):
self.model = model.to(device)
self.optimizer = optimizer
self.device = device
def train_episode(self, support_images, support_labels,
query_images, query_labels, n_way, n_shot):
"""Train on one episode."""
self.model.train()
# Move to device
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)
# Forward pass
logits = self.model(support_images, support_labels,
query_images, n_way, n_shot)
# Compute loss
loss = self.model.loss(logits, query_labels)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Accuracy
predictions = logits.argmax(dim=1)
accuracy = (predictions == query_labels).float().mean()
return {
'loss': loss.item(),
'accuracy': accuracy.item()
}
def evaluate_episode(self, support_images, support_labels,
query_images, query_labels, n_way, n_shot):
"""Evaluate on one episode."""
self.model.eval()
with torch.no_grad():
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)
logits = self.model(support_images, support_labels,
query_images, n_way, n_shot)
loss = self.model.loss(logits, query_labels)
predictions = logits.argmax(dim=1)
accuracy = (predictions == query_labels).float().mean()
return {
'loss': loss.item(),
'accuracy': accuracy.item()
}
# ============================================================================
# Demonstrations
# ============================================================================
print("=" * 70)
print("Prototypical Networks - Advanced Implementations")
print("=" * 70)
# 1. Distance metrics comparison
print("\n1. Distance Metrics:")
x = torch.randn(5, 10) # 5 points in 10D
y = torch.randn(3, 10) # 3 points in 10D
euclidean = DistanceMetric.pairwise_distances(x, y, 'euclidean')
squared = DistanceMetric.pairwise_distances(x, y, 'squared_euclidean')
cosine = DistanceMetric.pairwise_distances(x, y, 'cosine')
print(f" Points: {x.shape} vs {y.shape}")
print(f" Euclidean distances: {euclidean.shape}")
print(f" Squared Euclidean: {squared.shape}")
print(f" Cosine distances: {cosine.shape}")
print(f" Property: squared = euclideanΒ²: {torch.allclose(squared, euclidean**2)}")
# 2. Embedding networks
print("\n2. Embedding Networks:")
conv_net = ConvEmbedding(in_channels=3, embedding_dim=64)
resnet = ResNet12Embedding(in_channels=3, embedding_dim=640)
x_img = torch.randn(4, 3, 84, 84)
emb_conv = conv_net(x_img)
emb_resnet = resnet(x_img)
print(f" Input images: {x_img.shape}")
print(f" Conv embedding: {emb_conv.shape} (64D)")
print(f" ResNet-12 embedding: {emb_resnet.shape} (640D)")
print(f" ")
print(f" Conv-4 parameters: {sum(p.numel() for p in conv_net.parameters()):,}")
print(f" ResNet-12 parameters: {sum(p.numel() for p in resnet.parameters()):,}")
print(f" ResNet-12 ~30Γ more parameters β better performance")
# 3. Prototypical network
print("\n3. Standard Prototypical Network:")
proto_net = PrototypicalNetwork(conv_net, distance='squared_euclidean')
# Simulate 5-way 5-shot episode
n_way, n_shot, n_query = 5, 5, 10
support_images = torch.randn(n_way * n_shot, 3, 84, 84)
support_labels = torch.arange(n_way).repeat_interleave(n_shot)
query_images = torch.randn(n_way * n_query, 3, 84, 84)
logits = proto_net(support_images, support_labels, query_images, n_way, n_shot)
print(f" Task: {n_way}-way {n_shot}-shot")
print(f" Support: {support_images.shape} ({n_way} classes Γ {n_shot} shots)")
print(f" Query: {query_images.shape} ({n_way * n_query} examples)")
print(f" Logits: {logits.shape} (queries Γ classes)")
print(f" ")
print(f" Prototypes: Mean of {n_shot} embeddings per class")
print(f" Classification: Nearest prototype (min distance)")
# 4. Gaussian prototypical
print("\n4. Gaussian Prototypical Network:")
gauss_net = GaussianPrototypicalNetwork(conv_net, regularization=1e-3)
logits_gauss = gauss_net(support_images, support_labels, query_images, n_way, n_shot)
print(f" Enhancement: Models class covariance, not just mean")
print(f" Distance: Mahalanobis dΒ²(x,ΞΌ) = (x-ΞΌ)α΅Ξ£β»ΒΉ(x-ΞΌ)")
print(f" Regularization: Ξ£ β Ξ£ + Ξ»I (ensures invertibility)")
print(f" Logits: {logits_gauss.shape}")
print(f" Benefit: Better when classes have different variances")
# 5. Transductive inference
print("\n5. Transductive Prototypical Network:")
trans_net = TransductivePrototypicalNetwork(conv_net, n_iterations=3)
logits_trans = trans_net(support_images, support_labels, query_images, n_way, n_shot)
print(f" Refinement: Uses query set to improve prototypes")
print(f" Iterations: 3 (pseudo-label β update prototypes)")
print(f" Algorithm:")
print(f" 1. Initial prototypes from support")
print(f" 2. Soft-assign queries to prototypes")
print(f" 3. Update prototypes with weighted queries")
print(f" 4. Repeat")
print(f" Benefit: +3-5% accuracy on standard benchmarks")
# 6. Episode sampling
print("\n6. Episode Sampler:")
# Dummy data
dummy_data = np.random.randn(1000, 3, 84, 84)
dummy_labels = np.random.randint(0, 100, 1000)
sampler = EpisodeSampler(dummy_data, dummy_labels, n_way=5, n_shot=5, n_query=15)
s_img, s_lbl, q_img, q_lbl = sampler.sample_episode()
print(f" Dataset: {len(dummy_data)} images, {len(np.unique(dummy_labels))} classes")
print(f" Episode: {n_way}-way {n_shot}-shot, {sampler.n_query} queries/class")
print(f" Support: {s_img.shape}, labels {s_lbl.shape}")
print(f" Query: {q_img.shape}, labels {q_lbl.shape}")
print(f" Label mapping: Original classes β [0, 1, 2, 3, 4]")
# 7. Complexity analysis
print("\n7. Computational Complexity:")
print(" Forward pass breakdown (5-way 5-shot, 15 queries/class):")
print(" ββββββββββββββββββββββββββ¬βββββββββββββββ¬ββββββββββββββ")
print(" β Operation β Cost β Comment β")
print(" ββββββββββββββββββββββββββΌβββββββββββββββΌββββββββββββββ€")
print(" β Embed support (25) β 25 Γ f(x) β Once β")
print(" β Embed queries (75) β 75 Γ f(x) β Once β")
print(" β Compute prototypes (5) β O(25 Γ D) β Mean β")
print(" β Distances (75 Γ 5) β O(375 Γ D) β Pairwise β")
print(" β Softmax (75) β O(375) β Fast β")
print(" ββββββββββββββββββββββββββ΄βββββββββββββββ΄ββββββββββββββ")
print(" Dominant cost: Embedding network f(x)")
print(" Prototypes add negligible overhead!")
# 8. Method comparison
print("\n8. Comparison with Other Meta-Learning Methods:")
print(" ββββββββββββββββββββ¬βββββββββββ¬βββββββββββββ¬βββββββββββ¬ββββββββββ")
print(" β Method β Inner β Complexity β Speed β Perf β")
print(" ββββββββββββββββββββΌβββββββββββΌβββββββββββββΌβββββββββββΌββββββββββ€")
print(" β Prototypical β None β O(KΓD) β Fast β Good β")
print(" β Matching Nets β None β O(KΒ²ΓD) β Medium β Good β")
print(" β Relation Net β None β O(KΓDΒ²) β Medium β Better β")
print(" β MAML β Gradient β O(KΓBΓDΒ²) β Slow β Best β")
print(" ββββββββββββββββββββ΄βββββββββββ΄βββββββββββββ΄βββββββββββ΄ββββββββββ")
print(" K: shots, D: embedding dim, B: inner steps")
print(" ")
print(" Prototypical wins on: Speed, simplicity, robustness")
print(" MAML wins on: Flexibility, performance (with enough data)")
# 9. When to use guide
print("\n9. When to Use Prototypical Networks:")
print(" Use Prototypical Networks when:")
print(" β Few examples per class (K=1-10)")
print(" β Need fast inference (real-time systems)")
print(" β Limited meta-training data (<1000 classes)")
print(" β Interpretability important (prototypes = class centers)")
print(" β Classes well-separated in embedding space")
print("\n Use MAML instead when:")
print(" β Have lots of meta-training data")
print(" β Complex task-specific adaptation needed")
print(" β Compute not a bottleneck")
print("\n Use Relation Networks when:")
print(" β Need learned similarity metric")
print(" β Fixed distance (Euclidean) insufficient")
# 10. Performance expectations
print("\n10. Expected Performance (5-way accuracy):")
print(" miniImageNet benchmark:")
print(" βββββββββββββββββββ¬βββββββββββ¬βββββββββββ")
print(" β Method β 1-shot β 5-shot β")
print(" βββββββββββββββββββΌβββββββββββΌβββββββββββ€")
print(" β Proto (Conv-4) β 49.4% β 68.2% β")
print(" β Proto (ResNet) β 56.5% β 73.7% β")
print(" β + Pre-training β 60-63% β 75-79% β")
print(" β + Transductive β +3-5% β +3-5% β")
print(" βββββββββββββββββββ΄βββββββββββ΄βββββββββββ")
print(" ")
print(" Key insights:")
print(" β’ Better embedding >> better meta-algorithm")
print(" β’ ResNet-12 adds ~7% over Conv-4")
print(" β’ Pre-training adds ~5-10%")
print(" β’ 5-shot >> 1-shot (~20% absolute gain)")
print("\n" + "=" * 70)