import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Few-Shot ClassificationΒΆ

N-way K-shotΒΆ

  • N classes with K examples each

  • Support set: \(S = \{(x_i, y_i)\}_{i=1}^{NK}\)

  • Query: classify new examples

Prototypical Networks IdeaΒΆ

  1. Embed examples: \(f_\theta: \mathbb{R}^d \to \mathbb{R}^m\)

  2. Compute class prototypes (means)

  3. Classify by nearest prototype

πŸ“š Reference Materials:

2. AlgorithmΒΆ

Class PrototypeΒΆ

\[c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\theta(x_i)\]

ClassificationΒΆ

\[p(y=k|x) = \frac{\exp(-d(f_\theta(x), c_k))}{\sum_{k'} \exp(-d(f_\theta(x), c_{k'}))}\]

where \(d(\cdot, \cdot)\) is distance (e.g., Euclidean).

class ConvEmbedding(nn.Module):
    """Simple CNN for embedding."""
    
    def __init__(self, in_channels=1, hidden_dim=64, output_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Flatten(),
            nn.Linear(hidden_dim * 3 * 3, output_dim)
        )
    
    def forward(self, x):
        return self.encoder(x)

def compute_prototypes(embeddings, labels, n_classes):
    """Compute class prototypes."""
    prototypes = torch.zeros(n_classes, embeddings.size(1), device=embeddings.device)
    for k in range(n_classes):
        mask = labels == k
        prototypes[k] = embeddings[mask].mean(dim=0)
    return prototypes

def euclidean_distance(x, y):
    """Compute pairwise Euclidean distances."""
    return torch.cdist(x, y, p=2)

# Test
model = ConvEmbedding().to(device)
x_test = torch.randn(5, 1, 28, 28).to(device)
emb = model(x_test)
print(f"Embedding shape: {emb.shape}")

Episodic TrainingΒΆ

Prototypical networks are trained episodically: each training iteration samples a random \(N\)-way, \(K\)-shot task from the training classes. The support set (K examples per class) is used to compute class prototypes via averaging in embedding space, and the query set is classified by nearest-prototype distance. The training loss is the negative log-probability of the correct class under the softmax over distances. This episodic protocol mimics the few-shot evaluation setting during training, ensuring the learned embedding space is well-suited for prototype-based classification at test time.

def sample_episode(data, labels, n_way, k_shot, k_query):
    """Sample N-way K-shot episode."""
    # Sample N classes
    classes = np.random.choice(len(np.unique(labels)), n_way, replace=False)
    
    support_x, support_y = [], []
    query_x, query_y = [], []
    
    for i, c in enumerate(classes):
        # Get examples from class c
        idx = np.where(labels == c)[0]
        samples = np.random.choice(idx, k_shot + k_query, replace=False)
        
        support_x.append(data[samples[:k_shot]])
        support_y.extend([i] * k_shot)
        
        query_x.append(data[samples[k_shot:]])
        query_y.extend([i] * k_query)
    
    support_x = torch.cat(support_x)
    support_y = torch.tensor(support_y)
    query_x = torch.cat(query_x)
    query_y = torch.tensor(query_y)
    
    return support_x, support_y, query_x, query_y

# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)

# Convert to tensors
train_data = mnist.data.unsqueeze(1).float() / 255.0
train_labels = mnist.targets.numpy()

