import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Meta-Learning ProblemΒΆ
GoalΒΆ
Learn from distribution of tasks \(p(\mathcal{T})\).
Each task \(\mathcal{T}_i\) has:
Support set \(D_i^{\text{train}}\) (K-shot)
Query set \(D_i^{\text{test}}\)
MAML ObjectiveΒΆ
where \(\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)\) (inner loop)
2. MAML AlgorithmΒΆ
Inner Loop (Task Adaptation)ΒΆ
For task \(\mathcal{T}_i\):
Outer Loop (Meta-Update)ΒΆ
Key: Second-order gradients through inner loop!
class SineTaskDistribution:
"""Sinusoid regression tasks."""
def sample_task(self):
amp = np.random.uniform(0.1, 5.0)
phase = np.random.uniform(0, np.pi)
return amp, phase
def sample_data(self, task, K=10):
amp, phase = task
x = np.random.uniform(-5, 5, K)
y = amp * np.sin(x + phase)
return torch.tensor(x, dtype=torch.float32).unsqueeze(1), \
torch.tensor(y, dtype=torch.float32).unsqueeze(1)
# Test
task_dist = SineTaskDistribution()
task = task_dist.sample_task()
x, y = task_dist.sample_data(task, K=10)
print(f"Task: amp={task[0]:.2f}, phase={task[1]:.2f}")
print(f"Data: x.shape={x.shape}, y.shape={y.shape}")
class MAMLModel(nn.Module):
"""Simple MLP for MAML."""
def __init__(self, input_dim=1, hidden_dim=40, output_dim=1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
def clone_params(self):
return [p.clone() for p in self.parameters()]
def set_params(self, params):
for p, p_new in zip(self.parameters(), params):
p.data = p_new.data
model = MAMLModel().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
def maml_inner_loop(model, x_support, y_support, alpha, steps=1):
"""Adapt to task (inner loop)."""
params = [p.clone() for p in model.parameters()]
for _ in range(steps):
# Forward
model.set_params(params)
pred = model(x_support)
loss = F.mse_loss(pred, y_support)
# Gradient descent
grads = torch.autograd.grad(loss, params, create_graph=True)
params = [p - alpha * g for p, g in zip(params, grads)]
return params
def maml_train_step(model, task_dist, n_tasks=4, K_support=10, K_query=10,
alpha=0.01, beta=0.001, inner_steps=1):
"""Single MAML meta-update."""
meta_loss = 0
for _ in range(n_tasks):
# Sample task
task = task_dist.sample_task()
x_support, y_support = task_dist.sample_data(task, K=K_support)
x_query, y_query = task_dist.sample_data(task, K=K_query)
x_support, y_support = x_support.to(device), y_support.to(device)
x_query, y_query = x_query.to(device), y_query.to(device)
# Inner loop
adapted_params = maml_inner_loop(model, x_support, y_support, alpha, inner_steps)
# Evaluate on query
model.set_params(adapted_params)
pred_query = model(x_query)
loss_query = F.mse_loss(pred_query, y_query)
meta_loss += loss_query
meta_loss /= n_tasks
# Meta-update
meta_optimizer = torch.optim.Adam(model.parameters(), lr=beta)
meta_optimizer.zero_grad()
meta_loss.backward()
meta_optimizer.step()
return meta_loss.item()
TrainingΒΆ
MAML training (the meta-training or outer loop) proceeds over many randomly sampled tasks. For each task, the model takes \(k\) gradient steps on the taskβs support set to produce task-specific parameters, then evaluates on the taskβs query set. The meta-gradient is computed by differentiating through the inner optimization β this requires second-order gradients (gradients of gradients), which is computationally expensive but essential for MAMLβs effectiveness. First-order approximations like FOMAML drop the second derivatives for efficiency with only modest performance loss. The outer loop optimizer (typically Adam) updates the initialization parameters to minimize the average query-set loss across tasks.
# Train MAML
model = MAMLModel().to(device)
task_dist = SineTaskDistribution()
losses = []
for iteration in range(200):
loss = maml_train_step(model, task_dist, n_tasks=4, K_support=10, K_query=10,
alpha=0.01, beta=0.001, inner_steps=1)
losses.append(loss)
if (iteration + 1) % 50 == 0:
print(f"Iter {iteration+1}, Loss: {loss:.4f}")
plt.figure(figsize=(8, 5))
plt.plot(losses)
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Meta Loss', fontsize=11)
plt.title('MAML Training', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
Few-Shot AdaptationΒΆ
The payoff of MAML is at meta-test time: given a new, unseen task with only a handful of labeled examples (the support set), we take a few gradient steps from the learned initialization and evaluate on the query set. Because MAML has learned an initialization that is maximally sensitive to task-relevant features, even 1-5 gradient steps are enough to achieve strong performance on the new task. This is fundamentally different from training from scratch or fine-tuning a pre-trained model β MAML explicitly optimizes for fast adaptation, making it especially powerful for applications where labeled data is scarce, such as drug discovery, robotics, and personalized recommendation.
# Test on new task
test_task = task_dist.sample_task()
x_support, y_support = task_dist.sample_data(test_task, K=5)
x_support, y_support = x_support.to(device), y_support.to(device)
# Test points
x_test = torch.linspace(-5, 5, 100).unsqueeze(1).to(device)
y_true = test_task[0] * torch.sin(x_test + test_task[1])
# Before adaptation
model.eval()
with torch.no_grad():
y_before = model(x_test)
# After adaptation (5-shot)
adapted_params = maml_inner_loop(model, x_support, y_support, alpha=0.01, steps=5)
model.set_params(adapted_params)
with torch.no_grad():
y_after = model(x_test)
# Plot
plt.figure(figsize=(10, 6))
plt.plot(x_test.cpu(), y_true.cpu(), 'k-', label='True', linewidth=2)
plt.plot(x_test.cpu(), y_before.cpu(), 'b--', label='Before (0-shot)', alpha=0.7)
plt.plot(x_test.cpu(), y_after.cpu(), 'r-', label='After (5-shot)', linewidth=2)
plt.scatter(x_support.cpu(), y_support.cpu(), s=100, c='red', marker='x', label='Support', zorder=5)
plt.xlabel('x', fontsize=11)
plt.ylabel('y', fontsize=11)
plt.title('MAML Few-Shot Adaptation', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
SummaryΒΆ
MAML:ΒΆ
Inner loop: Task adaptation via gradient descent
Outer loop: Meta-learning across tasks
Algorithm:ΒΆ
Sample tasks from \(p(\mathcal{T})\)
Adapt: \(\theta' = \theta - \alpha \nabla \mathcal{L}_{\text{train}}\)
Meta-update: \(\theta \leftarrow \theta - \beta \nabla \mathcal{L}_{\text{test}}(\theta')\)
Key Features:ΒΆ
Model-agnostic (works with any gradient-based model)
Few-shot adaptation
Second-order gradients
Applications:ΒΆ
Few-shot classification
Rapid RL adaptation
Personalization
Drug discovery
Variants:ΒΆ
Reptile: First-order approximation
ANIL: Almost no inner loop
Meta-SGD: Learn inner LR
Next Steps:ΒΆ
Explore Prototypical Networks
Apply to image classification
Study task distribution design
Advanced Meta-Learning TheoryΒΆ
1. MAML Bi-Level Optimization FrameworkΒΆ
Mathematical FormulationΒΆ
Meta-learning aims to learn a good initialization \(\theta\) that can quickly adapt to new tasks. The bi-level optimization problem is:
where the inner loop (task adaptation) computes:
and the outer loop (meta-update) performs:
Second-Order GradientsΒΆ
The key challenge: \(\theta_i'\) depends on \(\theta\), so computing \(\nabla_\theta \mathcal{L}^{\text{test}}(f_{\theta_i'})\) requires the chain rule through the inner loop:
For one inner step:
where \(H\) is the Hessian matrix of the loss w.r.t. \(\theta\). This is computationally expensive!
First-Order MAML (FOMAML)ΒΆ
Approximate by ignoring second-order terms:
i.e., treat \(\theta_i'\) as independent of \(\theta\) for gradient computation. FOMAML is much faster with minimal performance loss in practice.
2. Prototypical NetworksΒΆ
Metric Learning FrameworkΒΆ
Instead of learning task-specific parameters, learn an embedding function \(f_\phi: \mathcal{X} \to \mathbb{R}^d\) that maps inputs to a metric space where classification uses distance to class prototypes.
Class PrototypesΒΆ
For N-way K-shot classification, compute prototype for class \(k\) as:
where \(S_k\) is the support set for class \(k\) (K examples).
Classification RuleΒΆ
Probability that query \(x\) belongs to class \(k\):
where \(d(\cdot, \cdot)\) is a distance metric (typically Euclidean distance or cosine similarity).
Training ObjectiveΒΆ
Minimize negative log-likelihood:
Episode-based training: Each training episode samples N classes, K support examples per class, and Q query examples.
Why It WorksΒΆ
Inductive bias: Similar examples should have similar embeddings
Non-parametric: No task-specific parameters (prototypes computed from data)
Fast adaptation: Single forward pass, no gradient steps
Scalable: Works with large N (many classes)
3. Matching NetworksΒΆ
Attention-Based MatchingΒΆ
Matching Networks classify by comparing query to all support examples using attention:
where \(a(x, x_i)\) is the attention weight from query \(x\) to support example \(x_i\):
with \(c(\cdot, \cdot)\) being cosine similarity.
Full Context EmbeddingsΒΆ
Unlike Prototypical Networks, embeddings use full context of the support set:
Support encoding \(g\): Bidirectional LSTM over support set
Query encoding \(f\): LSTM with read attention to support set
This allows embeddings to be task-dependent.
One-Shot Learning FormulationΒΆ
For one-shot (K=1), Matching Networks directly weight support labels:
This is a differentiable nearest neighbors approach.
4. Few-Shot Learning TheoryΒΆ
N-Way K-Shot ProblemΒΆ
Problem setup:
N-way: Classify among N classes
K-shot: Only K labeled examples per class in support set
Task distribution: \(p(\mathcal{T})\) generates diverse tasks
Episode construction:
Sample N classes from dataset
Sample K examples per class β Support set \(S\) (NΓK examples)
Sample Q examples per class β Query set \(Q\) (NΓQ examples)
Meta-Training vs Meta-TestingΒΆ
Meta-training:
Tasks sampled from training classes \(\mathcal{C}_{\text{train}}\)
Learn \(\phi\) or \(\theta\) to perform well on query sets after adapting to support sets
Meta-testing:
Tasks sampled from disjoint test classes \(\mathcal{C}_{\text{test}}\)
Evaluate few-shot performance on unseen classes
Key: \(\mathcal{C}_{\text{train}} \cap \mathcal{C}_{\text{test}} = \emptyset\) (no class overlap)
Generalization in Meta-LearningΒΆ
Meta-learning generalizes across tasks rather than across examples:
where \(h_{\mathcal{T}}\) is the task-specific hypothesis (e.g., adapted parameters or prototypes).
Challenge: Need sufficient task diversity in \(p(\mathcal{T})\) to generalize to new tasks.
5. Comparison of Meta-Learning ApproachesΒΆ
Approach |
Adaptation |
Parameters |
Computation |
Strengths |
Weaknesses |
|---|---|---|---|---|---|
MAML |
Gradient descent (inner loop) |
Task-specific \(\theta'\) |
High (second-order gradients) |
Model-agnostic, flexible |
Slow adaptation, expensive |
FOMAML |
Gradient descent (1st order) |
Task-specific \(\theta'\) |
Medium |
Faster than MAML |
Slightly lower performance |
Prototypical |
Non-parametric (prototypes) |
None (prototypes from data) |
Low (single forward pass) |
Fast, simple, scalable to many classes |
Assumes Euclidean structure |
Matching |
Attention over support |
None (attention weights) |
Medium (bi-LSTM encoding) |
Full context, differentiable NN |
Complex encoding, slower |
Relation Net |
Learn comparison metric |
Task-specific relation module |
Medium |
Flexible metric |
Requires meta-training |
When to Use Each:ΒΆ
MAML/FOMAML: When you need flexible adaptation with gradient-based learning; good for RL and diverse tasks
Prototypical Networks: When classes have clear cluster structure; best for classification with many classes
Matching Networks: When task context is important; one-shot learning scenarios
Hybrid: Combine approaches (e.g., Prototypical MAML)
6. Advanced InsightsΒΆ
Sample ComplexityΒΆ
Few-shot learning aims to achieve high accuracy with limited data:
Traditional: \(O(VC \text{ dim} / \epsilon^2)\) samples needed for generalization
Meta-learning: Amortizes learning across tasks, reducing per-task sample complexity
Trade-off: More meta-training tasks β better few-shot performance per task
Task DiversityΒΆ
Performance depends critically on task distribution \(p(\mathcal{T})\):
High diversity: Better generalization to new tasks
Task relatedness: Meta-learning assumes tasks share structure
Theorem (informal): If \(\mathcal{T}_{\text{test}}\) is far from \(\mathcal{T}_{\text{train}}\) in task space, meta-learning provides no benefit over learning from scratch.
Computational ComplexityΒΆ
For N-way K-shot with embedding dimension \(d\):
Method |
Time per episode |
Memory |
|---|---|---|
MAML |
$O(T \cdot N \cdot K \cdot |
\theta |
FOMAML |
$O(T \cdot N \cdot K \cdot |
\theta |
Prototypical |
\(O(N \cdot K \cdot d + N \cdot Q \cdot d)\) |
\(O(N \cdot d)\) (prototypes) |
Matching |
\(O(N \cdot K \cdot d^2)\) (bi-LSTM) |
\(O(N \cdot K \cdot d)\) |
\(T\) = inner loop steps, \(|\theta|\) = number of parameters, \(Q\) = query set size
7. Practical ConsiderationsΒΆ
Dataset DesignΒΆ
Common benchmarks:
Omniglot: 1623 characters, 20 examples each (character recognition)
Mini-ImageNet: 100 classes, 600 images each (5-way 1-shot / 5-shot)
Tiered-ImageNet: 608 classes, hierarchical structure
Episode sampling: Ensure balanced classes and sufficient query examples
Hyperparameter TuningΒΆ
MAML:
Inner learning rate \(\alpha\): 0.01 - 0.1 (task-specific)
Outer learning rate \(\beta\): 0.001 - 0.01 (meta-update)
Inner steps \(T\): 1-5 (more steps β better adaptation but slower)
Prototypical:
Embedding dimension \(d\): 64-512 (trade-off capacity vs. overfitting)
Distance metric: Euclidean (classification), cosine (semantic tasks)
OverfittingΒΆ
Symptoms:
High meta-training accuracy, low meta-test accuracy
Model memorizes support sets rather than learning to adapt
Solutions:
Increase task diversity (data augmentation, more classes)
Regularization (dropout, weight decay)
Early stopping on meta-validation set
PDF ReferenceΒΆ
π Related Papers:
Finn et al. (2017): βModel-Agnostic Meta-Learning for Fast Adaptation of Deep Networksβ
Snell et al. (2017): βPrototypical Networks for Few-shot Learningβ
Vinyals et al. (2016): βMatching Networks for One Shot Learningβ
Ravi & Larochelle (2017): βOptimization as a Model for Few-Shot Learningβ
π GitHub:
# ============================================================================
# Advanced Meta-Learning Implementations
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import copy
# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# ============================================================================
# 1. Complete MAML with Second-Order Gradients
# ============================================================================
class MAMLConvNet(nn.Module):
"""
4-layer ConvNet for few-shot image classification.
Used in Omniglot and Mini-ImageNet experiments.
"""
def __init__(self, in_channels=1, num_classes=5, hidden_dim=64):
super().__init__()
self.features = nn.Sequential(
OrderedDict([
('conv1', nn.Conv2d(in_channels, hidden_dim, 3, padding=1)),
('bn1', nn.BatchNorm2d(hidden_dim)),
('relu1', nn.ReLU(inplace=True)),
('pool1', nn.MaxPool2d(2)),
('conv2', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
('bn2', nn.BatchNorm2d(hidden_dim)),
('relu2', nn.ReLU(inplace=True)),
('pool2', nn.MaxPool2d(2)),
('conv3', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
('bn3', nn.BatchNorm2d(hidden_dim)),
('relu3', nn.ReLU(inplace=True)),
('pool3', nn.MaxPool2d(2)),
('conv4', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
('bn4', nn.BatchNorm2d(hidden_dim)),
('relu4', nn.ReLU(inplace=True)),
('pool4', nn.MaxPool2d(2))
])
)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # Flatten
return self.classifier(x)
class MAML:
"""
Full MAML implementation with second-order gradients.
Key features:
- Bi-level optimization (inner loop + outer loop)
- Second-order gradient computation through inner loop
- Support for multiple inner gradient steps
- Episode-based meta-training
"""
def __init__(self, model, inner_lr=0.01, meta_lr=0.001,
inner_steps=5, first_order=False):
self.model = model
self.inner_lr = inner_lr
self.meta_lr = meta_lr
self.inner_steps = inner_steps
self.first_order = first_order # Use FOMAML if True
# Meta-optimizer (for outer loop)
self.meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=meta_lr)
def inner_loop(self, x_support, y_support, params=None):
"""
Task adaptation via gradient descent (inner loop).
Args:
x_support: Support set inputs [K, ...]
y_support: Support set labels [K]
params: Current model parameters (None uses model.parameters())
Returns:
adapted_params: Parameters after inner loop adaptation
"""
if params is None:
params = OrderedDict(self.model.named_parameters())
for step in range(self.inner_steps):
# Forward pass with current params
logits = self._forward_with_params(x_support, params)
loss = F.cross_entropy(logits, y_support)
# Compute gradients w.r.t. params
# create_graph=True enables second-order gradients
grads = torch.autograd.grad(
loss, params.values(),
create_graph=not self.first_order
)
# Inner loop update: ΞΈ' = ΞΈ - Ξ±βL
params = OrderedDict(
(name, param - self.inner_lr * grad)
for ((name, param), grad) in zip(params.items(), grads)
)
return params
def _forward_with_params(self, x, params):
"""Forward pass using specific parameters (for inner loop)."""
# Simple implementation for linear layers
# For production, use functional API or hooks
x = x.view(x.size(0), -1) # Flatten
for name, param in params.items():
if 'weight' in name and 'fc' not in name:
continue
if 'fc.weight' in name:
x = F.linear(x, param)
elif 'fc.bias' in name:
x = x + param
return x
def outer_loop(self, tasks):
"""
Meta-update across multiple tasks (outer loop).
Args:
tasks: List of (x_support, y_support, x_query, y_query) tuples
Returns:
meta_loss: Average query loss across tasks
"""
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for x_support, y_support, x_query, y_query in tasks:
# Inner loop: adapt to support set
adapted_params = self.inner_loop(x_support, y_support)
# Evaluate on query set with adapted params
logits_query = self._forward_with_params(x_query, adapted_params)
loss_query = F.cross_entropy(logits_query, y_query)
meta_loss += loss_query
# Average over tasks
meta_loss = meta_loss / len(tasks)
# Outer loop update: ΞΈ β ΞΈ - Ξ²β_ΞΈ L_query(ΞΈ')
# This computes gradients through the inner loop (second-order)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
# ============================================================================
# 2. Prototypical Networks
# ============================================================================
class PrototypicalNetwork(nn.Module):
"""
Prototypical Networks for few-shot classification.
Learns an embedding function f_Ο that maps inputs to a metric space.
Classification uses distance to class prototypes computed from support set.
"""
def __init__(self, in_channels=1, embedding_dim=64):
super().__init__()
# Embedding network (same architecture as MAML encoder)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, embedding_dim, 3, padding=1),
nn.BatchNorm2d(embedding_dim),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def forward(self, x):
"""Compute embeddings."""
embeddings = self.encoder(x)
return embeddings.view(embeddings.size(0), -1)
def compute_prototypes(self, x_support, y_support, n_way):
"""
Compute class prototypes from support set.
Prototype for class k: c_k = (1/|S_k|) Ξ£ f_Ο(x_i) for (x_i, y_i) β S_k
Args:
x_support: Support set inputs [N*K, C, H, W]
y_support: Support set labels [N*K]
n_way: Number of classes
Returns:
prototypes: Class prototypes [N, embedding_dim]
"""
embeddings = self(x_support) # [N*K, embedding_dim]
prototypes = []
for k in range(n_way):
# Find all examples of class k
class_mask = (y_support == k)
class_embeddings = embeddings[class_mask]
# Compute mean (prototype)
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes) # [N, embedding_dim]
def classify(self, x_query, prototypes, distance='euclidean'):
"""
Classify queries using distance to prototypes.
p(y = k | x) = exp(-d(f(x), c_k)) / Ξ£_k' exp(-d(f(x), c_k'))
Args:
x_query: Query inputs [Q, C, H, W]
prototypes: Class prototypes [N, embedding_dim]
distance: 'euclidean' or 'cosine'
Returns:
logits: Class logits [Q, N]
"""
query_embeddings = self(x_query) # [Q, embedding_dim]
if distance == 'euclidean':
# Negative squared Euclidean distance
# -||f(x) - c_k||^2 = -||f(x)||^2 - ||c_k||^2 + 2<f(x), c_k>
dists = torch.cdist(query_embeddings, prototypes, p=2) # [Q, N]
logits = -dists ** 2
elif distance == 'cosine':
# Cosine similarity
query_norm = F.normalize(query_embeddings, p=2, dim=1)
proto_norm = F.normalize(prototypes, p=2, dim=1)
logits = query_norm @ proto_norm.t() # [Q, N]
else:
raise ValueError(f"Unknown distance: {distance}")
return logits
def loss(self, x_support, y_support, x_query, y_query, n_way, distance='euclidean'):
"""Compute prototypical loss for an episode."""
# Compute prototypes from support set
prototypes = self.compute_prototypes(x_support, y_support, n_way)
# Classify query set
logits = self.classify(x_query, prototypes, distance)
# Cross-entropy loss
return F.cross_entropy(logits, y_query)
# ============================================================================
# 3. Matching Networks (Simplified)
# ============================================================================
class MatchingNetwork(nn.Module):
"""
Matching Networks for one-shot learning.
Uses attention over support set to classify queries.
Simplified version without full context embeddings (bi-LSTM).
"""
def __init__(self, in_channels=1, embedding_dim=64):
super().__init__()
# Shared embedding network
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, embedding_dim, 3, padding=1),
nn.BatchNorm2d(embedding_dim),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def forward(self, x):
"""Compute embeddings."""
embeddings = self.encoder(x)
return embeddings.view(embeddings.size(0), -1)
def attention(self, query_embedding, support_embeddings):
"""
Compute attention weights from query to support set.
a(x, x_i) = exp(cosine(f(x), g(x_i))) / Ξ£_j exp(cosine(f(x), g(x_j)))
Args:
query_embedding: [embedding_dim]
support_embeddings: [K, embedding_dim]
Returns:
attention_weights: [K]
"""
# Cosine similarity
query_norm = F.normalize(query_embedding, p=2, dim=0)
support_norm = F.normalize(support_embeddings, p=2, dim=1)
similarities = support_norm @ query_norm # [K]
# Softmax to get attention weights
attention_weights = F.softmax(similarities, dim=0)
return attention_weights
def predict(self, x_query, x_support, y_support, n_way):
"""
Predict using attention-weighted support labels.
Ε· = Ξ£_i a(x, x_i) y_i
Args:
x_query: Query input [C, H, W]
x_support: Support set inputs [N*K, C, H, W]
y_support: Support set labels [N*K]
n_way: Number of classes
Returns:
logits: Class probabilities [n_way]
"""
query_embedding = self(x_query.unsqueeze(0)).squeeze(0) # [embedding_dim]
support_embeddings = self(x_support) # [N*K, embedding_dim]
# Compute attention weights
attention_weights = self.attention(query_embedding, support_embeddings) # [N*K]
# Weighted sum over one-hot labels
y_one_hot = F.one_hot(y_support, num_classes=n_way).float() # [N*K, n_way]
logits = (attention_weights.unsqueeze(1) * y_one_hot).sum(dim=0) # [n_way]
return logits
def loss(self, x_support, y_support, x_query, y_query, n_way):
"""Compute matching loss for an episode."""
batch_logits = []
for i in range(x_query.size(0)):
logits = self.predict(x_query[i], x_support, y_support, n_way)
batch_logits.append(logits)
batch_logits = torch.stack(batch_logits) # [Q, n_way]
return F.cross_entropy(batch_logits, y_query)
# ============================================================================
# 4. Few-Shot Episode Sampler
# ============================================================================
class FewShotEpisode:
"""
Sample N-way K-shot episodes for meta-learning.
Each episode contains:
- Support set: N classes Γ K examples
- Query set: N classes Γ Q examples
"""
def __init__(self, images, labels, n_way=5, k_shot=1, q_query=15):
"""
Args:
images: All available images [num_samples, C, H, W]
labels: Corresponding labels [num_samples]
n_way: Number of classes per episode
k_shot: Number of support examples per class
q_query: Number of query examples per class
"""
self.images = images
self.labels = labels
self.n_way = n_way
self.k_shot = k_shot
self.q_query = q_query
# Group images by class
self.classes = torch.unique(labels)
self.class_to_indices = {
c.item(): (labels == c).nonzero(as_tuple=True)[0]
for c in self.classes
}
def sample_episode(self):
"""
Sample a single N-way K-shot episode.
Returns:
x_support: [N*K, C, H, W]
y_support: [N*K] (relabeled 0 to N-1)
x_query: [N*Q, C, H, W]
y_query: [N*Q] (relabeled 0 to N-1)
"""
# Sample N classes
episode_classes = np.random.choice(
len(self.classes), self.n_way, replace=False
)
support_images, support_labels = [], []
query_images, query_labels = [], []
for new_label, class_idx in enumerate(episode_classes):
class_label = self.classes[class_idx]
indices = self.class_to_indices[class_label.item()]
# Sample K+Q examples
selected = indices[torch.randperm(len(indices))[:self.k_shot + self.q_query]]
# Split into support and query
support_indices = selected[:self.k_shot]
query_indices = selected[self.k_shot:]
support_images.append(self.images[support_indices])
support_labels.extend([new_label] * self.k_shot)
query_images.append(self.images[query_indices])
query_labels.extend([new_label] * self.q_query)
# Concatenate
x_support = torch.cat(support_images, dim=0)
y_support = torch.tensor(support_labels, dtype=torch.long)
x_query = torch.cat(query_images, dim=0)
y_query = torch.tensor(query_labels, dtype=torch.long)
return x_support, y_support, x_query, y_query
# ============================================================================
# 5. Visualization: Meta-Learning Comparison
# ============================================================================
def visualize_embeddings(model, x_support, y_support, x_query, y_query, n_way, title):
"""
Visualize 2D embeddings with prototypes.
(Assumes embedding_dim = 2 or uses PCA)
"""
from sklearn.decomposition import PCA
model.eval()
with torch.no_grad():
# Get embeddings
support_emb = model(x_support).cpu().numpy()
query_emb = model(x_query).cpu().numpy()
# Apply PCA if embedding_dim > 2
if support_emb.shape[1] > 2:
pca = PCA(n_components=2)
support_emb = pca.fit_transform(support_emb)
query_emb = pca.transform(query_emb)
# Compute prototypes
prototypes = []
for k in range(n_way):
mask = (y_support.cpu().numpy() == k)
proto = support_emb[mask].mean(axis=0)
prototypes.append(proto)
prototypes = np.array(prototypes)
# Plot
plt.figure(figsize=(10, 8))
colors = plt.cm.rainbow(np.linspace(0, 1, n_way))
for k in range(n_way):
# Support examples
support_mask = (y_support.cpu().numpy() == k)
plt.scatter(support_emb[support_mask, 0], support_emb[support_mask, 1],
c=[colors[k]], marker='s', s=100,
label=f'Support Class {k}', edgecolors='black', linewidth=1.5)
# Query examples
query_mask = (y_query.cpu().numpy() == k)
plt.scatter(query_emb[query_mask, 0], query_emb[query_mask, 1],
c=[colors[k]], marker='o', s=60, alpha=0.6)
# Prototype
plt.scatter(prototypes[k, 0], prototypes[k, 1],
c=[colors[k]], marker='*', s=500,
edgecolors='black', linewidth=2, zorder=10)
plt.xlabel('Embedding Dimension 1', fontsize=12)
plt.ylabel('Embedding Dimension 2', fontsize=12)
plt.title(title, fontsize=14, fontweight='bold')
plt.legend(loc='best', fontsize=9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Test prototypical network with synthetic data
print("\n" + "="*80)
print("Prototypical Network Example (Synthetic Data)")
print("="*80)
# Create synthetic 5-way 5-shot data (28x28 images)
torch.manual_seed(42)
n_way, k_shot, q_query = 5, 5, 15
image_size = 28
# Generate synthetic images (different patterns per class)
all_images, all_labels = [], []
for class_id in range(10): # 10 total classes
# Create class-specific pattern
base_pattern = torch.randn(1, 1, image_size, image_size)
images = base_pattern + 0.1 * torch.randn(100, 1, image_size, image_size)
labels = torch.full((100,), class_id, dtype=torch.long)
all_images.append(images)
all_labels.append(labels)
all_images = torch.cat(all_images, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Create episode sampler
episode_sampler = FewShotEpisode(all_images, all_labels, n_way, k_shot, q_query)
# Sample one episode
x_support, y_support, x_query, y_query = episode_sampler.sample_episode()
print(f"Episode shapes:")
print(f" Support: {x_support.shape}, labels: {y_support.shape}")
print(f" Query: {x_query.shape}, labels: {y_query.shape}")
# Train prototypical network
proto_net = PrototypicalNetwork(in_channels=1, embedding_dim=64).to(device)
optimizer = torch.optim.Adam(proto_net.parameters(), lr=0.001)
print("\nTraining Prototypical Network...")
losses = []
for episode in range(100):
x_sup, y_sup, x_que, y_que = episode_sampler.sample_episode()
x_sup, y_sup = x_sup.to(device), y_sup.to(device)
x_que, y_que = x_que.to(device), y_que.to(device)
optimizer.zero_grad()
loss = proto_net.loss(x_sup, y_sup, x_que, y_que, n_way)
loss.backward()
optimizer.step()
losses.append(loss.item())
if (episode + 1) % 20 == 0:
print(f"Episode {episode+1}/100, Loss: {loss.item():.4f}")
# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses, linewidth=2)
plt.xlabel('Episode', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Prototypical Network Training', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Evaluate on test episode
proto_net.eval()
x_sup, y_sup, x_que, y_que = episode_sampler.sample_episode()
x_sup, y_sup = x_sup.to(device), y_sup.to(device)
x_que, y_que = x_que.to(device), y_que.to(device)
with torch.no_grad():
prototypes = proto_net.compute_prototypes(x_sup, y_sup, n_way)
logits = proto_net.classify(x_que, prototypes)
predictions = logits.argmax(dim=1)
accuracy = (predictions == y_que).float().mean().item()
print(f"\nTest Episode Accuracy: {accuracy*100:.2f}%")
print(f"Random baseline: {100/n_way:.2f}%")
print("\n" + "="*80)
print("Implementation Complete!")
print("="*80)
print("\nKey Insights:")
print("1. MAML learns an initialization that adapts quickly via gradient descent")
print("2. Prototypical Networks classify using distance to class prototypes")
print("3. Matching Networks use attention over support set (differentiable NN)")
print("4. Episode-based training is crucial for meta-learning generalization")
print("5. Trade-offs: MAML (flexible but slow) vs Prototypical (fast but assumes metric space)")
print("\nNext: Apply to real datasets (Omniglot, Mini-ImageNet) for stronger results!")
Advanced Meta-Learning and MAML TheoryΒΆ
1. Introduction to Meta-LearningΒΆ
Meta-learning (learning to learn) aims to design models that can:
Quickly adapt to new tasks with few examples
Leverage prior experience across tasks
Generalize to unseen task distributions
1.1 Problem FormulationΒΆ
Given task distribution p(T), where each task T = {D_train, D_test}:
D_train: Support set (few examples for adaptation)
D_test: Query set (evaluation)
Goal: Learn ΞΈ such that model adapted on D_train generalizes to D_test.
1.2 Meta-Learning ApproachesΒΆ
Metric-based: Learn embedding space (Prototypical Networks, Matching Networks, Relation Networks)
Model-based: Learn update rules via RNNs or memory (Meta-Networks, SNAIL)
Optimization-based: Learn good initialization for gradient descent (MAML, Reptile)
2. Model-Agnostic Meta-Learning (MAML)ΒΆ
MAML [Finn et al., 2017] learns initialization ΞΈ that enables fast adaptation via few gradient steps.
2.1 AlgorithmΒΆ
Bi-level optimization:
Outer loop (meta-update):
ΞΈ β ΞΈ - Ξ² β_ΞΈ Ξ£_i L_{T_i}(ΞΈ'_i)
Inner loop (task-specific adaptation):
ΞΈ'_i = ΞΈ - Ξ± β_ΞΈ L_{T_i}(ΞΈ)
Where:
ΞΈ: Meta-parameters (initialization)
ΞΈβ_i: Task-specific parameters after adaptation
Ξ±: Inner learning rate (adaptation)
Ξ²: Outer learning rate (meta-learning)
L_{T_i}: Loss on task T_i
2.2 Computational GraphΒΆ
ΞΈ β [Inner gradient] β ΞΈ'β, ΞΈ'β, ..., ΞΈ'_N β [Evaluate on query] β Meta-loss
β
[Outer gradient] β Update ΞΈ
Key insight: Gradients flow through inner optimization! (Second-order)
2.3 MAML ObjectiveΒΆ
Meta-objective:
min_ΞΈ Ξ£_{T_i ~ p(T)} L_{T_i}(U_Ξ±(ΞΈ, D_train^i), D_test^i)
Where U_Ξ± is the adaptation operator (one or more gradient steps).
2.4 First-Order MAML (FOMAML)ΒΆ
Challenge: Computing second-order derivatives is expensive.
FOMAML: Ignore second-order terms:
β_ΞΈ L_{T_i}(ΞΈ'_i) β β_{ΞΈ'_i} L_{T_i}(ΞΈ'_i)
Treats ΞΈβ_i as constant w.r.t. ΞΈ. Much faster, surprisingly effective.
3. Mathematical FoundationsΒΆ
3.1 Taylor Expansion InterpretationΒΆ
After one inner step:
ΞΈ' = ΞΈ - Ξ± β_ΞΈ L_train(ΞΈ)
Loss on query set:
L_test(ΞΈ') β L_test(ΞΈ) - Ξ± β_ΞΈ L_test(ΞΈ)^T β_ΞΈ L_train(ΞΈ)
MAML optimizes for: Alignment between train and test gradients!
3.2 Gradient of Meta-LossΒΆ
Meta-gradient:
β_ΞΈ L_test(ΞΈ') = β_{ΞΈ'} L_test(ΞΈ') Β· β_ΞΈ ΞΈ'
Where:
β_ΞΈ ΞΈ' = I - Ξ± βΒ²_ΞΈ L_train(ΞΈ)
Hessian term H = βΒ²_ΞΈ L_train(ΞΈ) captures second-order effects.
Full gradient:
β_ΞΈ L_test(ΞΈ') = β_{ΞΈ'} L_test(ΞΈ') Β· (I - Ξ± H)
FOMAML approximation: β_ΞΈ ΞΈβ β I (ignore Hessian).
3.3 Implicit DifferentiationΒΆ
Alternative to backpropagating through inner loop:
At convergence of inner optimization ΞΈ* = argmin_ΞΈβ L_train(ΞΈβ):
β_{ΞΈ*} L_train(ΞΈ*) = 0
Implicit function theorem:
β_ΞΈ ΞΈ* = -(βΒ²_{ΞΈ*} L_train)^{-1} β_{ΞΈ*} β_ΞΈ L_train
Used in iMAML [Rajeswaran et al., 2019].
4. Variants and ExtensionsΒΆ
4.1 Reptile [Nichol et al., 2018]ΒΆ
Simpler: Just move toward adapted parameters.
ΞΈ β ΞΈ + Ξ² (ΞΈ' - ΞΈ)
Where ΞΈβ is result of K inner steps. No meta-gradient computation!
Connection to MAML: Reptile β MAML + averaging over all inner steps.
4.2 MAML++ [Antoniou et al., 2019]ΒΆ
Improvements:
Multi-step loss: Use loss from all inner steps, not just final
Per-parameter learning rates: Ξ± per layer
Learn learning rates: Ξ±, Ξ² as parameters
Batch normalization: Use running stats from support set
Meta-loss:
L_meta = Ξ£_k w_k L_test(ΞΈ_k)
Where ΞΈ_k is parameters after k inner steps, w_k are weights.
4.3 Meta-SGD [Li et al., 2017]ΒΆ
Learn both initialization ΞΈ and learning rates Ξ±:
ΞΈ'_i = ΞΈ - Ξ± β β_ΞΈ L_{T_i}(ΞΈ)
Meta-parameters: {ΞΈ, Ξ±} (both updated in outer loop).
4.4 ANIL (Almost No Inner Loop) [Raghu et al., 2020]ΒΆ
Finding: Only adapting final layer often suffices!
Freezes feature extractor during inner loop:
Features: Ο(x; ΞΈ_features) [frozen]
Head: h(Ο; ΞΈ_head) [adapted]
Faster, similar performance on many tasks.
4.5 Meta-Curvature [Park & Oliva, 2019]ΒΆ
Incorporate curvature information:
ΞΈ' = ΞΈ - Ξ± (H + Ξ»I)^{-1} β_ΞΈ L_train(ΞΈ)
Where H is Hessian approximation (e.g., Fisher information).
5. Task-Conditional ArchitecturesΒΆ
5.1 CAVIA [Zintgraf et al., 2019]ΒΆ
Context adaptation via context parameters Ο:
f(x; ΞΈ, Ο)
Inner loop: Adapt only Ο (low-dimensional) Outer loop: Update ΞΈ
Reduces inner loop computation significantly.
5.2 Conditional Neural Processes (CNPs)ΒΆ
Amortized inference via encoder-decoder:
Encoder: Context set (x_c, y_c) β representation r Decoder: (r, x_query) β y_query
Training: Sample context/target splits from tasks.
Advantage: Single forward pass at test (no inner loop).
6. Few-Shot Learning ApplicationsΒΆ
6.1 N-way K-shot ClassificationΒΆ
Task: Classify into N classes with K examples each.
Episode construction:
Sample N classes from dataset
Sample K support + Q query examples per class
Train to classify query given support
MAML approach:
Inner loop: Adapt on support set
Outer loop: Evaluate on query set, update ΞΈ
6.2 Few-Shot RegressionΒΆ
Task distribution: Functions f ~ p(f)
Example: Sine wave regression
Sample amplitude A, phase Ο
Support: Few (x, f(x)) pairs
Query: Predict f(x) at new x
MAML learns: Good initialization for function fitting.
6.3 Reinforcement LearningΒΆ
Task: Different reward functions or environments
Inner loop: Adapt policy Ο_ΞΈ to task via RL algorithm Outer loop: Meta-update ΞΈ for fast adaptation
Applications: Robot locomotion (varying terrains), manipulation (different objects).
7. Theoretical AnalysisΒΆ
7.1 Generalization BoundΒΆ
For MAML, PAC-Bayes bound:
With probability β₯ 1-Ξ΄:
E_{T~p(T)} [L_test(ΞΈ_T)] β€ E_{T~p(T)} [L_train(ΞΈ_T)] + O(β(KL(P_ΞΈ || P_prior) / N))
Where:
P_ΞΈ: Distribution over task-specific parameters
P_prior: Prior distribution
N: Number of tasks
Insight: More tasks β better meta-learning generalization.
7.2 Convergence RateΒΆ
Under smoothness assumptions, MAML converges at rate:
O(1/βT)
for T meta-iterations (same as SGD).
MAML++: Improved constant factors via multi-step loss.
7.3 ExpressivenessΒΆ
Theorem [Finn & Levine, 2018]: MAML can represent any learning algorithm that:
Uses gradient descent for adaptation
Has bounded Hessian
Limitation: Fixed number of inner steps limits expressiveness.
8. Practical ConsiderationsΒΆ
8.1 HyperparametersΒΆ
Critical choices:
Inner steps K: 1-10 (more = better adaptation, slower meta-training)
Inner LR Ξ±: 0.01-0.1 (task-dependent)
Outer LR Ξ²: 0.001-0.01
Batch size: Number of tasks per meta-update (4-32)
Tuning: Grid search or learning Ξ±, Ξ².
8.2 StabilityΒΆ
Issue: Second-order gradients can explode/vanish.
Solutions:
Gradient clipping
Layer normalization in network
Lower outer learning rate
Use FOMAML (more stable)
8.3 Computational CostΒΆ
MAML: O(K Β· |ΞΈ|Β²) per task (Hessian computation) FOMAML: O(K Β· |ΞΈ|) per task Reptile: O(K Β· |ΞΈ|) per task (no backprop through inner loop)
For 1M parameters, K=5 steps:
MAML: ~10Γ slower than standard training
FOMAML: ~2Γ slower
8.4 MemoryΒΆ
Challenge: Store computational graph for K inner steps.
Solution: Checkpointing (trade computation for memory).
PyTorch example:
from torch.utils.checkpoint import checkpoint
def inner_loop(ΞΈ, data):
return checkpoint(adaptation_function, ΞΈ, data)
9. Comparison with Other ApproachesΒΆ
9.1 vs. Transfer LearningΒΆ
Aspect |
Transfer Learning |
Meta-Learning |
|---|---|---|
Goal |
Adapt to one target task |
Adapt to many tasks quickly |
Training |
Pre-train + fine-tune |
Learn across tasks |
Adaptation |
Many examples (1000s) |
Few examples (1-10) |
Optimization |
Single-level |
Bi-level |
9.2 vs. Metric LearningΒΆ
Approach |
Method |
Pros |
Cons |
|---|---|---|---|
MAML |
Learn initialization |
General, model-agnostic |
Slow, second-order |
Prototypical |
Learn embedding + nearest neighbor |
Fast, simple |
Fixed comparison metric |
Matching Nets |
Attention over support set |
Fast inference |
Less general |
Hybrid: MAML for embedding, then metric comparison.
9.3 vs. Multi-Task LearningΒΆ
Multi-task: Shared parameters for all tasks simultaneously.
ΞΈ = argmin Ξ£_i L_i(ΞΈ)
Meta-learning: Learn initialization, then adapt per-task.
ΞΈ = argmin Ξ£_i L_i(ΞΈ - Ξ± β_ΞΈ L_i(ΞΈ))
Meta-learning advantage: Better for dissimilar tasks (no negative transfer).
10. Advanced TopicsΒΆ
10.1 Task Distribution ShiftΒΆ
Problem: Test tasks differ from training tasks.
Solutions:
Domain randomization: Diverse training tasks
Meta-regularization: Penalize overfitting to training tasks
Uncertainty estimation: Detect out-of-distribution tasks
10.2 Online Meta-LearningΒΆ
Scenario: Tasks arrive sequentially, no revisiting.
Approach: Update ΞΈ after each task:
ΞΈ_{t+1} = ΞΈ_t - Ξ² β_ΞΈ L_{T_t}(ΞΈ'_t)
Challenge: Catastrophic forgetting of earlier tasks.
Solution: Replay buffer of past tasks.
10.3 Hierarchical Meta-LearningΒΆ
Multiple levels of adaptation:
Global meta-parameters ΞΈ_0
Domain-specific ΞΈ_d (e.g., visual vs. audio)
Task-specific ΞΈ_t
Training: Nested bi-level optimization.
10.4 Meta-Learning with Pre-trained ModelsΒΆ
Modern approach: Initialize MAML with pre-trained features (e.g., ImageNet).
Procedure:
Load ΞΈ_pretrained from large-scale pre-training
Meta-train ΞΈ starting from ΞΈ_pretrained
Fine-tune on few-shot tasks
Benefit: Best of both worlds (pre-training + meta-learning).
11. Connections to Other FieldsΒΆ
11.1 Neural Architecture SearchΒΆ
Meta-learning NAS: Learn search strategy that generalizes across datasets.
MAML for NAS: Adapt architecture quickly to new dataset.
11.2 Bayesian OptimizationΒΆ
Task: Optimize black-box function with few evaluations.
Meta-BO: Learn GP kernel or acquisition function from related tasks.
11.3 Continual LearningΒΆ
Overlap: Both deal with learning from sequence of tasks.
Difference:
Meta-learning: Assumes task distribution, goal is fast adaptation
Continual learning: No repeated tasks, goal is avoid forgetting
Synergy: Meta-learned initialization resists catastrophic forgetting.
12. Recent Advances (2020-2024)ΒΆ
12.1 Transformer-Based Meta-LearningΒΆ
In-context learning (GPT-3 style):
Concatenate support examples in prompt
No gradient-based adaptation!
Relation to MAML: Implicit meta-learning during pre-training.
12.2 Meta-Learning for PromptsΒΆ
Prompt tuning: Learn soft prompts for language models.
Meta-prompt learning: MAML over prompt parameters.
12.3 Bi-level Optimization Beyond MAMLΒΆ
Applications:
Hyperparameter optimization: Inner = training, outer = validation loss
Data distillation: Inner = train on synthetic, outer = test on real
Neural architecture search: Inner = train model, outer = architecture loss
12.4 Federated Meta-LearningΒΆ
Scenario: Meta-learn from decentralized data (privacy-preserving).
Approach:
Local adaptation on each client
Aggregate meta-gradients on server
Challenge: Communication cost, heterogeneous data.
13. Implementation TipsΒΆ
13.1 Debugging MAMLΒΆ
Common issues:
Exploding gradients: Use gradient clipping, lower Ξ± or Ξ²
No improvement: Check inner loop is actually reducing loss
Overfitting: More tasks in meta-train set, data augmentation
Slow convergence: Try MAML++ (multi-step loss), increase batch size
Sanity checks:
Adapted model should outperform initialization on support set
Meta-loss should decrease over meta-iterations
Test on simple task (e.g., sine regression) first
13.2 Efficient ImplementationΒΆ
Batching:
Parallelize inner loops across tasks
Use
vmap(JAX) orfunctorch(PyTorch) for efficient gradient computation
Mixed precision: Train in FP16 to reduce memory (with gradient scaling).
Distributed: Split tasks across GPUs, aggregate meta-gradients.
13.3 Choosing Inner Steps KΒΆ
Trade-off:
K small (1-3): Fast, less adaptation, more meta-learning
K large (5-10): Better task performance, slower, risk overfitting
Rule of thumb: K β number of gradient steps needed to see loss decrease on task.
14. Complexity AnalysisΒΆ
14.1 Time ComplexityΒΆ
Per meta-iteration with B tasks, K inner steps, |ΞΈ| parameters:
MAML: O(B Β· K Β· |ΞΈ|Β²) (Hessian-vector products) FOMAML: O(B Β· K Β· |ΞΈ|) Reptile: O(B Β· K Β· |ΞΈ|)
Typical values:
B = 4-32 tasks
K = 1-10 steps
|ΞΈ| = 10β΄-10β· parameters
Example: ResNet-18 (11M params), B=8, K=5:
MAML: ~60GB memory, ~10s per iteration (GPU)
FOMAML: ~15GB memory, ~2s per iteration
14.2 Sample ComplexityΒΆ
Meta-train: Requires N_tasks tasks, each with support + query sets.
Guideline: N_tasks β₯ 100-1000 for good generalization.
Few-shot episodes: Generate via sampling from base dataset.
Example (Omniglot):
1623 characters (tasks)
20 examples per character
5-way 1-shot: 5 classes Γ (1 support + 15 query) = 80 examples per episode
15. LimitationsΒΆ
Computational cost: Second-order gradients expensive
Hyperparameter sensitivity: Ξ±, Ξ², K require tuning
Task diversity: Needs sufficient variation in p(T)
Failure modes: Can collapse to memorization or initialization-only
Non-stationarity: Struggles if task distribution shifts
When NOT to use MAML:
Very few total tasks (N < 20): Just multi-task learning
Many-shot regime (100+ examples): Standard transfer learning better
Tasks too dissimilar: No shared structure to meta-learn
16. Software and LibrariesΒΆ
16.1 ImplementationsΒΆ
PyTorch:
learn2learn: High-level MAML APIhigher: Functional programming for bi-level optimizationtorchmeta: Datasets + benchmarks
TensorFlow:
tensorflow-maml: Official implementation
JAX:
Native support for
grad(grad(...))(second-order)Efficient with
vmapfor batching
16.2 BenchmarksΒΆ
Classification:
Omniglot: 1623 handwritten characters (5-way 1-shot)
Mini-ImageNet: 100 classes, 600 images each (5-way 5-shot)
Tiered-ImageNet: Hierarchical version of ImageNet
Regression:
Sinusoid: Sample amplitude, phase, frequency
Polynomial: Random coefficients
Reinforcement Learning:
MuJoCo: HalfCheetah, Ant with varying dynamics
Meta-World: 50 robotic manipulation tasks
17. Key TakeawaysΒΆ
MAML learns initialization for fast adaptation via gradient descent
Bi-level optimization: Inner loop (task adaptation) + outer loop (meta-update)
Second-order: Gradients flow through inner optimization (FOMAML approximates)
Model-agnostic: Works with any gradient-based model
Few-shot learning: Excel with 1-10 examples per task
Trade-offs: Computational cost vs. adaptation speed vs. generalization
Extensions: MAML++, Reptile, Meta-SGD improve on vanilla MAML
Modern use: Combined with pre-training, prompt learning, transformers
Intuition: MAML finds ΞΈ such that gradient descent quickly moves toward good solutions across tasks. Itβs learning to learn via gradient descent.
18. Mathematical SummaryΒΆ
MAML objective:
min_ΞΈ E_{T~p(T)} [L_T(ΞΈ - Ξ± β_ΞΈ L_T^train(ΞΈ))]
Meta-gradient:
β_ΞΈ L_T^test(ΞΈ') = β_{ΞΈ'} L_T^test(ΞΈ') Β· (I - Ξ± βΒ²_ΞΈ L_T^train(ΞΈ))
βββββββββββββββ
Hessian term
FOMAML approximation:
β_ΞΈ L_T^test(ΞΈ') β β_{ΞΈ'} L_T^test(ΞΈ')
Update rule:
ΞΈ β ΞΈ - Ξ² β_ΞΈ Ξ£_{T_i} L_{T_i}^test(ΞΈ'_i)
ReferencesΒΆ
Finn et al. (2017) βModel-Agnostic Meta-Learning for Fast Adaptation of Deep Networksβ
Nichol et al. (2018) βOn First-Order Meta-Learning Algorithms (Reptile)β
Antoniou et al. (2019) βHow to Train Your MAML (MAML++)β
Li et al. (2017) βMeta-SGD: Learning to Learn Quickly for Few-Shot Learningβ
Raghu et al. (2020) βRapid Learning or Feature Reuse? (ANIL)β
Rajeswaran et al. (2019) βMeta-Learning with Implicit Gradients (iMAML)β
Zintgraf et al. (2019) βFast Context Adaptation via Meta-Learning (CAVIA)β
Hospedales et al. (2021) βMeta-Learning in Neural Networks: A Surveyβ
"""
Complete MAML and Meta-Learning Implementations
===============================================
Includes: MAML, FOMAML, Reptile, MAML++, ANIL, few-shot classification,
sine wave regression, metric comparison.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
import matplotlib.pyplot as plt
# ============================================================================
# 1. Utility Functions for MAML
# ============================================================================
def clone_parameters(model):
"""Create a copy of model parameters."""
return OrderedDict({
name: param.clone()
for name, param in model.named_parameters()
})
def set_parameters(model, params):
"""Set model parameters from OrderedDict."""
for name, param in model.named_parameters():
param.data = params[name].data
def get_grad_as_tensor(model):
"""Extract gradients as a single tensor."""
grads = []
for param in model.parameters():
if param.grad is not None:
grads.append(param.grad.view(-1))
else:
grads.append(torch.zeros_like(param).view(-1))
return torch.cat(grads)
# ============================================================================
# 2. Simple Convolutional Model for Few-Shot Learning
# ============================================================================
class ConvBlock(nn.Module):
"""Convolutional block with batch norm and ReLU."""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels, momentum=1.0) # No running stats
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
return self.pool(self.relu(self.bn(self.conv(x))))
class SimpleConvNet(nn.Module):
"""4-layer ConvNet for Omniglot/Mini-ImageNet."""
def __init__(self, input_channels=1, hidden_dim=64, output_dim=5):
super(SimpleConvNet, self).__init__()
self.features = nn.Sequential(
ConvBlock(input_channels, hidden_dim),
ConvBlock(hidden_dim, hidden_dim),
ConvBlock(hidden_dim, hidden_dim),
ConvBlock(hidden_dim, hidden_dim),
)
self.classifier = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
features = self.features(x)
features = features.view(features.size(0), -1)
return self.classifier(features)
# ============================================================================
# 3. MAML Algorithm
# ============================================================================
class MAML:
"""
Model-Agnostic Meta-Learning.
Args:
model: Neural network model
inner_lr: Learning rate for inner loop (task adaptation)
outer_lr: Learning rate for outer loop (meta-update)
inner_steps: Number of gradient steps in inner loop
first_order: If True, use FOMAML (no second-order gradients)
"""
def __init__(self, model, inner_lr=0.01, outer_lr=0.001,
inner_steps=5, first_order=False):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.inner_steps = inner_steps
self.first_order = first_order
# Outer optimizer
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)
def inner_loop(self, task_data, labels):
"""
Perform inner loop adaptation on a single task.
Args:
task_data: Support set data [K*N, ...]
labels: Support set labels [K*N]
Returns:
adapted_params: Adapted parameters after K gradient steps
"""
# Clone current parameters
adapted_params = clone_parameters(self.model)
# Create computational graph for adaptation
for step in range(self.inner_steps):
# Set model to adapted parameters
set_parameters(self.model, adapted_params)
# Forward pass
logits = self.model(task_data)
loss = F.cross_entropy(logits, labels)
# Compute gradients
grads = torch.autograd.grad(
loss,
self.model.parameters(),
create_graph=not self.first_order # Second-order if MAML
)
# Update adapted parameters
adapted_params = OrderedDict({
name: param - self.inner_lr * grad
for (name, param), grad in zip(adapted_params.items(), grads)
})
return adapted_params
def meta_update(self, tasks):
"""
Perform meta-update on a batch of tasks.
Args:
tasks: List of (support_data, support_labels, query_data, query_labels)
"""
self.meta_optimizer.zero_grad()
meta_loss = 0.0
meta_acc = 0.0
for support_x, support_y, query_x, query_y in tasks:
# Inner loop: Adapt to task
adapted_params = self.inner_loop(support_x, support_y)
# Set model to adapted parameters
set_parameters(self.model, adapted_params)
# Evaluate on query set
query_logits = self.model(query_x)
task_loss = F.cross_entropy(query_logits, query_y)
# Accumulate meta-loss
meta_loss += task_loss
# Compute accuracy
with torch.no_grad():
pred = query_logits.argmax(dim=1)
meta_acc += (pred == query_y).float().mean()
# Average over tasks
meta_loss = meta_loss / len(tasks)
meta_acc = meta_acc / len(tasks)
# Meta-gradient and update
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item(), meta_acc.item()
def evaluate(self, tasks):
"""Evaluate on validation/test tasks."""
total_loss = 0.0
total_acc = 0.0
with torch.no_grad():
for support_x, support_y, query_x, query_y in tasks:
# Adapt (no gradients needed)
for step in range(self.inner_steps):
logits = self.model(support_x)
loss = F.cross_entropy(logits, support_y)
# Manual gradient descent
grads = torch.autograd.grad(loss, self.model.parameters())
for param, grad in zip(self.model.parameters(), grads):
param.data -= self.inner_lr * grad
# Evaluate on query
query_logits = self.model(query_x)
task_loss = F.cross_entropy(query_logits, query_y)
total_loss += task_loss.item()
pred = query_logits.argmax(dim=1)
total_acc += (pred == query_y).float().mean().item()
return total_loss / len(tasks), total_acc / len(tasks)
# ============================================================================
# 4. Reptile Algorithm
# ============================================================================
class Reptile:
"""
Reptile: First-order meta-learning algorithm.
Simpler than MAML: Just move toward adapted parameters.
"""
def __init__(self, model, inner_lr=0.01, outer_lr=0.001, inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.inner_steps = inner_steps
def meta_update(self, tasks):
"""Meta-update via Reptile."""
init_params = clone_parameters(self.model)
total_loss = 0.0
total_acc = 0.0
for support_x, support_y, query_x, query_y in tasks:
# Reset to initial parameters
set_parameters(self.model, init_params)
# Inner loop: Adapt to task
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.inner_lr)
for step in range(self.inner_steps):
optimizer.zero_grad()
logits = self.model(support_x)
loss = F.cross_entropy(logits, support_y)
loss.backward()
optimizer.step()
# Evaluate on query (for logging)
with torch.no_grad():
query_logits = self.model(query_x)
task_loss = F.cross_entropy(query_logits, query_y)
total_loss += task_loss.item()
pred = query_logits.argmax(dim=1)
total_acc += (pred == query_y).float().mean().item()
# Meta-update: Move toward adapted parameters
adapted_params = clone_parameters(self.model)
for (name, init_param), (_, adapted_param) in zip(
init_params.items(), adapted_params.items()
):
init_param.data += self.outer_lr * (adapted_param.data - init_param.data)
# Set model to updated parameters
set_parameters(self.model, init_params)
return total_loss / len(tasks), total_acc / len(tasks)
# ============================================================================
# 5. ANIL (Almost No Inner Loop)
# ============================================================================
class ANIL:
"""
ANIL: Only adapt the final layer (head) during inner loop.
Feature extractor is frozen.
"""
def __init__(self, model, inner_lr=0.01, outer_lr=0.001, inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.inner_steps = inner_steps
# Separate feature extractor and classifier
self.features = model.features
self.classifier = model.classifier
# Outer optimizer (entire model)
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)
def inner_loop(self, task_data, labels):
"""Adapt only the classifier."""
# Extract features (frozen)
with torch.no_grad():
features = self.features(task_data)
features = features.view(features.size(0), -1)
# Clone classifier parameters
adapted_classifier = clone_parameters(self.classifier)
# Adapt classifier only
for step in range(self.inner_steps):
set_parameters(self.classifier, adapted_classifier)
logits = self.classifier(features)
loss = F.cross_entropy(logits, labels)
grads = torch.autograd.grad(loss, self.classifier.parameters(), create_graph=True)
adapted_classifier = OrderedDict({
name: param - self.inner_lr * grad
for (name, param), grad in zip(adapted_classifier.items(), grads)
})
return adapted_classifier
def meta_update(self, tasks):
"""Meta-update on tasks."""
self.meta_optimizer.zero_grad()
meta_loss = 0.0
meta_acc = 0.0
for support_x, support_y, query_x, query_y in tasks:
# Adapt classifier
adapted_classifier = self.inner_loop(support_x, support_y)
# Evaluate on query
with torch.no_grad():
query_features = self.features(query_x)
query_features = query_features.view(query_features.size(0), -1)
set_parameters(self.classifier, adapted_classifier)
query_logits = self.classifier(query_features)
task_loss = F.cross_entropy(query_logits, query_y)
meta_loss += task_loss
with torch.no_grad():
pred = query_logits.argmax(dim=1)
meta_acc += (pred == query_y).float().mean()
meta_loss = meta_loss / len(tasks)
meta_acc = meta_acc / len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item(), meta_acc.item()
# ============================================================================
# 6. Sine Wave Regression (Classic MAML Demo)
# ============================================================================
class SineWaveDataset:
"""Generate sine wave tasks for regression."""
def __init__(self, num_tasks=1000, k_shot=10, q_query=10):
self.num_tasks = num_tasks
self.k_shot = k_shot
self.q_query = q_query
def sample_task(self):
"""Sample a sine wave task with random amplitude and phase."""
# Random amplitude [0.1, 5.0], phase [0, Ο]
amplitude = np.random.uniform(0.1, 5.0)
phase = np.random.uniform(0, np.pi)
# Sample x uniformly from [-5, 5]
x = np.random.uniform(-5, 5, self.k_shot + self.q_query)
y = amplitude * np.sin(x + phase)
# Split into support and query
support_x = x[:self.k_shot]
support_y = y[:self.k_shot]
query_x = x[self.k_shot:]
query_y = y[self.k_shot:]
# Convert to tensors
support_x = torch.tensor(support_x, dtype=torch.float32).unsqueeze(1)
support_y = torch.tensor(support_y, dtype=torch.float32).unsqueeze(1)
query_x = torch.tensor(query_x, dtype=torch.float32).unsqueeze(1)
query_y = torch.tensor(query_y, dtype=torch.float32).unsqueeze(1)
return support_x, support_y, query_x, query_y
def __iter__(self):
for _ in range(self.num_tasks):
yield self.sample_task()
class SineModel(nn.Module):
"""Simple MLP for sine wave regression."""
def __init__(self, hidden_dim=40):
super(SineModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x)
def train_maml_sine():
"""Train MAML on sine wave regression."""
print("="*70)
print("MAML Sine Wave Regression Demo")
print("="*70)
# Model
model = SineModel(hidden_dim=40)
# MAML
maml = MAML(
model,
inner_lr=0.01,
outer_lr=0.001,
inner_steps=5,
first_order=False
)
# Dataset
train_dataset = SineWaveDataset(num_tasks=10000, k_shot=10, q_query=10)
# Meta-training
print("Meta-training MAML...")
num_iterations = 100
batch_size = 4
task_iter = iter(train_dataset)
for iteration in range(num_iterations):
# Sample batch of tasks
tasks = [next(task_iter) for _ in range(batch_size)]
# Meta-update
loss, acc = maml.meta_update(tasks)
if (iteration + 1) % 20 == 0:
print(f" Iteration {iteration+1}: Meta-loss = {loss:.4f}")
# Test: Adapt to new sine wave
print("\nTesting on new sine wave...")
test_task = SineWaveDataset(num_tasks=1, k_shot=10, q_query=100).sample_task()
support_x, support_y, query_x, query_y = test_task
# Before adaptation
with torch.no_grad():
pred_before = model(query_x)
mse_before = F.mse_loss(pred_before, query_y).item()
# Adapt
adapted_params = maml.inner_loop(support_x, support_y)
set_parameters(model, adapted_params)
# After adaptation
with torch.no_grad():
pred_after = model(query_x)
mse_after = F.mse_loss(pred_after, query_y).item()
print(f" MSE before adaptation: {mse_before:.4f}")
print(f" MSE after adaptation: {mse_after:.4f}")
print(f" Improvement: {mse_before / mse_after:.2f}Γ")
print()
# ============================================================================
# 7. Method Comparison
# ============================================================================
def print_method_comparison():
"""Print comparison of meta-learning algorithms."""
print("="*70)
print("Meta-Learning Algorithms Comparison")
print("="*70)
print()
comparison = """
βββββββββββββββ¬βββββββββββββββ¬βββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Order β Speed β Memory β Performance β
βββββββββββββββΌβββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β MAML β Second-order β Slow β High (graph) β Best β
β β (Hessian) β β β β
βββββββββββββββΌβββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β FOMAML β First-order β Medium β Medium β Good β
β β β β β β
βββββββββββββββΌβββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β Reptile β First-order β Fast β Low β Good β
β β (simpler) β β β β
βββββββββββββββΌβββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β ANIL β Second-order β Fast β Medium β Good (if β
β β (head only) β β β features OK) β
βββββββββββββββΌβββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β MAML++ β Second-order β Slow β High β Best (tuned) β
β β (enhanced) β β β β
βββββββββββββββ΄βββββββββββββββ΄βββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
**Computational Complexity (per meta-iteration):**
- MAML: O(B Β· K Β· |ΞΈ|Β²) [B=tasks, K=steps, |ΞΈ|=params]
- FOMAML: O(B Β· K Β· |ΞΈ|)
- Reptile: O(B Β· K Β· |ΞΈ|)
- ANIL: O(B Β· K Β· |ΞΈ_head|Β²) where |ΞΈ_head| << |ΞΈ|
**When to Use:**
- **MAML**: Best performance, have computational budget
- **FOMAML**: Good balance of speed and performance
- **Reptile**: Simplest to implement, fastest training
- **ANIL**: Fast adaptation, when features are pre-trained
- **MAML++**: Production use, worth hyperparameter tuning
**Inner Loop Steps K:**
- K=1: Minimal adaptation, very fast
- K=5: Standard choice, good balance
- K=10+: Better adaptation, slower, risk overfitting
**Typical Results (5-way 1-shot Omniglot):**
- Random: 20% accuracy
- Fine-tuning: 50-60%
- MAML: 95-98%
- FOMAML: 93-95%
- Reptile: 92-94%
**Implementation Tips:**
1. **Start with FOMAML**: Simpler, good baseline
2. **Use gradient clipping**: Stability crucial
3. **Batch size**: 4-16 tasks per meta-update
4. **Learning rates**: Ξ±=0.01 (inner), Ξ²=0.001 (outer)
5. **Warm-up**: Lower outer LR for first 1000 iterations
"""
print(comparison)
print()
def print_complexity_analysis():
"""Print detailed complexity analysis."""
print("="*70)
print("Complexity Analysis")
print("="*70)
print()
print("**Example: ResNet-18 (11M parameters)**")
print()
print("Configuration:")
print(" β’ Batch size B = 8 tasks")
print(" β’ Inner steps K = 5")
print(" β’ Parameters |ΞΈ| = 11M")
print()
print("Time per meta-iteration (GPU):")
print(" β’ MAML: ~10s (second-order gradients)")
print(" β’ FOMAML: ~2s (first-order)")
print(" β’ Reptile: ~1.5s (no backprop through inner loop)")
print(" β’ ANIL: ~0.5s (only adapt final layer)")
print()
print("Memory usage:")
print(" β’ MAML: ~60GB (stores computational graph for K steps)")
print(" β’ FOMAML: ~15GB (no Hessian)")
print(" β’ Reptile: ~12GB (minimal graph)")
print(" β’ ANIL: ~8GB (small adaptation)")
print()
print("Speedup techniques:")
print(" β’ Mixed precision (FP16): 2Γ faster, 1/2 memory")
print(" β’ Gradient checkpointing: 1/K memory, ~1.5Γ slower")
print(" β’ Distributed (multi-GPU): Linear speedup in B")
print()
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
torch.manual_seed(42)
np.random.seed(42)
train_maml_sine()
print_method_comparison()
print_complexity_analysis()
print("="*70)
print("MAML and Meta-Learning Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ MAML: Bi-level optimization for fast adaptation")
print(" β’ FOMAML: First-order approximation (faster)")
print(" β’ Reptile: Simpler algorithm moving toward adapted params")
print(" β’ ANIL: Only adapt final layer (faster)")
print(" β’ Sine regression: Classic MAML demonstration")
print()
print("Key insight: Learn initialization ΞΈ that enables")
print(" fast adaptation via gradient descent")
print("Trade-off: Computational cost vs. adaptation speed")
print("Applications: Few-shot learning, RL, regression")
print()