print(f"Dataset: {train_data.shape}")
def train_prototypical(model, data, labels, n_episodes=1000, n_way=5, k_shot=5, k_query=15):
    """Train prototypical network."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    losses = []
    
    for episode in range(n_episodes):
        # Sample episode
        support_x, support_y, query_x, query_y = sample_episode(
            data, labels, n_way, k_shot, k_query
        )
        
        support_x = support_x.to(device)
        support_y = support_y.to(device)
        query_x = query_x.to(device)
        query_y = query_y.to(device)
        
        # Embed
        support_emb = model(support_x)
        query_emb = model(query_x)
        
        # Prototypes
        prototypes = compute_prototypes(support_emb, support_y, n_way)
        
        # Distances
        dists = euclidean_distance(query_emb, prototypes)
        
        # Loss
        log_probs = F.log_softmax(-dists, dim=1)
        loss = F.nll_loss(log_probs, query_y)
        
        # Update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (episode + 1) % 200 == 0:
            acc = (log_probs.argmax(dim=1) == query_y).float().mean()
            print(f"Episode {episode+1}, Loss: {loss.item():.4f}, Acc: {acc:.3f}")
    
    return losses

# Train
model = ConvEmbedding().to(device)
losses = train_prototypical(model, train_data, train_labels, n_episodes=1000, 
                           n_way=5, k_shot=5, k_query=15)

plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(50)/50, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Episode', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('Prototypical Network Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

EvaluationΒΆ

Evaluating a prototypical network follows the same episodic protocol as training: sample many random \(N\)-way, \(K\)-shot tasks from held-out test classes (classes never seen during training), compute prototypes from the support set, and measure classification accuracy on the query set. Reporting mean accuracy and 95% confidence intervals over many episodes gives a reliable estimate of few-shot performance. The ability to classify novel classes with only a few examples per class – without any gradient updates at test time – is what makes metric-based meta-learning appealing for real-world deployment.

def evaluate_prototypical(model, data, labels, n_episodes=100, n_way=5, k_shot=5, k_query=15):
    """Evaluate prototypical network."""
    model.eval()
    accuracies = []
    
    with torch.no_grad():
        for _ in range(n_episodes):
            support_x, support_y, query_x, query_y = sample_episode(
                data, labels, n_way, k_shot, k_query
            )
            
            support_x = support_x.to(device)
            support_y = support_y.to(device)
            query_x = query_x.to(device)
            query_y = query_y.to(device)
            
            support_emb = model(support_x)
            query_emb = model(query_x)
            
            prototypes = compute_prototypes(support_emb, support_y, n_way)
            dists = euclidean_distance(query_emb, prototypes)
            preds = (-dists).argmax(dim=1)
            
            acc = (preds == query_y).float().mean()
            accuracies.append(acc.item())
    
    return np.array(accuracies)

# Load test data
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_data = mnist_test.data.unsqueeze(1).float() / 255.0
test_labels = mnist_test.targets.numpy()

# Evaluate
accs = evaluate_prototypical(model, test_data, test_labels, n_episodes=200)

print(f"5-way 5-shot accuracy: {accs.mean():.3f} Β± {accs.std():.3f}")

plt.figure(figsize=(10, 5))
plt.hist(accs, bins=20, edgecolor='black', alpha=0.7)
plt.axvline(accs.mean(), color='r', linestyle='--', label=f'Mean: {accs.mean():.3f}')
plt.xlabel('Accuracy', fontsize=11)
plt.ylabel('Count', fontsize=11)
plt.title('Test Episode Accuracies', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

VisualizationΒΆ

Visualizing the embedding space using t-SNE or UMAP reveals whether the learned representations form tight, well-separated clusters for different classes. For a well-trained prototypical network, same-class embeddings should cluster tightly around their prototype, with clear separation between classes. Plotting the prototypes (class means) alongside individual embeddings provides intuition for why nearest-prototype classification works: if the clusters are compact and distant, the decision boundaries are clean and few-shot accuracy will be high.

# Sample episode for visualization
support_x, support_y, query_x, query_y = sample_episode(
    test_data, test_labels, n_way=5, k_shot=5, k_query=10
)

model.eval()
with torch.no_grad():
    support_x = support_x.to(device)
    support_y = support_y.to(device)
    query_x = query_x.to(device)
    query_y = query_y.to(device)
    
    support_emb = model(support_x).cpu().numpy()
    query_emb = model(query_x).cpu().numpy()
    
    prototypes = compute_prototypes(
        torch.from_numpy(support_emb), support_y.cpu(), 5
    ).numpy()

# UMAP for 2D visualization
from sklearn.manifold import TSNE

all_emb = np.vstack([support_emb, prototypes, query_emb])
tsne = TSNE(n_components=2, random_state=42)
emb_2d = tsne.fit_transform(all_emb)

n_support = len(support_emb)
n_proto = len(prototypes)

support_2d = emb_2d[:n_support]
proto_2d = emb_2d[n_support:n_support+n_proto]
query_2d = emb_2d[n_support+n_proto:]

plt.figure(figsize=(12, 10))

colors = plt.cm.tab10(range(5))

# Support
for k in range(5):
    mask = support_y.cpu().numpy() == k
    plt.scatter(support_2d[mask, 0], support_2d[mask, 1], 
               c=[colors[k]], marker='o', s=100, alpha=0.6, label=f'Class {k} Support')

# Prototypes
for k in range(5):
    plt.scatter(proto_2d[k, 0], proto_2d[k, 1], 
               c=[colors[k]], marker='*', s=500, edgecolors='black', linewidths=2)

# Query
for k in range(5):
    mask = query_y.cpu().numpy() == k
    plt.scatter(query_2d[mask, 0], query_2d[mask, 1], 
               c=[colors[k]], marker='x', s=150, alpha=0.8)

plt.xlabel('t-SNE 1', fontsize=11)
plt.ylabel('t-SNE 2', fontsize=11)
plt.title('Prototypical Network Embedding Space', fontsize=12)
plt.legend(ncol=2, fontsize=9)
plt.grid(True, alpha=0.3)
plt.show()

Different MetricsΒΆ

The choice of distance metric in embedding space significantly affects prototypical network performance. The original paper uses squared Euclidean distance, but cosine distance and Mahalanobis distance are common alternatives. Euclidean distance assumes isotropic clusters of similar scale; cosine distance normalizes for magnitude and focuses on direction; Mahalanobis distance accounts for per-dimension variance. Comparing metrics on the same task reveals how the embedding geometry interacts with the distance function, helping practitioners select the best combination for their domain.

def cosine_distance(x, y):
    """Cosine distance."""
    x_norm = F.normalize(x, dim=1)
    y_norm = F.normalize(y, dim=1)
    return 1 - x_norm @ y_norm.T

# Compare metrics
metrics = {
    'Euclidean': euclidean_distance,
    'Cosine': cosine_distance
}

results = {}

for name, metric_fn in metrics.items():
    accs = []
    model.eval()
    
    with torch.no_grad():
        for _ in range(100):
            support_x, support_y, query_x, query_y = sample_episode(
                test_data, test_labels, 5, 5, 15
            )
            
            support_x = support_x.to(device)
            support_y = support_y.to(device)
            query_x = query_x.to(device)
            query_y = query_y.to(device)
            
            support_emb = model(support_x)
            query_emb = model(query_x)
            prototypes = compute_prototypes(support_emb, support_y, 5)
            
            dists = metric_fn(query_emb, prototypes)
            preds = (-dists).argmax(dim=1)
            acc = (preds == query_y).float().mean().item()
            accs.append(acc)
    
    results[name] = accs

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
positions = np.arange(len(metrics))
bp = ax.boxplot([results[name] for name in metrics.keys()], 
                 labels=metrics.keys(), patch_artist=True)

for patch in bp['boxes']:
    patch.set_facecolor('lightblue')

ax.set_ylabel('Accuracy', fontsize=11)
ax.set_title('Distance Metrics Comparison', fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
plt.show()

for name, accs in results.items():
    print(f"{name}: {np.mean(accs):.3f} Β± {np.std(accs):.3f}")

SummaryΒΆ

Prototypical Networks:ΒΆ

Learn embedding where classification is nearest centroid.

Algorithm:ΒΆ

  1. Embed support examples

  2. Compute class prototypes (means)

  3. Classify by distance to prototypes

Advantages:ΒΆ

  • Simple and efficient

  • Interpretable (class centers)

  • Works with any metric

  • Scales to many classes

Applications:ΒΆ

  • Few-shot image classification

  • Zero-shot learning

  • Domain adaptation

  • Cold-start recommendation

Extensions:ΒΆ

  • Relation Networks (learned metric)

  • Matching Networks (attention)

  • TADAM (task conditioning)

Next Steps:ΒΆ

  • 16_maml_meta_learning.ipynb - Optimization-based

  • Study transductive inference

  • Explore semi-supervised episodes

Advanced Prototypical Networks TheoryΒΆ

1. Few-Shot Learning Problem FormulationΒΆ

1.1 Mathematical SetupΒΆ

Few-shot classification: Learn to classify examples with very few labeled samples per class.

N-way K-shot learning:

  • N classes

  • K examples per class (support set)

  • M query examples per class

Support set: S = {(x₁, y₁), …, (xβ‚™β‚–, yβ‚™β‚–)} where |{i : yα΅’ = c}| = K for each class c

Query set: Q = {(x̃₁, ỹ₁), …, (xΜƒβ‚˜, α»Ήβ‚˜)}

Goal: Classify query examples using only K support examples per class

1.2 Episodic TrainingΒΆ

Meta-learning paradigm: Train on many tasks, generalize to new tasks.

Episode structure:

  1. Sample N classes from training set

  2. Sample K examples per class (support)

  3. Sample M examples per class (query)

  4. Train to classify query given support

  5. Repeat for many episodes

Key insight: Training mimics test-time scenario (few examples per class)

2. Prototypical NetworksΒΆ

2.1 Core IdeaΒΆ

Prototype: Representative embedding for each class, computed as mean of support embeddings.

Prototype computation:

c_k = (1/K) Ξ£_{(xα΅’,yα΅’)∈Sβ‚–} f_ΞΈ(xα΅’)

where:

  • c_k: prototype for class k

  • S_k: support examples for class k

  • f_ΞΈ: embedding function (neural network)

Classification rule:

p(y = k | x) = softmax(-d(f_ΞΈ(x), c_k))
            = exp(-d(f_ΞΈ(x), c_k)) / Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))

where d(Β·,Β·) is a distance metric (typically Euclidean).

2.2 Distance MetricsΒΆ

Euclidean distance (most common):

d(x, y) = ||x - y||β‚‚ = √(Ξ£α΅’ (xα΅’ - yα΅’)Β²)

Squared Euclidean:

d(x, y) = ||x - y||β‚‚Β²

Mathematically equivalent (monotonic transformation) but simpler gradients.

Cosine distance:

d(x, y) = 1 - (xΒ·y) / (||x|| ||y||)

Useful when magnitude doesn’t matter.

Mahalanobis distance:

d(x, y) = √((x-y)α΅€ Σ⁻¹ (x-y))

Accounts for correlations between dimensions.

2.3 Loss FunctionΒΆ

Negative log-likelihood:

L(ΞΈ) = -log p(y = y_true | x)
     = -log [exp(-d(f_ΞΈ(x), c_y_true)) / Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))]
     = d(f_ΞΈ(x), c_y_true) + log Ξ£β±Ό exp(-d(f_ΞΈ(x), cβ±Ό))

Intuition:

  • Minimize distance to correct prototype

  • Maximize distance to incorrect prototypes

  • Softmax provides probabilistic interpretation

2.4 Gradient AnalysisΒΆ

Gradient w.r.t. embedding:

βˆ‚L/βˆ‚f_ΞΈ(x) = βˆ‚d(f_ΞΈ(x), c_y_true)/βˆ‚f_ΞΈ(x) - Ξ£β±Ό p(y=j|x) βˆ‚d(f_ΞΈ(x), cβ±Ό)/βˆ‚f_ΞΈ(x)

For Euclidean distance:

βˆ‚d/βˆ‚f_ΞΈ(x) = 2(f_ΞΈ(x) - c_k)

Gradient interpretation:

  • Push embedding towards correct prototype

  • Pull away from incorrect prototypes (weighted by probability)

3. Theoretical FoundationsΒΆ

3.1 Connection to k-Nearest NeighborsΒΆ

Prototypes as cluster centers: Prototypical Networks can be seen as soft k-NN with k=K.

k-NN decision:

ŷ = argmax_k Σ_{(xᡒ,yᡒ)∈S, yᡒ=k} δ(x, xᡒ)

where Ξ΄ is indicator for x_i among k-nearest neighbors.

Prototypical Networks smoothed version:

Ε· = argmax_k p(y=k|x) where prototypes = mean(k-NN)

3.2 Mixture Density EstimationΒΆ

Generative view: Each class modeled as Gaussian distribution.

Class-conditional distribution:

p(x | y=k) = N(x; c_k, σ²I)

Posterior via Bayes rule:

p(y=k | x) = p(x|y=k)p(y=k) / Ξ£β±Ό p(x|y=j)p(y=j)

With uniform prior p(y=k) = 1/N:

p(y=k | x) ∝ exp(-||x - c_k||Β²/(2σ²))

Prototypical Networks: Equivalent to maximum likelihood estimation of mixture model!

3.3 Bregman DivergencesΒΆ

Generalization of distance: Prototypical Networks work with any Bregman divergence.

Bregman divergence:

D_Ο†(x, y) = Ο†(x) - Ο†(y) - βŸ¨βˆ‡Ο†(y), x-y⟩

for convex function Ο†.

Examples:

  • Ο†(x) = ||x||Β²/2 β†’ Squared Euclidean

  • Ο†(x) = Ξ£α΅’ xα΅’ log xα΅’ β†’ KL divergence

  • Ο†(x) = -log det(x) β†’ Log-determinant divergence

Theorem (Snell et al., 2017): Prototypical Networks with Bregman divergence D_Ο† correspond to exponential family distributions with sufficient statistic βˆ‡Ο†(x).

4. Embedding Function DesignΒΆ

4.1 Common ArchitecturesΒΆ

Convolutional Networks (images):

  • 4-layer CNN (Snell et al., 2017 baseline)

  • ResNet-12 (modern standard)

  • WideResNet-28-10 (state-of-the-art)

Architecture pattern:

Conv blocks β†’ Global pooling β†’ Embedding

Output dimension: Typically 64-1600 dimensions

Recurrent Networks (sequences):

  • LSTM/GRU for text

  • Bidirectional encoding

  • Attention mechanisms

Transformers (general):

  • Self-attention for long-range dependencies

  • Pre-trained models (BERT, ViT) as backbone

4.2 Design PrinciplesΒΆ

1. High-dimensional embeddings:

  • Typical: 64-1600 dimensions

  • Higher dimensions β†’ better separation

  • But: overfitting risk with very few samples

2. Normalized embeddings:

  • L2 normalization: x ← x/||x||

  • Converts to cosine similarity metric

  • Prevents magnitude dominating distance

3. Batch normalization:

  • Stabilizes training

  • But: tricky with small support sets

  • Alternative: Layer normalization, Group normalization

4. Pooling strategies:

  • Global average pooling (most common)

  • Max pooling

  • Attention pooling (weighted average)

4.3 Pre-training StrategiesΒΆ

Transfer learning:

  1. Pre-train on large dataset (ImageNet, etc.)

  2. Fine-tune with episodic training

  3. Helps with limited meta-training data

Self-supervised pre-training:

  • Contrastive learning (SimCLR, MoCo)

  • Rotation prediction

  • Jigsaw puzzles

  • Improves embedding quality without labels

5. Advanced VariantsΒΆ

5.1 Gaussian Prototypical NetworksΒΆ

Motivation: Model class variance, not just mean.

Per-class covariance:

c_k = (1/K) Ξ£ f_ΞΈ(xα΅’)         (mean)
Ξ£_k = (1/K) Ξ£ (f_ΞΈ(xα΅’) - c_k)(f_ΞΈ(xα΅’) - c_k)α΅€  (covariance)

Mahalanobis distance:

d(x, k) = (f_ΞΈ(x) - c_k)α΅€ Ξ£_k⁻¹ (f_ΞΈ(x) - c_k)

Challenge: Singular covariance with K < embedding_dim.

Solutions:

  • Diagonal covariance only

  • Regularization: Ξ£_k ← Ξ£_k + Ξ»I

  • Shared covariance across classes

5.2 Semi-Prototypical NetworksΒΆ

Use unlabeled data in support set:

Soft prototypes:

c_k = (1/Z_k) [Σ_{labeled} f_θ(xᡒ) + Σ_{unlabeled} p(y=k|x̃ⱼ) f_θ(x̃ⱼ)]

Iterative refinement:

  1. Compute initial prototypes from labeled data

  2. Pseudo-label unlabeled data

  3. Update prototypes with pseudo-labels

  4. Repeat

Benefit: Better prototypes with limited labels + abundant unlabeled data.

5.3 Transductive Prototypical NetworksΒΆ

Use query set to refine prototypes:

Standard (inductive): Prototypes only from support set.

Transductive: Prototypes from support + query.

Algorithm:

  1. Initialize prototypes from support

  2. Pseudo-label query examples

  3. Update prototypes using query embeddings

  4. Re-classify query

  5. Iterate until convergence

Advantage: Query examples provide more data for better prototypes.

Disadvantage: Not applicable when query arrives one-by-one.

5.4 Task-Adaptive Prototypical NetworksΒΆ

Adapt prototypes per task:

Learnable scaling:

d(x, c_k) = ||W_task(f_ΞΈ(x) - c_k)||Β²

where W_task is task-specific diagonal matrix.

Feature selection:

α_task = softmax(g_ψ(c_1, ..., c_N))  (attention over dimensions)
d(x, c_k) = ||Ξ±_task βŠ™ (f_ΞΈ(x) - c_k)||Β²

Benefit: Different tasks may need different feature dimensions.

6. Training TechniquesΒΆ

6.1 Episodic Sampling StrategiesΒΆ

Uniform sampling:

  • Sample N classes uniformly from training set

  • Simple, balanced

Class-balanced sampling:

  • Oversample rare classes

  • Ensures all classes seen equally

Curriculum learning:

  • Start with easy tasks (more shots, fewer ways)

  • Gradually increase difficulty

  • Faster convergence

Hard task mining:

  • Identify difficult class combinations

  • Oversample hard episodes

  • Improves worst-case performance

6.2 Data AugmentationΒΆ

Standard augmentations:

  • Random crops

  • Horizontal flips

  • Color jitter

  • Rotation

Few-shot specific:

  • Mixing: Average support examples within class

  • Hallucination: Generate synthetic examples

  • Adversarial: Small perturbations for robustness

Augmentation timing:

  • Support set: Usually yes (increases effective K)

  • Query set: Typically no (want clean evaluation)

6.3 OptimizationΒΆ

Learning rate scheduling:

  • Warmup for first few thousand episodes

  • Cosine annealing or step decay

  • Lower learning rate than standard supervised

Gradient clipping:

||βˆ‡ΞΈ|| > threshold β‡’ βˆ‡ΞΈ ← threshold Β· βˆ‡ΞΈ/||βˆ‡ΞΈ||

Prevents instability from variable episode difficulty.

Optimizer choice:

  • Adam (most common, adaptive)

  • SGD with momentum (better generalization)

  • RAdam (combines benefits)

Batch size (episodes per update):

  • Typical: 1-4 episodes

  • Larger batch β†’ more stable but slower

  • Smaller batch β†’ faster iteration

7. Evaluation ProtocolsΒΆ

7.1 Standard BenchmarksΒΆ

Omniglot:

  • 1623 handwritten characters

  • 20 examples per class

  • Task: 5-way 1-shot, 5-way 5-shot, 20-way 1-shot

miniImageNet:

  • 100 classes, 600 examples each

  • 64 train, 16 validation, 20 test classes

  • Task: 5-way 1-shot, 5-way 5-shot

tieredImageNet:

  • 608 classes from ImageNet

  • Hierarchical split (avoids train/test similarity)

  • Task: 5-way 1-shot, 5-way 5-shot

CIFAR-FS:

  • 100 CIFAR-100 classes

  • 64/16/20 train/val/test split

  • Task: 5-way 1-shot, 5-way 5-shot

7.2 Evaluation MetricsΒΆ

Accuracy:

Acc = (1/MΒ·N) Ξ£ 1[Ε·α΅’ = yα΅’]

Averaged over query examples and episodes.

95% confidence intervals:

CI = mean ± 1.96 · std/√(num_episodes)

Typical: 600-10,000 test episodes.

Per-class accuracy: Identify which classes are hard (for analysis).

Calibration: Are predicted probabilities p(y=k|x) well-calibrated?

Expected Calibration Error (ECE):

ECE = Ξ£_b (|B_b|/N) |acc(B_b) - conf(B_b)|

where B_b are prediction bins.

7.3 Cross-Domain EvaluationΒΆ

Test on different domain:

  • Train: miniImageNet

  • Test: CUB-200 (birds), Cars, etc.

Measures: Transferability of learned embedding.

Finding: Prototypical Networks transfer better than metric learning methods (more generalizable distance function).

8. Comparison with Other MethodsΒΆ

8.1 Matching NetworksΒΆ

Attention-based:

p(y=k|x) = Σ_{(xᡒ,yᡒ)∈S, yᡒ=k} a(x, xᡒ) where a = softmax(cos(f(x), f(xᡒ)))

Differences from Prototypical:

  • Attends to individual examples, not prototypes

  • More expressive but more parameters

  • Prototypical often performs similarly with simpler approach

8.2 Relation NetworksΒΆ

Learned metric:

d(x, x') = g_Ο†(concat(f_ΞΈ(x), f_ΞΈ(x')))

where g_Ο† is a learned network.

Prototypical: Fixed distance (Euclidean). Relation: Learned distance (more flexible).

Tradeoff:

  • Relation: Better performance with enough meta-training data

  • Prototypical: More robust with limited meta-training

8.3 MAML (Model-Agnostic Meta-Learning)ΒΆ

Optimization-based:

  • Inner loop: Adapt parameters on support

  • Outer loop: Update initialization

Prototypical: Non-parametric (no inner optimization).

Comparison:

Method           | Inner Loop | Outer Loop      | Speed
─────────────────|────────────|─────────────────|──────
Prototypical     | None       | Update f_ΞΈ      | Fast
MAML             | K gradient | Update ΞΈ_init   | Slow
Relation Net     | None       | Update f_ΞΈ, g_Ο† | Medium

When Prototypical wins: Limited compute, need fast inference. When MAML wins: Complex adaptation required, enough compute.

9. State-of-the-Art ResultsΒΆ

9.1 Benchmark PerformanceΒΆ

miniImageNet (5-way accuracy):

Method                  | 1-shot    | 5-shot
───────────────────────|───────────|──────────
Prototypical (2017)    | 49.4%     | 68.2%
Prototypical + deeper  | 56.5%     | 73.7%
FEAT (2019)            | 55.2%     | 71.5%
MetaOptNet (2019)      | 62.6%     | 78.6%
Meta-Baseline (2020)   | 63.2%     | 79.3%

tieredImageNet (5-way):

Method                  | 1-shot    | 5-shot
───────────────────────|───────────|──────────
Prototypical (2017)    | 53.3%     | 72.7%
Meta-Baseline (2020)   | 68.1%     | 83.7%
FRN (2021)             | 66.5%     | 82.8%

9.2 Key Improvements Over BaselineΒΆ

Better backbones: ResNet-12 β†’ +7-10% accuracy. Pre-training: Self-supervised β†’ +5-8% accuracy. Data augmentation: MixUp, CutMix β†’ +2-4%. Transductive inference: Query refinement β†’ +3-5%.

10. Limitations and ChallengesΒΆ

10.1 Known IssuesΒΆ

1. Domain shift:

  • Training and test classes must be similar

  • Poor cross-domain generalization without adaptation

2. Prototype quality with K=1:

  • Single example may not represent class well

  • Sensitive to outliers and noise

3. Class imbalance:

  • Assumes balanced support sets

  • Real-world often has varying shots per class

4. Computational cost of embeddings:

  • Must embed all support examples at test time

  • Slow for large support sets

5. Fixed distance metric:

  • Euclidean may not be optimal for all tasks

  • Learned metrics can be better but more complex

10.2 Failure ModesΒΆ

Out-of-distribution query: Query very different from support β†’ assigns to nearest prototype (may be wrong).

Solution: Calibration, outlier detection.

Fine-grained classification: When classes are very similar, prototypes overlap.

Solution: Higher-dimensional embeddings, metric learning.

Multi-modal classes: Class has multiple clusters in embedding space.

Solution: Mixture models, multiple prototypes per class.

11. Extensions and ApplicationsΒΆ

11.1 Few-Shot DetectionΒΆ

Object detection with few examples:

  • Extract region proposals

  • Compute prototypes from support boxes

  • Match query regions to prototypes

Challenges: Background class, varying scales.

11.2 Few-Shot SegmentationΒΆ

Semantic segmentation:

  • Prototype per pixel class

  • Support: Segmented images

  • Query: Segment using prototypes

Approach:

For each pixel p in query:
  Embedding z_p = f_ΞΈ(local_patch(p))
  Class = argmin_k d(z_p, c_k)

11.3 Cross-Modal Few-ShotΒΆ

Match across modalities:

  • Image-text: CLIP-style prototypes

  • Audio-visual: Cross-modal embeddings

Joint embedding space:

  • f_image, f_text map to same space

  • Prototypes can be text or image

11.4 Continual Few-Shot LearningΒΆ

Stream of few-shot tasks:

  • Learn task 1, then task 2, etc.

  • Avoid catastrophic forgetting

Approach:

  • Store prototypes for old classes

  • Replay or regularization for stability

12. Implementation Best PracticesΒΆ

12.1 HyperparametersΒΆ

Embedding dimension:

  • Small datasets (Omniglot): 64-128

  • ImageNet-scale: 512-1600

  • Higher β†’ better separation, but overfitting risk

Learning rate:

  • Typical: 1e-3 to 1e-4

  • Warmup for 1k-5k episodes

  • Decay by 0.1 every 20-40k episodes

Number of episodes:

  • Training: 40k-100k episodes

  • Validation: 600 episodes

  • Test: 600-10,000 episodes

Ways and shots:

  • Train on variable N-way, K-shot (e.g., 5-20 way, 1-5 shot)

  • Test on fixed task (e.g., 5-way 1-shot)

12.2 Debugging TipsΒΆ

Check prototype separation:

# Visualize prototypes (e.g., with t-SNE)
# Should see N clusters for N-way task

Embedding norms:

  • Should be similar across classes

  • If not, consider normalization

Sanity checks:

  • 5-way 5-shot should be easier than 1-shot

  • 5-way should be easier than 20-way

  • Training accuracy should exceed test

Common bugs:

  • Leaking test classes into training

  • Not shuffling support/query within episode

  • Incorrect prototype computation (check averaging)

12.3 Computational EfficiencyΒΆ

Pre-compute support embeddings:

# Embed support once per episode, reuse for all queries
support_emb = f_theta(support_images)  # Compute once
prototypes = support_emb.mean(dim=1)   # Average per class

for query in queries:
    query_emb = f_theta(query)
    distances = compute_distance(query_emb, prototypes)

Batched distance computation:

# Compute all query-prototype distances in one operation
# (Q, D) vs (N, D) β†’ (Q, N) distances
distances = torch.cdist(query_emb, prototypes)

Mixed precision:

  • Use FP16 for embedding network

  • Keep distances in FP32 for numerical stability

13. Recent Advances (2020-2024)ΒΆ

13.1 Meta-BaselineΒΆ

Simplified approach:

  • Pre-train with standard cross-entropy

  • Fine-tune last layer only on support

  • Often matches complex meta-learning methods

Insight: Good features + simple adaptation can beat complex meta-learning.

13.2 Distribution CalibrationΒΆ

Problem: Prototypes from few examples have high variance.

Solution: Calibrate using base class statistics.

c_k_calibrated = Ξ±Β·c_k + (1-Ξ±)Β·ΞΌ_base

where ΞΌ_base is mean prototype from base classes.

13.3 Self-Supervision + Few-ShotΒΆ

SimCLR, MoCo for pre-training:

  • Learn general features without labels

  • Then few-shot learning on top

Results: +5-10% accuracy on standard benchmarks.

13.4 Optimal Transport for MatchingΒΆ

Wasserstein distance: Match support and query distributions optimally.

Better than prototypes when:

  • Multi-modal classes

  • Varying shot numbers per class

14. Future DirectionsΒΆ

1. Cross-domain few-shot:

  • Better transfer across very different domains

  • Meta-learning with domain adaptation

2. Task-agnostic meta-learning:

  • Single model for any N-way K-shot combination

  • Currently re-train for each configuration

3. Theoretical understanding:

  • Sample complexity bounds

  • Generalization guarantees for few-shot

4. Efficient architectures:

  • Neural architecture search for few-shot

  • Lightweight models for edge deployment

5. Multi-task few-shot:

  • Learn multiple tasks simultaneously with few examples each

  • Leveraging cross-task structure

15. Key TakeawaysΒΆ

  1. Simplicity wins: Prototypical Networks’ simple mean-based approach often matches complex methods.

  2. Embedding quality is key: Good feature representation (from pre-training or architecture) matters more than meta-learning algorithm.

  3. Episodic training crucial: Training must mimic test scenario (few examples per class).

  4. Trade-offs everywhere:

    • Complexity vs robustness (Prototypical simpler, Relation more flexible)

    • Speed vs performance (Prototypical fast, MAML slow but expressive)

  5. Pre-training helps: Transfer learning or self-supervision improves few-shot significantly.

  6. Distance metric matters: Euclidean works well, but learned metrics can be better for specific domains.

  7. Not just for images: Prototypical Networks work for any domain with meaningful embeddings (text, audio, graphs).

16. ReferencesΒΆ

Foundational:

  • Snell et al. (2017): β€œPrototypical Networks for Few-shot Learning” (NeurIPS)

  • Vinyals et al. (2016): β€œMatching Networks for One Shot Learning”

  • Sung et al. (2018): β€œLearning to Compare: Relation Network for Few-Shot Learning”

Theoretical:

  • Banerjee et al. (2005): β€œClustering with Bregman Divergences”

  • Allen et al. (2019): β€œInfinite Mixture Prototypes for Few-Shot Learning”

Recent advances:

  • Chen et al. (2020): β€œA Closer Look at Few-shot Classification” (Meta-Baseline)

  • Yang et al. (2021): β€œFree Lunch for Few-Shot Learning: Distribution Calibration”

  • Rizve et al. (2021): β€œExploring Complementary Strengths of Invariant and Equivariant Representations”

Applications:

  • Hu et al. (2019): β€œFew-Shot Object Detection”

  • Shaban et al. (2017): β€œOne-Shot Learning for Semantic Segmentation”

# Advanced Prototypical Networks Implementations

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List

# ============================================================================
# 1. Distance Metrics
# ============================================================================

class DistanceMetric:
    """Base class for distance metrics."""
    
    @staticmethod
    def euclidean(x, y):
        """Euclidean distance: ||x - y||_2"""
        return torch.norm(x - y, p=2, dim=-1)
    
    @staticmethod
    def squared_euclidean(x, y):
        """Squared Euclidean: ||x - y||_2^2"""
        return torch.sum((x - y) ** 2, dim=-1)
    
    @staticmethod
    def cosine(x, y):
        """Cosine distance: 1 - (xΒ·y)/(||x|| ||y||)"""
        x_norm = F.normalize(x, p=2, dim=-1)
        y_norm = F.normalize(y, p=2, dim=-1)
        return 1 - torch.sum(x_norm * y_norm, dim=-1)
    
    @staticmethod
    def pairwise_distances(x, y, distance='euclidean'):
        """
        Compute pairwise distances between all pairs.
        
        Args:
            x: (N, D) - N points in D dimensions
            y: (M, D) - M points in D dimensions
            distance: 'euclidean', 'squared_euclidean', or 'cosine'
        
        Returns:
            distances: (N, M) - distance from each x to each y
        """
        if distance == 'euclidean':
            # ||x - y||_2 = sqrt(||x||^2 - 2xΒ·y + ||y||^2)
            x_norm = (x ** 2).sum(dim=1, keepdim=True)  # (N, 1)
            y_norm = (y ** 2).sum(dim=1, keepdim=True)  # (M, 1)
            dist = x_norm + y_norm.T - 2 * torch.matmul(x, y.T)
            return torch.sqrt(torch.clamp(dist, min=1e-12))
        
        elif distance == 'squared_euclidean':
            x_norm = (x ** 2).sum(dim=1, keepdim=True)
            y_norm = (y ** 2).sum(dim=1, keepdim=True)
            return x_norm + y_norm.T - 2 * torch.matmul(x, y.T)
        
        elif distance == 'cosine':
            x_norm = F.normalize(x, p=2, dim=1)
            y_norm = F.normalize(y, p=2, dim=1)
            return 1 - torch.matmul(x_norm, y_norm.T)
        
        else:
            raise ValueError(f"Unknown distance: {distance}")


# ============================================================================
# 2. Embedding Networks
# ============================================================================

class ConvBlock(nn.Module):
    """Convolutional block for image embedding."""
    
    def __init__(self, in_channels, out_channels, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2) if pool else None
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        if self.pool is not None:
            x = self.pool(x)
        return x


class ConvEmbedding(nn.Module):
    """
    4-layer CNN for image embedding (Snell et al., 2017 baseline).
    """
    
    def __init__(self, in_channels=3, hidden_dim=64, embedding_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            ConvBlock(in_channels, hidden_dim, pool=True),     # 64x64 -> 32x32
            ConvBlock(hidden_dim, hidden_dim, pool=True),      # 32x32 -> 16x16
            ConvBlock(hidden_dim, hidden_dim, pool=True),      # 16x16 -> 8x8
            ConvBlock(hidden_dim, embedding_dim, pool=True),   # 8x8 -> 4x4
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
    
    def forward(self, x):
        """
        Args:
            x: (batch, channels, H, W)
        Returns:
            embeddings: (batch, embedding_dim)
        """
        x = self.encoder(x)
        x = self.pool(x)
        return x.view(x.size(0), -1)


class ResNetBlock(nn.Module):
    """Residual block for deeper networks."""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet12Embedding(nn.Module):
    """
    ResNet-12 for few-shot learning (modern standard).
    Much better than 4-layer CNN.
    """
    
    def __init__(self, in_channels=3, embedding_dim=640):
        super().__init__()
        self.layer1 = nn.Sequential(
            ResNetBlock(in_channels, 64),
            ResNetBlock(64, 64),
            ResNetBlock(64, 64),
            nn.MaxPool2d(2)
        )
        self.layer2 = nn.Sequential(
            ResNetBlock(64, 128),
            ResNetBlock(128, 128),
            ResNetBlock(128, 128),
            nn.MaxPool2d(2)
        )
        self.layer3 = nn.Sequential(
            ResNetBlock(128, 256),
            ResNetBlock(256, 256),
            ResNetBlock(256, 256),
            nn.MaxPool2d(2)
        )
        self.layer4 = nn.Sequential(
            ResNetBlock(256, 512),
            ResNetBlock(512, 512),
            ResNetBlock(512, 512),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(512, embedding_dim)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


# ============================================================================
# 3. Prototypical Networks
# ============================================================================

class PrototypicalNetwork(nn.Module):
    """
    Prototypical Networks for few-shot classification.
    
    Computes class prototypes as mean of support embeddings,
    classifies queries by distance to prototypes.
    """
    
    def __init__(self, embedding_net, distance='squared_euclidean'):
        super().__init__()
        self.embedding_net = embedding_net
        self.distance = distance
    
    def compute_prototypes(self, support_embeddings, support_labels, n_way):
        """
        Compute class prototypes as mean of support embeddings.
        
        Args:
            support_embeddings: (n_support, embedding_dim)
            support_labels: (n_support,) - class labels 0 to n_way-1
            n_way: number of classes
        
        Returns:
            prototypes: (n_way, embedding_dim)
        """
        prototypes = torch.zeros(n_way, support_embeddings.size(1),
                                device=support_embeddings.device,
                                dtype=support_embeddings.dtype)
        
        for k in range(n_way):
            # Get all embeddings for class k
            class_mask = (support_labels == k)
            class_embeddings = support_embeddings[class_mask]
            # Prototype = mean of class embeddings
            prototypes[k] = class_embeddings.mean(dim=0)
        
        return prototypes
    
    def forward(self, support_images, support_labels, query_images, n_way, n_shot):
        """
        Args:
            support_images: (n_way * n_shot, C, H, W)
            support_labels: (n_way * n_shot,)
            query_images: (n_query, C, H, W)
            n_way: number of classes
            n_shot: number of examples per class
        
        Returns:
            logits: (n_query, n_way) - negative distances (log probabilities)
        """
        # Embed support and query
        support_embeddings = self.embedding_net(support_images)
        query_embeddings = self.embedding_net(query_images)
        
        # Compute prototypes
        prototypes = self.compute_prototypes(support_embeddings, support_labels, n_way)
        
        # Compute distances from queries to prototypes
        distances = DistanceMetric.pairwise_distances(
            query_embeddings, prototypes, distance=self.distance
        )
        
        # Convert to logits (negative distances for softmax)
        logits = -distances
        
        return logits
    
    def loss(self, logits, query_labels):
        """
        Cross-entropy loss.
        
        Args:
            logits: (n_query, n_way)
            query_labels: (n_query,)
        
        Returns:
            loss: scalar
        """
        return F.cross_entropy(logits, query_labels)


# ============================================================================
# 4. Gaussian Prototypical Networks
# ============================================================================

class GaussianPrototypicalNetwork(nn.Module):
    """
    Prototypical Networks with Gaussian class models.
    Uses Mahalanobis distance with learned covariance.
    """
    
    def __init__(self, embedding_net, regularization=1e-4):
        super().__init__()
        self.embedding_net = embedding_net
        self.regularization = regularization
    
    def compute_gaussian_prototypes(self, support_embeddings, support_labels, n_way):
        """
        Compute mean and covariance for each class.
        
        Returns:
            means: (n_way, D)
            covariances: (n_way, D, D)
        """
        D = support_embeddings.size(1)
        device = support_embeddings.device
        
        means = torch.zeros(n_way, D, device=device)
        covariances = torch.zeros(n_way, D, D, device=device)
        
        for k in range(n_way):
            class_mask = (support_labels == k)
            class_embeddings = support_embeddings[class_mask]  # (K, D)
            
            # Mean
            means[k] = class_embeddings.mean(dim=0)
            
            # Covariance
            centered = class_embeddings - means[k].unsqueeze(0)
            cov = torch.matmul(centered.T, centered) / class_embeddings.size(0)
            
            # Regularize (add identity to ensure invertibility)
            cov = cov + self.regularization * torch.eye(D, device=device)
            covariances[k] = cov
        
        return means, covariances
    
    def mahalanobis_distance(self, x, means, covariances):
        """
        Compute Mahalanobis distance from x to each Gaussian.
        
        Args:
            x: (N, D)
            means: (K, D)
            covariances: (K, D, D)
        
        Returns:
            distances: (N, K)
        """
        N, D = x.shape
        K = means.size(0)
        
        distances = torch.zeros(N, K, device=x.device)
        
        for k in range(K):
            # (x - ΞΌ_k)^T Ξ£_k^{-1} (x - ΞΌ_k)
            diff = x - means[k].unsqueeze(0)  # (N, D)
            inv_cov = torch.inverse(covariances[k])  # (D, D)
            
            # Vectorized computation
            temp = torch.matmul(diff, inv_cov)  # (N, D)
            distances[:, k] = (temp * diff).sum(dim=1)  # (N,)
        
        return distances
    
    def forward(self, support_images, support_labels, query_images, n_way, n_shot):
        support_embeddings = self.embedding_net(support_images)
        query_embeddings = self.embedding_net(query_images)
        
        means, covariances = self.compute_gaussian_prototypes(
            support_embeddings, support_labels, n_way
        )
        
        distances = self.mahalanobis_distance(query_embeddings, means, covariances)
        logits = -distances
        
        return logits


# ============================================================================
# 5. Transductive Prototypical Networks
# ============================================================================

class TransductivePrototypicalNetwork(nn.Module):
    """
    Transductive inference: use query set to refine prototypes.
    Iteratively update prototypes using pseudo-labels on queries.
    """
    
    def __init__(self, embedding_net, n_iterations=3, distance='squared_euclidean'):
        super().__init__()
        self.embedding_net = embedding_net
        self.n_iterations = n_iterations
        self.distance = distance
    
    def forward(self, support_images, support_labels, query_images, n_way, n_shot):
        # Initial embeddings
        support_embeddings = self.embedding_net(support_images)
        query_embeddings = self.embedding_net(query_images)
        
        # Initialize prototypes from support only
        prototypes = torch.zeros(n_way, support_embeddings.size(1),
                                device=support_embeddings.device)
        for k in range(n_way):
            prototypes[k] = support_embeddings[support_labels == k].mean(dim=0)
        
        # Iterative refinement
        for iteration in range(self.n_iterations):
            # Compute distances and soft assignments
            distances = DistanceMetric.pairwise_distances(
                query_embeddings, prototypes, distance=self.distance
            )
            logits = -distances
            probs = F.softmax(logits, dim=1)  # (n_query, n_way)
            
            # Update prototypes using weighted query embeddings
            for k in range(n_way):
                # Support contribution
                support_mean = support_embeddings[support_labels == k].mean(dim=0)
                
                # Query contribution (weighted by probability)
                query_weights = probs[:, k].unsqueeze(1)  # (n_query, 1)
                query_contribution = (query_embeddings * query_weights).sum(dim=0)
                query_contribution /= (query_weights.sum() + 1e-8)
                
                # Combine (equal weight to support and query)
                prototypes[k] = 0.5 * support_mean + 0.5 * query_contribution
        
        # Final classification
        distances = DistanceMetric.pairwise_distances(
            query_embeddings, prototypes, distance=self.distance
        )
        return -distances


# ============================================================================
# 6. Episode Sampler
# ============================================================================

class EpisodeSampler:
    """
    Sample N-way K-shot episodes for meta-learning.
    """
    
    def __init__(self, data, labels, n_way, n_shot, n_query):
        """
        Args:
            data: All training images
            labels: All training labels
            n_way: Number of classes per episode
            n_shot: Number of support examples per class
            n_query: Number of query examples per class
        """
        self.data = data
        self.labels = labels
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        
        # Organize data by class
        self.classes = np.unique(labels)
        self.class_to_indices = {
            c: np.where(labels == c)[0] for c in self.classes
        }
    
    def sample_episode(self):
        """
        Sample one episode.
        
        Returns:
            support_images, support_labels, query_images, query_labels
        """
        # Sample N classes
        episode_classes = np.random.choice(self.classes, self.n_way, replace=False)
        
        support_images = []
        support_labels = []
        query_images = []
        query_labels = []
        
        for i, c in enumerate(episode_classes):
            # Get all indices for this class
            class_indices = self.class_to_indices[c]
            
            # Sample K + M examples
            sampled_indices = np.random.choice(
                class_indices, self.n_shot + self.n_query, replace=False
            )
            
            # Split into support and query
            support_idx = sampled_indices[:self.n_shot]
            query_idx = sampled_indices[self.n_shot:]
            
            support_images.append(self.data[support_idx])
            support_labels.extend([i] * self.n_shot)  # Relabel to 0, 1, ..., N-1
            
            query_images.append(self.data[query_idx])
            query_labels.extend([i] * self.n_query)
        
        support_images = np.concatenate(support_images, axis=0)
        query_images = np.concatenate(query_images, axis=0)
        support_labels = np.array(support_labels)
        query_labels = np.array(query_labels)
        
        return (
            torch.from_numpy(support_images).float(),
            torch.from_numpy(support_labels).long(),
            torch.from_numpy(query_images).float(),
            torch.from_numpy(query_labels).long()
        )


# ============================================================================
# 7. Training Loop
# ============================================================================

class PrototypicalTrainer:
    """
    Trainer for Prototypical Networks with episodic training.
    """
    
    def __init__(self, model, optimizer, device='cpu'):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
    
    def train_episode(self, support_images, support_labels, 
                     query_images, query_labels, n_way, n_shot):
        """Train on one episode."""
        self.model.train()
        
        # Move to device
        support_images = support_images.to(self.device)
        support_labels = support_labels.to(self.device)
        query_images = query_images.to(self.device)
        query_labels = query_labels.to(self.device)
        
        # Forward pass
        logits = self.model(support_images, support_labels, 
                          query_images, n_way, n_shot)
        
        # Compute loss
        loss = self.model.loss(logits, query_labels)
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        # Accuracy
        predictions = logits.argmax(dim=1)
        accuracy = (predictions == query_labels).float().mean()
        
        return {
            'loss': loss.item(),
            'accuracy': accuracy.item()
        }
    
    def evaluate_episode(self, support_images, support_labels,
                        query_images, query_labels, n_way, n_shot):
        """Evaluate on one episode."""
        self.model.eval()
        
        with torch.no_grad():
            support_images = support_images.to(self.device)
            support_labels = support_labels.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device)
            
            logits = self.model(support_images, support_labels,
                              query_images, n_way, n_shot)
            
            loss = self.model.loss(logits, query_labels)
            predictions = logits.argmax(dim=1)
            accuracy = (predictions == query_labels).float().mean()
        
        return {
            'loss': loss.item(),
            'accuracy': accuracy.item()
        }


# ============================================================================
# Demonstrations
# ============================================================================

print("=" * 70)
print("Prototypical Networks - Advanced Implementations")
print("=" * 70)

# 1. Distance metrics comparison
print("\n1. Distance Metrics:")
x = torch.randn(5, 10)  # 5 points in 10D
y = torch.randn(3, 10)  # 3 points in 10D

euclidean = DistanceMetric.pairwise_distances(x, y, 'euclidean')
squared = DistanceMetric.pairwise_distances(x, y, 'squared_euclidean')
cosine = DistanceMetric.pairwise_distances(x, y, 'cosine')

print(f"   Points: {x.shape} vs {y.shape}")
print(f"   Euclidean distances: {euclidean.shape}")
print(f"   Squared Euclidean: {squared.shape}")
print(f"   Cosine distances: {cosine.shape}")
print(f"   Property: squared = euclideanΒ²: {torch.allclose(squared, euclidean**2)}")

# 2. Embedding networks
print("\n2. Embedding Networks:")
conv_net = ConvEmbedding(in_channels=3, embedding_dim=64)
resnet = ResNet12Embedding(in_channels=3, embedding_dim=640)

x_img = torch.randn(4, 3, 84, 84)
emb_conv = conv_net(x_img)
emb_resnet = resnet(x_img)

print(f"   Input images: {x_img.shape}")
print(f"   Conv embedding: {emb_conv.shape} (64D)")
print(f"   ResNet-12 embedding: {emb_resnet.shape} (640D)")
print(f"   ")
print(f"   Conv-4 parameters: {sum(p.numel() for p in conv_net.parameters()):,}")
print(f"   ResNet-12 parameters: {sum(p.numel() for p in resnet.parameters()):,}")
print(f"   ResNet-12 ~30Γ— more parameters β†’ better performance")

# 3. Prototypical network
print("\n3. Standard Prototypical Network:")
proto_net = PrototypicalNetwork(conv_net, distance='squared_euclidean')

# Simulate 5-way 5-shot episode
n_way, n_shot, n_query = 5, 5, 10
support_images = torch.randn(n_way * n_shot, 3, 84, 84)
support_labels = torch.arange(n_way).repeat_interleave(n_shot)
query_images = torch.randn(n_way * n_query, 3, 84, 84)

logits = proto_net(support_images, support_labels, query_images, n_way, n_shot)

print(f"   Task: {n_way}-way {n_shot}-shot")
print(f"   Support: {support_images.shape} ({n_way} classes Γ— {n_shot} shots)")
print(f"   Query: {query_images.shape} ({n_way * n_query} examples)")
print(f"   Logits: {logits.shape} (queries Γ— classes)")
print(f"   ")
print(f"   Prototypes: Mean of {n_shot} embeddings per class")
print(f"   Classification: Nearest prototype (min distance)")

# 4. Gaussian prototypical
print("\n4. Gaussian Prototypical Network:")
gauss_net = GaussianPrototypicalNetwork(conv_net, regularization=1e-3)
logits_gauss = gauss_net(support_images, support_labels, query_images, n_way, n_shot)

print(f"   Enhancement: Models class covariance, not just mean")
print(f"   Distance: Mahalanobis dΒ²(x,ΞΌ) = (x-ΞΌ)ᡀΣ⁻¹(x-ΞΌ)")
print(f"   Regularization: Ξ£ ← Ξ£ + Ξ»I (ensures invertibility)")
print(f"   Logits: {logits_gauss.shape}")
print(f"   Benefit: Better when classes have different variances")

# 5. Transductive inference
print("\n5. Transductive Prototypical Network:")
trans_net = TransductivePrototypicalNetwork(conv_net, n_iterations=3)
logits_trans = trans_net(support_images, support_labels, query_images, n_way, n_shot)

print(f"   Refinement: Uses query set to improve prototypes")
print(f"   Iterations: 3 (pseudo-label β†’ update prototypes)")
print(f"   Algorithm:")
print(f"     1. Initial prototypes from support")
print(f"     2. Soft-assign queries to prototypes")
print(f"     3. Update prototypes with weighted queries")
print(f"     4. Repeat")
print(f"   Benefit: +3-5% accuracy on standard benchmarks")

# 6. Episode sampling
print("\n6. Episode Sampler:")
# Dummy data
dummy_data = np.random.randn(1000, 3, 84, 84)
dummy_labels = np.random.randint(0, 100, 1000)

sampler = EpisodeSampler(dummy_data, dummy_labels, n_way=5, n_shot=5, n_query=15)
s_img, s_lbl, q_img, q_lbl = sampler.sample_episode()

print(f"   Dataset: {len(dummy_data)} images, {len(np.unique(dummy_labels))} classes")
print(f"   Episode: {n_way}-way {n_shot}-shot, {sampler.n_query} queries/class")
print(f"   Support: {s_img.shape}, labels {s_lbl.shape}")
print(f"   Query: {q_img.shape}, labels {q_lbl.shape}")
print(f"   Label mapping: Original classes β†’ [0, 1, 2, 3, 4]")

# 7. Complexity analysis
print("\n7. Computational Complexity:")
print("   Forward pass breakdown (5-way 5-shot, 15 queries/class):")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Operation              β”‚ Cost         β”‚ Comment     β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ Embed support (25)     β”‚ 25 Γ— f(x)    β”‚ Once        β”‚")
print("   β”‚ Embed queries (75)     β”‚ 75 Γ— f(x)    β”‚ Once        β”‚")
print("   β”‚ Compute prototypes (5) β”‚ O(25 Γ— D)    β”‚ Mean        β”‚")
print("   β”‚ Distances (75 Γ— 5)     β”‚ O(375 Γ— D)   β”‚ Pairwise    β”‚")
print("   β”‚ Softmax (75)           β”‚ O(375)       β”‚ Fast        β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print("   Dominant cost: Embedding network f(x)")
print("   Prototypes add negligible overhead!")

# 8. Method comparison
print("\n8. Comparison with Other Meta-Learning Methods:")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Method           β”‚ Inner    β”‚ Complexity β”‚ Speed    β”‚ Perf    β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ Prototypical     β”‚ None     β”‚ O(KΓ—D)     β”‚ Fast     β”‚ Good    β”‚")
print("   β”‚ Matching Nets    β”‚ None     β”‚ O(KΒ²Γ—D)    β”‚ Medium   β”‚ Good    β”‚")
print("   β”‚ Relation Net     β”‚ None     β”‚ O(KΓ—DΒ²)    β”‚ Medium   β”‚ Better  β”‚")
print("   β”‚ MAML             β”‚ Gradient β”‚ O(KΓ—BΓ—DΒ²)  β”‚ Slow     β”‚ Best    β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print("   K: shots, D: embedding dim, B: inner steps")
print("   ")
print("   Prototypical wins on: Speed, simplicity, robustness")
print("   MAML wins on: Flexibility, performance (with enough data)")

# 9. When to use guide
print("\n9. When to Use Prototypical Networks:")
print("   Use Prototypical Networks when:")
print("     βœ“ Few examples per class (K=1-10)")
print("     βœ“ Need fast inference (real-time systems)")
print("     βœ“ Limited meta-training data (<1000 classes)")
print("     βœ“ Interpretability important (prototypes = class centers)")
print("     βœ“ Classes well-separated in embedding space")
print("\n   Use MAML instead when:")
print("     βœ“ Have lots of meta-training data")
print("     βœ“ Complex task-specific adaptation needed")
print("     βœ“ Compute not a bottleneck")
print("\n   Use Relation Networks when:")
print("     βœ“ Need learned similarity metric")
print("     βœ“ Fixed distance (Euclidean) insufficient")

# 10. Performance expectations
print("\n10. Expected Performance (5-way accuracy):")
print("   miniImageNet benchmark:")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Method          β”‚ 1-shot   β”‚ 5-shot   β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ Proto (Conv-4)  β”‚ 49.4%    β”‚ 68.2%    β”‚")
print("   β”‚ Proto (ResNet)  β”‚ 56.5%    β”‚ 73.7%    β”‚")
print("   β”‚ + Pre-training  β”‚ 60-63%   β”‚ 75-79%   β”‚")
print("   β”‚ + Transductive  β”‚ +3-5%    β”‚ +3-5%    β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print("   ")
print("   Key insights:")
print("     β€’ Better embedding >> better meta-algorithm")
print("     β€’ ResNet-12 adds ~7% over Conv-4")
print("     β€’ Pre-training adds ~5-10%")
print("     β€’ 5-shot >> 1-shot (~20% absolute gain)")

print("\n" + "=" * 70)