import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Meta-Learning ProblemΒΆ

GoalΒΆ

Learn from distribution of tasks \(p(\mathcal{T})\).

Each task \(\mathcal{T}_i\) has:

  • Support set \(D_i^{\text{train}}\) (K-shot)

  • Query set \(D_i^{\text{test}}\)

MAML ObjectiveΒΆ

\[\min_\theta \mathbb{E}_{\mathcal{T}_i}\left[\mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'})\right]\]

where \(\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)\) (inner loop)

2. MAML AlgorithmΒΆ

Inner Loop (Task Adaptation)ΒΆ

For task \(\mathcal{T}_i\):

\[\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{D_i^{\text{train}}}(f_\theta)\]

Outer Loop (Meta-Update)ΒΆ

\[\theta \leftarrow \theta - \beta \nabla_\theta \sum_i \mathcal{L}_{D_i^{\text{test}}}(f_{\theta_i'})\]

Key: Second-order gradients through inner loop!

class SineTaskDistribution:
    """Sinusoid regression tasks."""
    
    def sample_task(self):
        amp = np.random.uniform(0.1, 5.0)
        phase = np.random.uniform(0, np.pi)
        return amp, phase
    
    def sample_data(self, task, K=10):
        amp, phase = task
        x = np.random.uniform(-5, 5, K)
        y = amp * np.sin(x + phase)
        return torch.tensor(x, dtype=torch.float32).unsqueeze(1), \
               torch.tensor(y, dtype=torch.float32).unsqueeze(1)

# Test
task_dist = SineTaskDistribution()
task = task_dist.sample_task()
x, y = task_dist.sample_data(task, K=10)
print(f"Task: amp={task[0]:.2f}, phase={task[1]:.2f}")
print(f"Data: x.shape={x.shape}, y.shape={y.shape}")
class MAMLModel(nn.Module):
    """Simple MLP for MAML."""
    
    def __init__(self, input_dim=1, hidden_dim=40, output_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def clone_params(self):
        return [p.clone() for p in self.parameters()]
    
    def set_params(self, params):
        for p, p_new in zip(self.parameters(), params):
            p.data = p_new.data

model = MAMLModel().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
def maml_inner_loop(model, x_support, y_support, alpha, steps=1):
    """Adapt to task (inner loop)."""
    params = [p.clone() for p in model.parameters()]
    
    for _ in range(steps):
        # Forward
        model.set_params(params)
        pred = model(x_support)
        loss = F.mse_loss(pred, y_support)
        
        # Gradient descent
        grads = torch.autograd.grad(loss, params, create_graph=True)
        params = [p - alpha * g for p, g in zip(params, grads)]
    
    return params

def maml_train_step(model, task_dist, n_tasks=4, K_support=10, K_query=10, 
                   alpha=0.01, beta=0.001, inner_steps=1):
    """Single MAML meta-update."""
    meta_loss = 0
    
    for _ in range(n_tasks):
        # Sample task
        task = task_dist.sample_task()
        x_support, y_support = task_dist.sample_data(task, K=K_support)
        x_query, y_query = task_dist.sample_data(task, K=K_query)
        
        x_support, y_support = x_support.to(device), y_support.to(device)
        x_query, y_query = x_query.to(device), y_query.to(device)
        
        # Inner loop
        adapted_params = maml_inner_loop(model, x_support, y_support, alpha, inner_steps)
        
        # Evaluate on query
        model.set_params(adapted_params)
        pred_query = model(x_query)
        loss_query = F.mse_loss(pred_query, y_query)
        
        meta_loss += loss_query
    
    meta_loss /= n_tasks
    
    # Meta-update
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=beta)
    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()
    
    return meta_loss.item()

TrainingΒΆ

MAML training (the meta-training or outer loop) proceeds over many randomly sampled tasks. For each task, the model takes \(k\) gradient steps on the task’s support set to produce task-specific parameters, then evaluates on the task’s query set. The meta-gradient is computed by differentiating through the inner optimization – this requires second-order gradients (gradients of gradients), which is computationally expensive but essential for MAML’s effectiveness. First-order approximations like FOMAML drop the second derivatives for efficiency with only modest performance loss. The outer loop optimizer (typically Adam) updates the initialization parameters to minimize the average query-set loss across tasks.

# Train MAML
model = MAMLModel().to(device)
task_dist = SineTaskDistribution()

losses = []
for iteration in range(200):
    loss = maml_train_step(model, task_dist, n_tasks=4, K_support=10, K_query=10,
                          alpha=0.01, beta=0.001, inner_steps=1)
    losses.append(loss)
    
    if (iteration + 1) % 50 == 0:
        print(f"Iter {iteration+1}, Loss: {loss:.4f}")

plt.figure(figsize=(8, 5))
plt.plot(losses)
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Meta Loss', fontsize=11)
plt.title('MAML Training', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

Few-Shot AdaptationΒΆ

The payoff of MAML is at meta-test time: given a new, unseen task with only a handful of labeled examples (the support set), we take a few gradient steps from the learned initialization and evaluate on the query set. Because MAML has learned an initialization that is maximally sensitive to task-relevant features, even 1-5 gradient steps are enough to achieve strong performance on the new task. This is fundamentally different from training from scratch or fine-tuning a pre-trained model – MAML explicitly optimizes for fast adaptation, making it especially powerful for applications where labeled data is scarce, such as drug discovery, robotics, and personalized recommendation.

# Test on new task
test_task = task_dist.sample_task()
x_support, y_support = task_dist.sample_data(test_task, K=5)
x_support, y_support = x_support.to(device), y_support.to(device)

# Test points
x_test = torch.linspace(-5, 5, 100).unsqueeze(1).to(device)
y_true = test_task[0] * torch.sin(x_test + test_task[1])

# Before adaptation
model.eval()
with torch.no_grad():
    y_before = model(x_test)

# After adaptation (5-shot)
adapted_params = maml_inner_loop(model, x_support, y_support, alpha=0.01, steps=5)
model.set_params(adapted_params)
with torch.no_grad():
    y_after = model(x_test)

# Plot
plt.figure(figsize=(10, 6))
plt.plot(x_test.cpu(), y_true.cpu(), 'k-', label='True', linewidth=2)
plt.plot(x_test.cpu(), y_before.cpu(), 'b--', label='Before (0-shot)', alpha=0.7)
plt.plot(x_test.cpu(), y_after.cpu(), 'r-', label='After (5-shot)', linewidth=2)
plt.scatter(x_support.cpu(), y_support.cpu(), s=100, c='red', marker='x', label='Support', zorder=5)
plt.xlabel('x', fontsize=11)
plt.ylabel('y', fontsize=11)
plt.title('MAML Few-Shot Adaptation', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

SummaryΒΆ

MAML:ΒΆ

Inner loop: Task adaptation via gradient descent

Outer loop: Meta-learning across tasks

Algorithm:ΒΆ

  1. Sample tasks from \(p(\mathcal{T})\)

  2. Adapt: \(\theta' = \theta - \alpha \nabla \mathcal{L}_{\text{train}}\)

  3. Meta-update: \(\theta \leftarrow \theta - \beta \nabla \mathcal{L}_{\text{test}}(\theta')\)

Key Features:ΒΆ

  • Model-agnostic (works with any gradient-based model)

  • Few-shot adaptation

  • Second-order gradients

Applications:ΒΆ

  • Few-shot classification

  • Rapid RL adaptation

  • Personalization

  • Drug discovery

Variants:ΒΆ

  • Reptile: First-order approximation

  • ANIL: Almost no inner loop

  • Meta-SGD: Learn inner LR

Next Steps:ΒΆ

  • Explore Prototypical Networks

  • Apply to image classification

  • Study task distribution design

Advanced Meta-Learning TheoryΒΆ

1. MAML Bi-Level Optimization FrameworkΒΆ

Mathematical FormulationΒΆ

Meta-learning aims to learn a good initialization \(\theta\) that can quickly adapt to new tasks. The bi-level optimization problem is:

\[\min_{\theta} \mathbb{E}_{\mathcal{T}_i \sim p(\mathcal{T})} \left[ \mathcal{L}_{\mathcal{T}_i}^{\text{test}}(f_{\theta_i'}) \right]\]

where the inner loop (task adaptation) computes:

\[\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}^{\text{train}}(f_\theta)\]

and the outer loop (meta-update) performs:

\[\theta \leftarrow \theta - \beta \nabla_\theta \mathbb{E}_{\mathcal{T}_i} \left[ \mathcal{L}_{\mathcal{T}_i}^{\text{test}}(f_{\theta_i'}) \right]\]

Second-Order GradientsΒΆ

The key challenge: \(\theta_i'\) depends on \(\theta\), so computing \(\nabla_\theta \mathcal{L}^{\text{test}}(f_{\theta_i'})\) requires the chain rule through the inner loop:

\[\nabla_\theta \mathcal{L}^{\text{test}}(f_{\theta_i'}) = \nabla_{\theta_i'} \mathcal{L}^{\text{test}}(f_{\theta_i'}) \cdot \frac{\partial \theta_i'}{\partial \theta}\]

For one inner step:

\[\frac{\partial \theta_i'}{\partial \theta} = I - \alpha \frac{\partial}{\partial \theta} \nabla_\theta \mathcal{L}^{\text{train}} = I - \alpha H_{\mathcal{L}^{\text{train}}}\]

where \(H\) is the Hessian matrix of the loss w.r.t. \(\theta\). This is computationally expensive!

First-Order MAML (FOMAML)ΒΆ

Approximate by ignoring second-order terms:

\[\nabla_\theta \mathcal{L}^{\text{test}}(f_{\theta_i'}) \approx \nabla_{\theta_i'} \mathcal{L}^{\text{test}}(f_{\theta_i'})\]

i.e., treat \(\theta_i'\) as independent of \(\theta\) for gradient computation. FOMAML is much faster with minimal performance loss in practice.

2. Prototypical NetworksΒΆ

Metric Learning FrameworkΒΆ

Instead of learning task-specific parameters, learn an embedding function \(f_\phi: \mathcal{X} \to \mathbb{R}^d\) that maps inputs to a metric space where classification uses distance to class prototypes.

Class PrototypesΒΆ

For N-way K-shot classification, compute prototype for class \(k\) as:

\[c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\phi(x_i)\]

where \(S_k\) is the support set for class \(k\) (K examples).

Classification RuleΒΆ

Probability that query \(x\) belongs to class \(k\):

\[p(y = k | x) = \frac{\exp(-d(f_\phi(x), c_k))}{\sum_{k'} \exp(-d(f_\phi(x), c_{k'}))}\]

where \(d(\cdot, \cdot)\) is a distance metric (typically Euclidean distance or cosine similarity).

Training ObjectiveΒΆ

Minimize negative log-likelihood:

\[\mathcal{L}(\phi) = -\log p(y | x) = -\log \frac{\exp(-d(f_\phi(x), c_y))}{\sum_{k} \exp(-d(f_\phi(x), c_k))}\]

Episode-based training: Each training episode samples N classes, K support examples per class, and Q query examples.

Why It WorksΒΆ

  • Inductive bias: Similar examples should have similar embeddings

  • Non-parametric: No task-specific parameters (prototypes computed from data)

  • Fast adaptation: Single forward pass, no gradient steps

  • Scalable: Works with large N (many classes)

3. Matching NetworksΒΆ

Attention-Based MatchingΒΆ

Matching Networks classify by comparing query to all support examples using attention:

\[\hat{y} = \sum_{i=1}^{|S|} a(x, x_i) y_i\]

where \(a(x, x_i)\) is the attention weight from query \(x\) to support example \(x_i\):

\[a(x, x_i) = \frac{\exp(c(f(x), g(x_i)))}{\sum_{j} \exp(c(f(x), g(x_j)))}\]

with \(c(\cdot, \cdot)\) being cosine similarity.

Full Context EmbeddingsΒΆ

Unlike Prototypical Networks, embeddings use full context of the support set:

  • Support encoding \(g\): Bidirectional LSTM over support set

  • Query encoding \(f\): LSTM with read attention to support set

This allows embeddings to be task-dependent.

One-Shot Learning FormulationΒΆ

For one-shot (K=1), Matching Networks directly weight support labels:

\[P(y | x, S) = \sum_{(x_i, y_i) \in S} a(x, x_i) y_i\]

This is a differentiable nearest neighbors approach.

4. Few-Shot Learning TheoryΒΆ

N-Way K-Shot ProblemΒΆ

Problem setup:

  • N-way: Classify among N classes

  • K-shot: Only K labeled examples per class in support set

  • Task distribution: \(p(\mathcal{T})\) generates diverse tasks

Episode construction:

  1. Sample N classes from dataset

  2. Sample K examples per class β†’ Support set \(S\) (NΓ—K examples)

  3. Sample Q examples per class β†’ Query set \(Q\) (NΓ—Q examples)

Meta-Training vs Meta-TestingΒΆ

Meta-training:

  • Tasks sampled from training classes \(\mathcal{C}_{\text{train}}\)

  • Learn \(\phi\) or \(\theta\) to perform well on query sets after adapting to support sets

Meta-testing:

  • Tasks sampled from disjoint test classes \(\mathcal{C}_{\text{test}}\)

  • Evaluate few-shot performance on unseen classes

Key: \(\mathcal{C}_{\text{train}} \cap \mathcal{C}_{\text{test}} = \emptyset\) (no class overlap)

Generalization in Meta-LearningΒΆ

Meta-learning generalizes across tasks rather than across examples:

\[\text{Expected risk:} \quad \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} \left[ R_{\mathcal{T}}(h_{\mathcal{T}}) \right]\]

where \(h_{\mathcal{T}}\) is the task-specific hypothesis (e.g., adapted parameters or prototypes).

Challenge: Need sufficient task diversity in \(p(\mathcal{T})\) to generalize to new tasks.

5. Comparison of Meta-Learning ApproachesΒΆ

Approach

Adaptation

Parameters

Computation

Strengths

Weaknesses

MAML

Gradient descent (inner loop)

Task-specific \(\theta'\)

High (second-order gradients)

Model-agnostic, flexible

Slow adaptation, expensive

FOMAML

Gradient descent (1st order)

Task-specific \(\theta'\)

Medium

Faster than MAML

Slightly lower performance

Prototypical

Non-parametric (prototypes)

None (prototypes from data)

Low (single forward pass)

Fast, simple, scalable to many classes

Assumes Euclidean structure

Matching

Attention over support

None (attention weights)

Medium (bi-LSTM encoding)

Full context, differentiable NN

Complex encoding, slower

Relation Net

Learn comparison metric

Task-specific relation module

Medium

Flexible metric

Requires meta-training

When to Use Each:ΒΆ

  • MAML/FOMAML: When you need flexible adaptation with gradient-based learning; good for RL and diverse tasks

  • Prototypical Networks: When classes have clear cluster structure; best for classification with many classes

  • Matching Networks: When task context is important; one-shot learning scenarios

  • Hybrid: Combine approaches (e.g., Prototypical MAML)

6. Advanced InsightsΒΆ

Sample ComplexityΒΆ

Few-shot learning aims to achieve high accuracy with limited data:

  • Traditional: \(O(VC \text{ dim} / \epsilon^2)\) samples needed for generalization

  • Meta-learning: Amortizes learning across tasks, reducing per-task sample complexity

Trade-off: More meta-training tasks β†’ better few-shot performance per task

Task DiversityΒΆ

Performance depends critically on task distribution \(p(\mathcal{T})\):

  • High diversity: Better generalization to new tasks

  • Task relatedness: Meta-learning assumes tasks share structure

Theorem (informal): If \(\mathcal{T}_{\text{test}}\) is far from \(\mathcal{T}_{\text{train}}\) in task space, meta-learning provides no benefit over learning from scratch.

Computational ComplexityΒΆ

For N-way K-shot with embedding dimension \(d\):

Method

Time per episode

Memory

MAML

$O(T \cdot N \cdot K \cdot

\theta

FOMAML

$O(T \cdot N \cdot K \cdot

\theta

Prototypical

\(O(N \cdot K \cdot d + N \cdot Q \cdot d)\)

\(O(N \cdot d)\) (prototypes)

Matching

\(O(N \cdot K \cdot d^2)\) (bi-LSTM)

\(O(N \cdot K \cdot d)\)

\(T\) = inner loop steps, \(|\theta|\) = number of parameters, \(Q\) = query set size

7. Practical ConsiderationsΒΆ

Dataset DesignΒΆ

Common benchmarks:

  • Omniglot: 1623 characters, 20 examples each (character recognition)

  • Mini-ImageNet: 100 classes, 600 images each (5-way 1-shot / 5-shot)

  • Tiered-ImageNet: 608 classes, hierarchical structure

Episode sampling: Ensure balanced classes and sufficient query examples

Hyperparameter TuningΒΆ

MAML:

  • Inner learning rate \(\alpha\): 0.01 - 0.1 (task-specific)

  • Outer learning rate \(\beta\): 0.001 - 0.01 (meta-update)

  • Inner steps \(T\): 1-5 (more steps β†’ better adaptation but slower)

Prototypical:

  • Embedding dimension \(d\): 64-512 (trade-off capacity vs. overfitting)

  • Distance metric: Euclidean (classification), cosine (semantic tasks)

OverfittingΒΆ

Symptoms:

  • High meta-training accuracy, low meta-test accuracy

  • Model memorizes support sets rather than learning to adapt

Solutions:

  • Increase task diversity (data augmentation, more classes)

  • Regularization (dropout, weight decay)

  • Early stopping on meta-validation set

PDF ReferenceΒΆ

πŸ“„ Related Papers:

πŸ“‚ GitHub:

# ============================================================================
# Advanced Meta-Learning Implementations
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import copy

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# 1. Complete MAML with Second-Order Gradients
# ============================================================================

class MAMLConvNet(nn.Module):
    """
    4-layer ConvNet for few-shot image classification.
    Used in Omniglot and Mini-ImageNet experiments.
    """
    def __init__(self, in_channels=1, num_classes=5, hidden_dim=64):
        super().__init__()
        self.features = nn.Sequential(
            OrderedDict([
                ('conv1', nn.Conv2d(in_channels, hidden_dim, 3, padding=1)),
                ('bn1', nn.BatchNorm2d(hidden_dim)),
                ('relu1', nn.ReLU(inplace=True)),
                ('pool1', nn.MaxPool2d(2)),
                
                ('conv2', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
                ('bn2', nn.BatchNorm2d(hidden_dim)),
                ('relu2', nn.ReLU(inplace=True)),
                ('pool2', nn.MaxPool2d(2)),
                
                ('conv3', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
                ('bn3', nn.BatchNorm2d(hidden_dim)),
                ('relu3', nn.ReLU(inplace=True)),
                ('pool3', nn.MaxPool2d(2)),
                
                ('conv4', nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)),
                ('bn4', nn.BatchNorm2d(hidden_dim)),
                ('relu4', nn.ReLU(inplace=True)),
                ('pool4', nn.MaxPool2d(2))
            ])
        )
        
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        return self.classifier(x)


class MAML:
    """
    Full MAML implementation with second-order gradients.
    
    Key features:
    - Bi-level optimization (inner loop + outer loop)
    - Second-order gradient computation through inner loop
    - Support for multiple inner gradient steps
    - Episode-based meta-training
    """
    def __init__(self, model, inner_lr=0.01, meta_lr=0.001, 
                 inner_steps=5, first_order=False):
        self.model = model
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.inner_steps = inner_steps
        self.first_order = first_order  # Use FOMAML if True
        
        # Meta-optimizer (for outer loop)
        self.meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=meta_lr)
    
    def inner_loop(self, x_support, y_support, params=None):
        """
        Task adaptation via gradient descent (inner loop).
        
        Args:
            x_support: Support set inputs [K, ...]
            y_support: Support set labels [K]
            params: Current model parameters (None uses model.parameters())
        
        Returns:
            adapted_params: Parameters after inner loop adaptation
        """
        if params is None:
            params = OrderedDict(self.model.named_parameters())
        
        for step in range(self.inner_steps):
            # Forward pass with current params
            logits = self._forward_with_params(x_support, params)
            loss = F.cross_entropy(logits, y_support)
            
            # Compute gradients w.r.t. params
            # create_graph=True enables second-order gradients
            grads = torch.autograd.grad(
                loss, params.values(), 
                create_graph=not self.first_order
            )
            
            # Inner loop update: ΞΈ' = ΞΈ - Ξ±βˆ‡L
            params = OrderedDict(
                (name, param - self.inner_lr * grad)
                for ((name, param), grad) in zip(params.items(), grads)
            )
        
        return params
    
    def _forward_with_params(self, x, params):
        """Forward pass using specific parameters (for inner loop)."""
        # Simple implementation for linear layers
        # For production, use functional API or hooks
        x = x.view(x.size(0), -1)  # Flatten
        for name, param in params.items():
            if 'weight' in name and 'fc' not in name:
                continue
            if 'fc.weight' in name:
                x = F.linear(x, param)
            elif 'fc.bias' in name:
                x = x + param
        return x
    
    def outer_loop(self, tasks):
        """
        Meta-update across multiple tasks (outer loop).
        
        Args:
            tasks: List of (x_support, y_support, x_query, y_query) tuples
        
        Returns:
            meta_loss: Average query loss across tasks
        """
        self.meta_optimizer.zero_grad()
        meta_loss = 0.0
        
        for x_support, y_support, x_query, y_query in tasks:
            # Inner loop: adapt to support set
            adapted_params = self.inner_loop(x_support, y_support)
            
            # Evaluate on query set with adapted params
            logits_query = self._forward_with_params(x_query, adapted_params)
            loss_query = F.cross_entropy(logits_query, y_query)
            
            meta_loss += loss_query
        
        # Average over tasks
        meta_loss = meta_loss / len(tasks)
        
        # Outer loop update: ΞΈ ← ΞΈ - Ξ²βˆ‡_ΞΈ L_query(ΞΈ')
        # This computes gradients through the inner loop (second-order)
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()


# ============================================================================
# 2. Prototypical Networks
# ============================================================================

class PrototypicalNetwork(nn.Module):
    """
    Prototypical Networks for few-shot classification.
    
    Learns an embedding function f_Ο† that maps inputs to a metric space.
    Classification uses distance to class prototypes computed from support set.
    """
    def __init__(self, in_channels=1, embedding_dim=64):
        super().__init__()
        
        # Embedding network (same architecture as MAML encoder)
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, embedding_dim, 3, padding=1),
            nn.BatchNorm2d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    
    def forward(self, x):
        """Compute embeddings."""
        embeddings = self.encoder(x)
        return embeddings.view(embeddings.size(0), -1)
    
    def compute_prototypes(self, x_support, y_support, n_way):
        """
        Compute class prototypes from support set.
        
        Prototype for class k: c_k = (1/|S_k|) Ξ£ f_Ο†(x_i) for (x_i, y_i) ∈ S_k
        
        Args:
            x_support: Support set inputs [N*K, C, H, W]
            y_support: Support set labels [N*K]
            n_way: Number of classes
        
        Returns:
            prototypes: Class prototypes [N, embedding_dim]
        """
        embeddings = self(x_support)  # [N*K, embedding_dim]
        
        prototypes = []
        for k in range(n_way):
            # Find all examples of class k
            class_mask = (y_support == k)
            class_embeddings = embeddings[class_mask]
            
            # Compute mean (prototype)
            prototype = class_embeddings.mean(dim=0)
            prototypes.append(prototype)
        
        return torch.stack(prototypes)  # [N, embedding_dim]
    
    def classify(self, x_query, prototypes, distance='euclidean'):
        """
        Classify queries using distance to prototypes.
        
        p(y = k | x) = exp(-d(f(x), c_k)) / Ξ£_k' exp(-d(f(x), c_k'))
        
        Args:
            x_query: Query inputs [Q, C, H, W]
            prototypes: Class prototypes [N, embedding_dim]
            distance: 'euclidean' or 'cosine'
        
        Returns:
            logits: Class logits [Q, N]
        """
        query_embeddings = self(x_query)  # [Q, embedding_dim]
        
        if distance == 'euclidean':
            # Negative squared Euclidean distance
            # -||f(x) - c_k||^2 = -||f(x)||^2 - ||c_k||^2 + 2<f(x), c_k>
            dists = torch.cdist(query_embeddings, prototypes, p=2)  # [Q, N]
            logits = -dists ** 2
        elif distance == 'cosine':
            # Cosine similarity
            query_norm = F.normalize(query_embeddings, p=2, dim=1)
            proto_norm = F.normalize(prototypes, p=2, dim=1)
            logits = query_norm @ proto_norm.t()  # [Q, N]
        else:
            raise ValueError(f"Unknown distance: {distance}")
        
        return logits
    
    def loss(self, x_support, y_support, x_query, y_query, n_way, distance='euclidean'):
        """Compute prototypical loss for an episode."""
        # Compute prototypes from support set
        prototypes = self.compute_prototypes(x_support, y_support, n_way)
        
        # Classify query set
        logits = self.classify(x_query, prototypes, distance)
        
        # Cross-entropy loss
        return F.cross_entropy(logits, y_query)


# ============================================================================
# 3. Matching Networks (Simplified)
# ============================================================================

class MatchingNetwork(nn.Module):
    """
    Matching Networks for one-shot learning.
    
    Uses attention over support set to classify queries.
    Simplified version without full context embeddings (bi-LSTM).
    """
    def __init__(self, in_channels=1, embedding_dim=64):
        super().__init__()
        
        # Shared embedding network
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, embedding_dim, 3, padding=1),
            nn.BatchNorm2d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    
    def forward(self, x):
        """Compute embeddings."""
        embeddings = self.encoder(x)
        return embeddings.view(embeddings.size(0), -1)
    
    def attention(self, query_embedding, support_embeddings):
        """
        Compute attention weights from query to support set.
        
        a(x, x_i) = exp(cosine(f(x), g(x_i))) / Ξ£_j exp(cosine(f(x), g(x_j)))
        
        Args:
            query_embedding: [embedding_dim]
            support_embeddings: [K, embedding_dim]
        
        Returns:
            attention_weights: [K]
        """
        # Cosine similarity
        query_norm = F.normalize(query_embedding, p=2, dim=0)
        support_norm = F.normalize(support_embeddings, p=2, dim=1)
        
        similarities = support_norm @ query_norm  # [K]
        
        # Softmax to get attention weights
        attention_weights = F.softmax(similarities, dim=0)
        
        return attention_weights
    
    def predict(self, x_query, x_support, y_support, n_way):
        """
        Predict using attention-weighted support labels.
        
        Ε· = Ξ£_i a(x, x_i) y_i
        
        Args:
            x_query: Query input [C, H, W]
            x_support: Support set inputs [N*K, C, H, W]
            y_support: Support set labels [N*K]
            n_way: Number of classes
        
        Returns:
            logits: Class probabilities [n_way]
        """
        query_embedding = self(x_query.unsqueeze(0)).squeeze(0)  # [embedding_dim]
        support_embeddings = self(x_support)  # [N*K, embedding_dim]
        
        # Compute attention weights
        attention_weights = self.attention(query_embedding, support_embeddings)  # [N*K]
        
        # Weighted sum over one-hot labels
        y_one_hot = F.one_hot(y_support, num_classes=n_way).float()  # [N*K, n_way]
        logits = (attention_weights.unsqueeze(1) * y_one_hot).sum(dim=0)  # [n_way]
        
        return logits
    
    def loss(self, x_support, y_support, x_query, y_query, n_way):
        """Compute matching loss for an episode."""
        batch_logits = []
        
        for i in range(x_query.size(0)):
            logits = self.predict(x_query[i], x_support, y_support, n_way)
            batch_logits.append(logits)
        
        batch_logits = torch.stack(batch_logits)  # [Q, n_way]
        
        return F.cross_entropy(batch_logits, y_query)


# ============================================================================
# 4. Few-Shot Episode Sampler
# ============================================================================

class FewShotEpisode:
    """
    Sample N-way K-shot episodes for meta-learning.
    
    Each episode contains:
    - Support set: N classes Γ— K examples
    - Query set: N classes Γ— Q examples
    """
    def __init__(self, images, labels, n_way=5, k_shot=1, q_query=15):
        """
        Args:
            images: All available images [num_samples, C, H, W]
            labels: Corresponding labels [num_samples]
            n_way: Number of classes per episode
            k_shot: Number of support examples per class
            q_query: Number of query examples per class
        """
        self.images = images
        self.labels = labels
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        
        # Group images by class
        self.classes = torch.unique(labels)
        self.class_to_indices = {
            c.item(): (labels == c).nonzero(as_tuple=True)[0]
            for c in self.classes
        }
    
    def sample_episode(self):
        """
        Sample a single N-way K-shot episode.
        
        Returns:
            x_support: [N*K, C, H, W]
            y_support: [N*K] (relabeled 0 to N-1)
            x_query: [N*Q, C, H, W]
            y_query: [N*Q] (relabeled 0 to N-1)
        """
        # Sample N classes
        episode_classes = np.random.choice(
            len(self.classes), self.n_way, replace=False
        )
        
        support_images, support_labels = [], []
        query_images, query_labels = [], []
        
        for new_label, class_idx in enumerate(episode_classes):
            class_label = self.classes[class_idx]
            indices = self.class_to_indices[class_label.item()]
            
            # Sample K+Q examples
            selected = indices[torch.randperm(len(indices))[:self.k_shot + self.q_query]]
            
            # Split into support and query
            support_indices = selected[:self.k_shot]
            query_indices = selected[self.k_shot:]
            
            support_images.append(self.images[support_indices])
            support_labels.extend([new_label] * self.k_shot)
            
            query_images.append(self.images[query_indices])
            query_labels.extend([new_label] * self.q_query)
        
        # Concatenate
        x_support = torch.cat(support_images, dim=0)
        y_support = torch.tensor(support_labels, dtype=torch.long)
        x_query = torch.cat(query_images, dim=0)
        y_query = torch.tensor(query_labels, dtype=torch.long)
        
        return x_support, y_support, x_query, y_query


# ============================================================================
# 5. Visualization: Meta-Learning Comparison
# ============================================================================

def visualize_embeddings(model, x_support, y_support, x_query, y_query, n_way, title):
    """
    Visualize 2D embeddings with prototypes.
    (Assumes embedding_dim = 2 or uses PCA)
    """
    from sklearn.decomposition import PCA
    
    model.eval()
    with torch.no_grad():
        # Get embeddings
        support_emb = model(x_support).cpu().numpy()
        query_emb = model(x_query).cpu().numpy()
    
    # Apply PCA if embedding_dim > 2
    if support_emb.shape[1] > 2:
        pca = PCA(n_components=2)
        support_emb = pca.fit_transform(support_emb)
        query_emb = pca.transform(query_emb)
    
    # Compute prototypes
    prototypes = []
    for k in range(n_way):
        mask = (y_support.cpu().numpy() == k)
        proto = support_emb[mask].mean(axis=0)
        prototypes.append(proto)
    prototypes = np.array(prototypes)
    
    # Plot
    plt.figure(figsize=(10, 8))
    colors = plt.cm.rainbow(np.linspace(0, 1, n_way))
    
    for k in range(n_way):
        # Support examples
        support_mask = (y_support.cpu().numpy() == k)
        plt.scatter(support_emb[support_mask, 0], support_emb[support_mask, 1], 
                   c=[colors[k]], marker='s', s=100, 
                   label=f'Support Class {k}', edgecolors='black', linewidth=1.5)
        
        # Query examples
        query_mask = (y_query.cpu().numpy() == k)
        plt.scatter(query_emb[query_mask, 0], query_emb[query_mask, 1], 
                   c=[colors[k]], marker='o', s=60, alpha=0.6)
        
        # Prototype
        plt.scatter(prototypes[k, 0], prototypes[k, 1], 
                   c=[colors[k]], marker='*', s=500, 
                   edgecolors='black', linewidth=2, zorder=10)
    
    plt.xlabel('Embedding Dimension 1', fontsize=12)
    plt.ylabel('Embedding Dimension 2', fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.legend(loc='best', fontsize=9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


# Test prototypical network with synthetic data
print("\n" + "="*80)
print("Prototypical Network Example (Synthetic Data)")
print("="*80)

# Create synthetic 5-way 5-shot data (28x28 images)
torch.manual_seed(42)
n_way, k_shot, q_query = 5, 5, 15
image_size = 28

# Generate synthetic images (different patterns per class)
all_images, all_labels = [], []
for class_id in range(10):  # 10 total classes
    # Create class-specific pattern
    base_pattern = torch.randn(1, 1, image_size, image_size)
    images = base_pattern + 0.1 * torch.randn(100, 1, image_size, image_size)
    labels = torch.full((100,), class_id, dtype=torch.long)
    all_images.append(images)
    all_labels.append(labels)

all_images = torch.cat(all_images, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Create episode sampler
episode_sampler = FewShotEpisode(all_images, all_labels, n_way, k_shot, q_query)

# Sample one episode
x_support, y_support, x_query, y_query = episode_sampler.sample_episode()

print(f"Episode shapes:")
print(f"  Support: {x_support.shape}, labels: {y_support.shape}")
print(f"  Query: {x_query.shape}, labels: {y_query.shape}")

# Train prototypical network
proto_net = PrototypicalNetwork(in_channels=1, embedding_dim=64).to(device)
optimizer = torch.optim.Adam(proto_net.parameters(), lr=0.001)

print("\nTraining Prototypical Network...")
losses = []
for episode in range(100):
    x_sup, y_sup, x_que, y_que = episode_sampler.sample_episode()
    x_sup, y_sup = x_sup.to(device), y_sup.to(device)
    x_que, y_que = x_que.to(device), y_que.to(device)
    
    optimizer.zero_grad()
    loss = proto_net.loss(x_sup, y_sup, x_que, y_que, n_way)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (episode + 1) % 20 == 0:
        print(f"Episode {episode+1}/100, Loss: {loss.item():.4f}")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses, linewidth=2)
plt.xlabel('Episode', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Prototypical Network Training', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Evaluate on test episode
proto_net.eval()
x_sup, y_sup, x_que, y_que = episode_sampler.sample_episode()
x_sup, y_sup = x_sup.to(device), y_sup.to(device)
x_que, y_que = x_que.to(device), y_que.to(device)

with torch.no_grad():
    prototypes = proto_net.compute_prototypes(x_sup, y_sup, n_way)
    logits = proto_net.classify(x_que, prototypes)
    predictions = logits.argmax(dim=1)
    accuracy = (predictions == y_que).float().mean().item()

print(f"\nTest Episode Accuracy: {accuracy*100:.2f}%")
print(f"Random baseline: {100/n_way:.2f}%")

print("\n" + "="*80)
print("Implementation Complete!")
print("="*80)
print("\nKey Insights:")
print("1. MAML learns an initialization that adapts quickly via gradient descent")
print("2. Prototypical Networks classify using distance to class prototypes")
print("3. Matching Networks use attention over support set (differentiable NN)")
print("4. Episode-based training is crucial for meta-learning generalization")
print("5. Trade-offs: MAML (flexible but slow) vs Prototypical (fast but assumes metric space)")
print("\nNext: Apply to real datasets (Omniglot, Mini-ImageNet) for stronger results!")

Advanced Meta-Learning and MAML TheoryΒΆ

1. Introduction to Meta-LearningΒΆ

Meta-learning (learning to learn) aims to design models that can:

  • Quickly adapt to new tasks with few examples

  • Leverage prior experience across tasks

  • Generalize to unseen task distributions

1.1 Problem FormulationΒΆ

Given task distribution p(T), where each task T = {D_train, D_test}:

  • D_train: Support set (few examples for adaptation)

  • D_test: Query set (evaluation)

Goal: Learn ΞΈ such that model adapted on D_train generalizes to D_test.

1.2 Meta-Learning ApproachesΒΆ

  1. Metric-based: Learn embedding space (Prototypical Networks, Matching Networks, Relation Networks)

  2. Model-based: Learn update rules via RNNs or memory (Meta-Networks, SNAIL)

  3. Optimization-based: Learn good initialization for gradient descent (MAML, Reptile)

2. Model-Agnostic Meta-Learning (MAML)ΒΆ

MAML [Finn et al., 2017] learns initialization ΞΈ that enables fast adaptation via few gradient steps.

2.1 AlgorithmΒΆ

Bi-level optimization:

Outer loop (meta-update):

ΞΈ ← ΞΈ - Ξ² βˆ‡_ΞΈ Ξ£_i L_{T_i}(ΞΈ'_i)

Inner loop (task-specific adaptation):

ΞΈ'_i = ΞΈ - Ξ± βˆ‡_ΞΈ L_{T_i}(ΞΈ)

Where:

  • ΞΈ: Meta-parameters (initialization)

  • θ’_i: Task-specific parameters after adaptation

  • Ξ±: Inner learning rate (adaptation)

  • Ξ²: Outer learning rate (meta-learning)

  • L_{T_i}: Loss on task T_i

2.2 Computational GraphΒΆ

ΞΈ β†’ [Inner gradient] β†’ ΞΈ'₁, ΞΈ'β‚‚, ..., ΞΈ'_N β†’ [Evaluate on query] β†’ Meta-loss
                                                                        ↓
                                                              [Outer gradient] β†’ Update ΞΈ

Key insight: Gradients flow through inner optimization! (Second-order)

2.3 MAML ObjectiveΒΆ

Meta-objective:

min_ΞΈ Ξ£_{T_i ~ p(T)} L_{T_i}(U_Ξ±(ΞΈ, D_train^i), D_test^i)

Where U_Ξ± is the adaptation operator (one or more gradient steps).

2.4 First-Order MAML (FOMAML)ΒΆ

Challenge: Computing second-order derivatives is expensive.

FOMAML: Ignore second-order terms:

βˆ‡_ΞΈ L_{T_i}(ΞΈ'_i) β‰ˆ βˆ‡_{ΞΈ'_i} L_{T_i}(ΞΈ'_i)

Treats θ’_i as constant w.r.t. ΞΈ. Much faster, surprisingly effective.

3. Mathematical FoundationsΒΆ

3.1 Taylor Expansion InterpretationΒΆ

After one inner step:

ΞΈ' = ΞΈ - Ξ± βˆ‡_ΞΈ L_train(ΞΈ)

Loss on query set:

L_test(ΞΈ') β‰ˆ L_test(ΞΈ) - Ξ± βˆ‡_ΞΈ L_test(ΞΈ)^T βˆ‡_ΞΈ L_train(ΞΈ)

MAML optimizes for: Alignment between train and test gradients!

3.2 Gradient of Meta-LossΒΆ

Meta-gradient:

βˆ‡_ΞΈ L_test(ΞΈ') = βˆ‡_{ΞΈ'} L_test(ΞΈ') Β· βˆ‡_ΞΈ ΞΈ'

Where:

βˆ‡_ΞΈ ΞΈ' = I - Ξ± βˆ‡Β²_ΞΈ L_train(ΞΈ)

Hessian term H = βˆ‡Β²_ΞΈ L_train(ΞΈ) captures second-order effects.

Full gradient:

βˆ‡_ΞΈ L_test(ΞΈ') = βˆ‡_{ΞΈ'} L_test(ΞΈ') Β· (I - Ξ± H)

FOMAML approximation: βˆ‡_ΞΈ θ’ β‰ˆ I (ignore Hessian).

3.3 Implicit DifferentiationΒΆ

Alternative to backpropagating through inner loop:

At convergence of inner optimization ΞΈ* = argmin_θ’ L_train(θ’):

βˆ‡_{ΞΈ*} L_train(ΞΈ*) = 0

Implicit function theorem:

βˆ‡_ΞΈ ΞΈ* = -(βˆ‡Β²_{ΞΈ*} L_train)^{-1} βˆ‡_{ΞΈ*} βˆ‡_ΞΈ L_train

Used in iMAML [Rajeswaran et al., 2019].

4. Variants and ExtensionsΒΆ

4.1 Reptile [Nichol et al., 2018]ΒΆ

Simpler: Just move toward adapted parameters.

ΞΈ ← ΞΈ + Ξ² (ΞΈ' - ΞΈ)

Where θ’ is result of K inner steps. No meta-gradient computation!

Connection to MAML: Reptile β‰ˆ MAML + averaging over all inner steps.

4.2 MAML++ [Antoniou et al., 2019]ΒΆ

Improvements:

  1. Multi-step loss: Use loss from all inner steps, not just final

  2. Per-parameter learning rates: Ξ± per layer

  3. Learn learning rates: Ξ±, Ξ² as parameters

  4. Batch normalization: Use running stats from support set

Meta-loss:

L_meta = Ξ£_k w_k L_test(ΞΈ_k)

Where ΞΈ_k is parameters after k inner steps, w_k are weights.

4.3 Meta-SGD [Li et al., 2017]ΒΆ

Learn both initialization ΞΈ and learning rates Ξ±:

ΞΈ'_i = ΞΈ - Ξ± βŠ™ βˆ‡_ΞΈ L_{T_i}(ΞΈ)

Meta-parameters: {ΞΈ, Ξ±} (both updated in outer loop).

4.4 ANIL (Almost No Inner Loop) [Raghu et al., 2020]ΒΆ

Finding: Only adapting final layer often suffices!

Freezes feature extractor during inner loop:

  • Features: Ο†(x; ΞΈ_features) [frozen]

  • Head: h(Ο†; ΞΈ_head) [adapted]

Faster, similar performance on many tasks.

4.5 Meta-Curvature [Park & Oliva, 2019]ΒΆ

Incorporate curvature information:

ΞΈ' = ΞΈ - Ξ± (H + Ξ»I)^{-1} βˆ‡_ΞΈ L_train(ΞΈ)

Where H is Hessian approximation (e.g., Fisher information).

5. Task-Conditional ArchitecturesΒΆ

5.1 CAVIA [Zintgraf et al., 2019]ΒΆ

Context adaptation via context parameters Ο†:

f(x; ΞΈ, Ο†)

Inner loop: Adapt only Ο† (low-dimensional) Outer loop: Update ΞΈ

Reduces inner loop computation significantly.

5.2 Conditional Neural Processes (CNPs)ΒΆ

Amortized inference via encoder-decoder:

Encoder: Context set (x_c, y_c) β†’ representation r Decoder: (r, x_query) β†’ y_query

Training: Sample context/target splits from tasks.

Advantage: Single forward pass at test (no inner loop).

6. Few-Shot Learning ApplicationsΒΆ

6.1 N-way K-shot ClassificationΒΆ

Task: Classify into N classes with K examples each.

Episode construction:

  • Sample N classes from dataset

  • Sample K support + Q query examples per class

  • Train to classify query given support

MAML approach:

  1. Inner loop: Adapt on support set

  2. Outer loop: Evaluate on query set, update ΞΈ

6.2 Few-Shot RegressionΒΆ

Task distribution: Functions f ~ p(f)

Example: Sine wave regression

  • Sample amplitude A, phase Ο†

  • Support: Few (x, f(x)) pairs

  • Query: Predict f(x) at new x

MAML learns: Good initialization for function fitting.

6.3 Reinforcement LearningΒΆ

Task: Different reward functions or environments

Inner loop: Adapt policy Ο€_ΞΈ to task via RL algorithm Outer loop: Meta-update ΞΈ for fast adaptation

Applications: Robot locomotion (varying terrains), manipulation (different objects).

7. Theoretical AnalysisΒΆ

7.1 Generalization BoundΒΆ

For MAML, PAC-Bayes bound:

With probability β‰₯ 1-Ξ΄:

E_{T~p(T)} [L_test(ΞΈ_T)] ≀ E_{T~p(T)} [L_train(ΞΈ_T)] + O(√(KL(P_ΞΈ || P_prior) / N))

Where:

  • P_ΞΈ: Distribution over task-specific parameters

  • P_prior: Prior distribution

  • N: Number of tasks

Insight: More tasks β†’ better meta-learning generalization.

7.2 Convergence RateΒΆ

Under smoothness assumptions, MAML converges at rate:

O(1/√T)

for T meta-iterations (same as SGD).

MAML++: Improved constant factors via multi-step loss.

7.3 ExpressivenessΒΆ

Theorem [Finn & Levine, 2018]: MAML can represent any learning algorithm that:

  1. Uses gradient descent for adaptation

  2. Has bounded Hessian

Limitation: Fixed number of inner steps limits expressiveness.

8. Practical ConsiderationsΒΆ

8.1 HyperparametersΒΆ

Critical choices:

  • Inner steps K: 1-10 (more = better adaptation, slower meta-training)

  • Inner LR Ξ±: 0.01-0.1 (task-dependent)

  • Outer LR Ξ²: 0.001-0.01

  • Batch size: Number of tasks per meta-update (4-32)

Tuning: Grid search or learning Ξ±, Ξ².

8.2 StabilityΒΆ

Issue: Second-order gradients can explode/vanish.

Solutions:

  1. Gradient clipping

  2. Layer normalization in network

  3. Lower outer learning rate

  4. Use FOMAML (more stable)

8.3 Computational CostΒΆ

MAML: O(K Β· |ΞΈ|Β²) per task (Hessian computation) FOMAML: O(K Β· |ΞΈ|) per task Reptile: O(K Β· |ΞΈ|) per task (no backprop through inner loop)

For 1M parameters, K=5 steps:

  • MAML: ~10Γ— slower than standard training

  • FOMAML: ~2Γ— slower

8.4 MemoryΒΆ

Challenge: Store computational graph for K inner steps.

Solution: Checkpointing (trade computation for memory).

PyTorch example:

from torch.utils.checkpoint import checkpoint

def inner_loop(ΞΈ, data):
    return checkpoint(adaptation_function, ΞΈ, data)

9. Comparison with Other ApproachesΒΆ

9.1 vs. Transfer LearningΒΆ

Aspect

Transfer Learning

Meta-Learning

Goal

Adapt to one target task

Adapt to many tasks quickly

Training

Pre-train + fine-tune

Learn across tasks

Adaptation

Many examples (1000s)

Few examples (1-10)

Optimization

Single-level

Bi-level

9.2 vs. Metric LearningΒΆ

Approach

Method

Pros

Cons

MAML

Learn initialization

General, model-agnostic

Slow, second-order

Prototypical

Learn embedding + nearest neighbor

Fast, simple

Fixed comparison metric

Matching Nets

Attention over support set

Fast inference

Less general

Hybrid: MAML for embedding, then metric comparison.

9.3 vs. Multi-Task LearningΒΆ

Multi-task: Shared parameters for all tasks simultaneously.

ΞΈ = argmin Ξ£_i L_i(ΞΈ)

Meta-learning: Learn initialization, then adapt per-task.

ΞΈ = argmin Ξ£_i L_i(ΞΈ - Ξ± βˆ‡_ΞΈ L_i(ΞΈ))

Meta-learning advantage: Better for dissimilar tasks (no negative transfer).

10. Advanced TopicsΒΆ

10.1 Task Distribution ShiftΒΆ

Problem: Test tasks differ from training tasks.

Solutions:

  1. Domain randomization: Diverse training tasks

  2. Meta-regularization: Penalize overfitting to training tasks

  3. Uncertainty estimation: Detect out-of-distribution tasks

10.2 Online Meta-LearningΒΆ

Scenario: Tasks arrive sequentially, no revisiting.

Approach: Update ΞΈ after each task:

ΞΈ_{t+1} = ΞΈ_t - Ξ² βˆ‡_ΞΈ L_{T_t}(ΞΈ'_t)

Challenge: Catastrophic forgetting of earlier tasks.

Solution: Replay buffer of past tasks.

10.3 Hierarchical Meta-LearningΒΆ

Multiple levels of adaptation:

  1. Global meta-parameters ΞΈ_0

  2. Domain-specific ΞΈ_d (e.g., visual vs. audio)

  3. Task-specific ΞΈ_t

Training: Nested bi-level optimization.

10.4 Meta-Learning with Pre-trained ModelsΒΆ

Modern approach: Initialize MAML with pre-trained features (e.g., ImageNet).

Procedure:

  1. Load ΞΈ_pretrained from large-scale pre-training

  2. Meta-train ΞΈ starting from ΞΈ_pretrained

  3. Fine-tune on few-shot tasks

Benefit: Best of both worlds (pre-training + meta-learning).

11. Connections to Other FieldsΒΆ

11.2 Bayesian OptimizationΒΆ

Task: Optimize black-box function with few evaluations.

Meta-BO: Learn GP kernel or acquisition function from related tasks.

11.3 Continual LearningΒΆ

Overlap: Both deal with learning from sequence of tasks.

Difference:

  • Meta-learning: Assumes task distribution, goal is fast adaptation

  • Continual learning: No repeated tasks, goal is avoid forgetting

Synergy: Meta-learned initialization resists catastrophic forgetting.

12. Recent Advances (2020-2024)ΒΆ

12.1 Transformer-Based Meta-LearningΒΆ

In-context learning (GPT-3 style):

  • Concatenate support examples in prompt

  • No gradient-based adaptation!

Relation to MAML: Implicit meta-learning during pre-training.

12.2 Meta-Learning for PromptsΒΆ

Prompt tuning: Learn soft prompts for language models.

Meta-prompt learning: MAML over prompt parameters.

12.3 Bi-level Optimization Beyond MAMLΒΆ

Applications:

  • Hyperparameter optimization: Inner = training, outer = validation loss

  • Data distillation: Inner = train on synthetic, outer = test on real

  • Neural architecture search: Inner = train model, outer = architecture loss

12.4 Federated Meta-LearningΒΆ

Scenario: Meta-learn from decentralized data (privacy-preserving).

Approach:

  1. Local adaptation on each client

  2. Aggregate meta-gradients on server

Challenge: Communication cost, heterogeneous data.

13. Implementation TipsΒΆ

13.1 Debugging MAMLΒΆ

Common issues:

  1. Exploding gradients: Use gradient clipping, lower Ξ± or Ξ²

  2. No improvement: Check inner loop is actually reducing loss

  3. Overfitting: More tasks in meta-train set, data augmentation

  4. Slow convergence: Try MAML++ (multi-step loss), increase batch size

Sanity checks:

  • Adapted model should outperform initialization on support set

  • Meta-loss should decrease over meta-iterations

  • Test on simple task (e.g., sine regression) first

13.2 Efficient ImplementationΒΆ

Batching:

  • Parallelize inner loops across tasks

  • Use vmap (JAX) or functorch (PyTorch) for efficient gradient computation

Mixed precision: Train in FP16 to reduce memory (with gradient scaling).

Distributed: Split tasks across GPUs, aggregate meta-gradients.

13.3 Choosing Inner Steps KΒΆ

Trade-off:

  • K small (1-3): Fast, less adaptation, more meta-learning

  • K large (5-10): Better task performance, slower, risk overfitting

Rule of thumb: K β‰ˆ number of gradient steps needed to see loss decrease on task.

14. Complexity AnalysisΒΆ

14.1 Time ComplexityΒΆ

Per meta-iteration with B tasks, K inner steps, |ΞΈ| parameters:

MAML: O(B Β· K Β· |ΞΈ|Β²) (Hessian-vector products) FOMAML: O(B Β· K Β· |ΞΈ|) Reptile: O(B Β· K Β· |ΞΈ|)

Typical values:

  • B = 4-32 tasks

  • K = 1-10 steps

  • |ΞΈ| = 10⁴-10⁷ parameters

Example: ResNet-18 (11M params), B=8, K=5:

  • MAML: ~60GB memory, ~10s per iteration (GPU)

  • FOMAML: ~15GB memory, ~2s per iteration

14.2 Sample ComplexityΒΆ

Meta-train: Requires N_tasks tasks, each with support + query sets.

Guideline: N_tasks β‰₯ 100-1000 for good generalization.

Few-shot episodes: Generate via sampling from base dataset.

Example (Omniglot):

  • 1623 characters (tasks)

  • 20 examples per character

  • 5-way 1-shot: 5 classes Γ— (1 support + 15 query) = 80 examples per episode

15. LimitationsΒΆ

  1. Computational cost: Second-order gradients expensive

  2. Hyperparameter sensitivity: Ξ±, Ξ², K require tuning

  3. Task diversity: Needs sufficient variation in p(T)

  4. Failure modes: Can collapse to memorization or initialization-only

  5. Non-stationarity: Struggles if task distribution shifts

When NOT to use MAML:

  • Very few total tasks (N < 20): Just multi-task learning

  • Many-shot regime (100+ examples): Standard transfer learning better

  • Tasks too dissimilar: No shared structure to meta-learn

16. Software and LibrariesΒΆ

16.1 ImplementationsΒΆ

PyTorch:

  • learn2learn: High-level MAML API

  • higher: Functional programming for bi-level optimization

  • torchmeta: Datasets + benchmarks

TensorFlow:

  • tensorflow-maml: Official implementation

JAX:

  • Native support for grad(grad(...)) (second-order)

  • Efficient with vmap for batching

16.2 BenchmarksΒΆ

Classification:

  • Omniglot: 1623 handwritten characters (5-way 1-shot)

  • Mini-ImageNet: 100 classes, 600 images each (5-way 5-shot)

  • Tiered-ImageNet: Hierarchical version of ImageNet

Regression:

  • Sinusoid: Sample amplitude, phase, frequency

  • Polynomial: Random coefficients

Reinforcement Learning:

  • MuJoCo: HalfCheetah, Ant with varying dynamics

  • Meta-World: 50 robotic manipulation tasks

17. Key TakeawaysΒΆ

  1. MAML learns initialization for fast adaptation via gradient descent

  2. Bi-level optimization: Inner loop (task adaptation) + outer loop (meta-update)

  3. Second-order: Gradients flow through inner optimization (FOMAML approximates)

  4. Model-agnostic: Works with any gradient-based model

  5. Few-shot learning: Excel with 1-10 examples per task

  6. Trade-offs: Computational cost vs. adaptation speed vs. generalization

  7. Extensions: MAML++, Reptile, Meta-SGD improve on vanilla MAML

  8. Modern use: Combined with pre-training, prompt learning, transformers

Intuition: MAML finds ΞΈ such that gradient descent quickly moves toward good solutions across tasks. It’s learning to learn via gradient descent.

18. Mathematical SummaryΒΆ

MAML objective:

min_ΞΈ E_{T~p(T)} [L_T(ΞΈ - Ξ± βˆ‡_ΞΈ L_T^train(ΞΈ))]

Meta-gradient:

βˆ‡_ΞΈ L_T^test(ΞΈ') = βˆ‡_{ΞΈ'} L_T^test(ΞΈ') Β· (I - Ξ± βˆ‡Β²_ΞΈ L_T^train(ΞΈ))
                                              β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                                 Hessian term

FOMAML approximation:

βˆ‡_ΞΈ L_T^test(ΞΈ') β‰ˆ βˆ‡_{ΞΈ'} L_T^test(ΞΈ')

Update rule:

ΞΈ ← ΞΈ - Ξ² βˆ‡_ΞΈ Ξ£_{T_i} L_{T_i}^test(ΞΈ'_i)

ReferencesΒΆ

  1. Finn et al. (2017) β€œModel-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”

  2. Nichol et al. (2018) β€œOn First-Order Meta-Learning Algorithms (Reptile)”

  3. Antoniou et al. (2019) β€œHow to Train Your MAML (MAML++)”

  4. Li et al. (2017) β€œMeta-SGD: Learning to Learn Quickly for Few-Shot Learning”

  5. Raghu et al. (2020) β€œRapid Learning or Feature Reuse? (ANIL)”

  6. Rajeswaran et al. (2019) β€œMeta-Learning with Implicit Gradients (iMAML)”

  7. Zintgraf et al. (2019) β€œFast Context Adaptation via Meta-Learning (CAVIA)”

  8. Hospedales et al. (2021) β€œMeta-Learning in Neural Networks: A Survey”

"""
Complete MAML and Meta-Learning Implementations
===============================================
Includes: MAML, FOMAML, Reptile, MAML++, ANIL, few-shot classification,
sine wave regression, metric comparison.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
import matplotlib.pyplot as plt

# ============================================================================
# 1. Utility Functions for MAML
# ============================================================================

def clone_parameters(model):
    """Create a copy of model parameters."""
    return OrderedDict({
        name: param.clone() 
        for name, param in model.named_parameters()
    })

def set_parameters(model, params):
    """Set model parameters from OrderedDict."""
    for name, param in model.named_parameters():
        param.data = params[name].data

def get_grad_as_tensor(model):
    """Extract gradients as a single tensor."""
    grads = []
    for param in model.parameters():
        if param.grad is not None:
            grads.append(param.grad.view(-1))
        else:
            grads.append(torch.zeros_like(param).view(-1))
    return torch.cat(grads)

# ============================================================================
# 2. Simple Convolutional Model for Few-Shot Learning
# ============================================================================

class ConvBlock(nn.Module):
    """Convolutional block with batch norm and ReLU."""
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels, momentum=1.0)  # No running stats
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        return self.pool(self.relu(self.bn(self.conv(x))))

class SimpleConvNet(nn.Module):
    """4-layer ConvNet for Omniglot/Mini-ImageNet."""
    def __init__(self, input_channels=1, hidden_dim=64, output_dim=5):
        super(SimpleConvNet, self).__init__()
        
        self.features = nn.Sequential(
            ConvBlock(input_channels, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
        )
        
        self.classifier = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        return self.classifier(features)

# ============================================================================
# 3. MAML Algorithm
# ============================================================================

class MAML:
    """
    Model-Agnostic Meta-Learning.
    
    Args:
        model: Neural network model
        inner_lr: Learning rate for inner loop (task adaptation)
        outer_lr: Learning rate for outer loop (meta-update)
        inner_steps: Number of gradient steps in inner loop
        first_order: If True, use FOMAML (no second-order gradients)
    """
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, 
                 inner_steps=5, first_order=False):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.inner_steps = inner_steps
        self.first_order = first_order
        
        # Outer optimizer
        self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)
    
    def inner_loop(self, task_data, labels):
        """
        Perform inner loop adaptation on a single task.
        
        Args:
            task_data: Support set data [K*N, ...]
            labels: Support set labels [K*N]
        
        Returns:
            adapted_params: Adapted parameters after K gradient steps
        """
        # Clone current parameters
        adapted_params = clone_parameters(self.model)
        
        # Create computational graph for adaptation
        for step in range(self.inner_steps):
            # Set model to adapted parameters
            set_parameters(self.model, adapted_params)
            
            # Forward pass
            logits = self.model(task_data)
            loss = F.cross_entropy(logits, labels)
            
            # Compute gradients
            grads = torch.autograd.grad(
                loss, 
                self.model.parameters(),
                create_graph=not self.first_order  # Second-order if MAML
            )
            
            # Update adapted parameters
            adapted_params = OrderedDict({
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(adapted_params.items(), grads)
            })
        
        return adapted_params
    
    def meta_update(self, tasks):
        """
        Perform meta-update on a batch of tasks.
        
        Args:
            tasks: List of (support_data, support_labels, query_data, query_labels)
        """
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        meta_acc = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # Inner loop: Adapt to task
            adapted_params = self.inner_loop(support_x, support_y)
            
            # Set model to adapted parameters
            set_parameters(self.model, adapted_params)
            
            # Evaluate on query set
            query_logits = self.model(query_x)
            task_loss = F.cross_entropy(query_logits, query_y)
            
            # Accumulate meta-loss
            meta_loss += task_loss
            
            # Compute accuracy
            with torch.no_grad():
                pred = query_logits.argmax(dim=1)
                meta_acc += (pred == query_y).float().mean()
        
        # Average over tasks
        meta_loss = meta_loss / len(tasks)
        meta_acc = meta_acc / len(tasks)
        
        # Meta-gradient and update
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item(), meta_acc.item()
    
    def evaluate(self, tasks):
        """Evaluate on validation/test tasks."""
        total_loss = 0.0
        total_acc = 0.0
        
        with torch.no_grad():
            for support_x, support_y, query_x, query_y in tasks:
                # Adapt (no gradients needed)
                for step in range(self.inner_steps):
                    logits = self.model(support_x)
                    loss = F.cross_entropy(logits, support_y)
                    
                    # Manual gradient descent
                    grads = torch.autograd.grad(loss, self.model.parameters())
                    
                    for param, grad in zip(self.model.parameters(), grads):
                        param.data -= self.inner_lr * grad
                
                # Evaluate on query
                query_logits = self.model(query_x)
                task_loss = F.cross_entropy(query_logits, query_y)
                
                total_loss += task_loss.item()
                pred = query_logits.argmax(dim=1)
                total_acc += (pred == query_y).float().mean().item()
        
        return total_loss / len(tasks), total_acc / len(tasks)


# ============================================================================
# 4. Reptile Algorithm
# ============================================================================

class Reptile:
    """
    Reptile: First-order meta-learning algorithm.
    
    Simpler than MAML: Just move toward adapted parameters.
    """
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.inner_steps = inner_steps
    
    def meta_update(self, tasks):
        """Meta-update via Reptile."""
        init_params = clone_parameters(self.model)
        
        total_loss = 0.0
        total_acc = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # Reset to initial parameters
            set_parameters(self.model, init_params)
            
            # Inner loop: Adapt to task
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.inner_lr)
            
            for step in range(self.inner_steps):
                optimizer.zero_grad()
                logits = self.model(support_x)
                loss = F.cross_entropy(logits, support_y)
                loss.backward()
                optimizer.step()
            
            # Evaluate on query (for logging)
            with torch.no_grad():
                query_logits = self.model(query_x)
                task_loss = F.cross_entropy(query_logits, query_y)
                total_loss += task_loss.item()
                
                pred = query_logits.argmax(dim=1)
                total_acc += (pred == query_y).float().mean().item()
            
            # Meta-update: Move toward adapted parameters
            adapted_params = clone_parameters(self.model)
            
            for (name, init_param), (_, adapted_param) in zip(
                init_params.items(), adapted_params.items()
            ):
                init_param.data += self.outer_lr * (adapted_param.data - init_param.data)
        
        # Set model to updated parameters
        set_parameters(self.model, init_params)
        
        return total_loss / len(tasks), total_acc / len(tasks)


# ============================================================================
# 5. ANIL (Almost No Inner Loop)
# ============================================================================

class ANIL:
    """
    ANIL: Only adapt the final layer (head) during inner loop.
    
    Feature extractor is frozen.
    """
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.inner_steps = inner_steps
        
        # Separate feature extractor and classifier
        self.features = model.features
        self.classifier = model.classifier
        
        # Outer optimizer (entire model)
        self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)
    
    def inner_loop(self, task_data, labels):
        """Adapt only the classifier."""
        # Extract features (frozen)
        with torch.no_grad():
            features = self.features(task_data)
            features = features.view(features.size(0), -1)
        
        # Clone classifier parameters
        adapted_classifier = clone_parameters(self.classifier)
        
        # Adapt classifier only
        for step in range(self.inner_steps):
            set_parameters(self.classifier, adapted_classifier)
            
            logits = self.classifier(features)
            loss = F.cross_entropy(logits, labels)
            
            grads = torch.autograd.grad(loss, self.classifier.parameters(), create_graph=True)
            
            adapted_classifier = OrderedDict({
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(adapted_classifier.items(), grads)
            })
        
        return adapted_classifier
    
    def meta_update(self, tasks):
        """Meta-update on tasks."""
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        meta_acc = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # Adapt classifier
            adapted_classifier = self.inner_loop(support_x, support_y)
            
            # Evaluate on query
            with torch.no_grad():
                query_features = self.features(query_x)
                query_features = query_features.view(query_features.size(0), -1)
            
            set_parameters(self.classifier, adapted_classifier)
            query_logits = self.classifier(query_features)
            task_loss = F.cross_entropy(query_logits, query_y)
            
            meta_loss += task_loss
            
            with torch.no_grad():
                pred = query_logits.argmax(dim=1)
                meta_acc += (pred == query_y).float().mean()
        
        meta_loss = meta_loss / len(tasks)
        meta_acc = meta_acc / len(tasks)
        
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item(), meta_acc.item()


# ============================================================================
# 6. Sine Wave Regression (Classic MAML Demo)
# ============================================================================

class SineWaveDataset:
    """Generate sine wave tasks for regression."""
    def __init__(self, num_tasks=1000, k_shot=10, q_query=10):
        self.num_tasks = num_tasks
        self.k_shot = k_shot
        self.q_query = q_query
    
    def sample_task(self):
        """Sample a sine wave task with random amplitude and phase."""
        # Random amplitude [0.1, 5.0], phase [0, Ο€]
        amplitude = np.random.uniform(0.1, 5.0)
        phase = np.random.uniform(0, np.pi)
        
        # Sample x uniformly from [-5, 5]
        x = np.random.uniform(-5, 5, self.k_shot + self.q_query)
        y = amplitude * np.sin(x + phase)
        
        # Split into support and query
        support_x = x[:self.k_shot]
        support_y = y[:self.k_shot]
        query_x = x[self.k_shot:]
        query_y = y[self.k_shot:]
        
        # Convert to tensors
        support_x = torch.tensor(support_x, dtype=torch.float32).unsqueeze(1)
        support_y = torch.tensor(support_y, dtype=torch.float32).unsqueeze(1)
        query_x = torch.tensor(query_x, dtype=torch.float32).unsqueeze(1)
        query_y = torch.tensor(query_y, dtype=torch.float32).unsqueeze(1)
        
        return support_x, support_y, query_x, query_y
    
    def __iter__(self):
        for _ in range(self.num_tasks):
            yield self.sample_task()


class SineModel(nn.Module):
    """Simple MLP for sine wave regression."""
    def __init__(self, hidden_dim=40):
        super(SineModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.net(x)


def train_maml_sine():
    """Train MAML on sine wave regression."""
    print("="*70)
    print("MAML Sine Wave Regression Demo")
    print("="*70)
    
    # Model
    model = SineModel(hidden_dim=40)
    
    # MAML
    maml = MAML(
        model, 
        inner_lr=0.01, 
        outer_lr=0.001, 
        inner_steps=5,
        first_order=False
    )
    
    # Dataset
    train_dataset = SineWaveDataset(num_tasks=10000, k_shot=10, q_query=10)
    
    # Meta-training
    print("Meta-training MAML...")
    num_iterations = 100
    batch_size = 4
    
    task_iter = iter(train_dataset)
    
    for iteration in range(num_iterations):
        # Sample batch of tasks
        tasks = [next(task_iter) for _ in range(batch_size)]
        
        # Meta-update
        loss, acc = maml.meta_update(tasks)
        
        if (iteration + 1) % 20 == 0:
            print(f"  Iteration {iteration+1}: Meta-loss = {loss:.4f}")
    
    # Test: Adapt to new sine wave
    print("\nTesting on new sine wave...")
    test_task = SineWaveDataset(num_tasks=1, k_shot=10, q_query=100).sample_task()
    support_x, support_y, query_x, query_y = test_task
    
    # Before adaptation
    with torch.no_grad():
        pred_before = model(query_x)
        mse_before = F.mse_loss(pred_before, query_y).item()
    
    # Adapt
    adapted_params = maml.inner_loop(support_x, support_y)
    set_parameters(model, adapted_params)
    
    # After adaptation
    with torch.no_grad():
        pred_after = model(query_x)
        mse_after = F.mse_loss(pred_after, query_y).item()
    
    print(f"  MSE before adaptation: {mse_before:.4f}")
    print(f"  MSE after adaptation: {mse_after:.4f}")
    print(f"  Improvement: {mse_before / mse_after:.2f}Γ—")
    print()


# ============================================================================
# 7. Method Comparison
# ============================================================================

def print_method_comparison():
    """Print comparison of meta-learning algorithms."""
    print("="*70)
    print("Meta-Learning Algorithms Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method      β”‚ Order        β”‚ Speed      β”‚ Memory       β”‚ Performance  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ MAML        β”‚ Second-order β”‚ Slow       β”‚ High (graph) β”‚ Best         β”‚
β”‚             β”‚ (Hessian)    β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ FOMAML      β”‚ First-order  β”‚ Medium     β”‚ Medium       β”‚ Good         β”‚
β”‚             β”‚              β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Reptile     β”‚ First-order  β”‚ Fast       β”‚ Low          β”‚ Good         β”‚
β”‚             β”‚ (simpler)    β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ANIL        β”‚ Second-order β”‚ Fast       β”‚ Medium       β”‚ Good (if     β”‚
β”‚             β”‚ (head only)  β”‚            β”‚              β”‚ features OK) β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ MAML++      β”‚ Second-order β”‚ Slow       β”‚ High         β”‚ Best (tuned) β”‚
β”‚             β”‚ (enhanced)   β”‚            β”‚              β”‚              β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Computational Complexity (per meta-iteration):**

- MAML: O(B Β· K Β· |ΞΈ|Β²)    [B=tasks, K=steps, |ΞΈ|=params]
- FOMAML: O(B Β· K Β· |ΞΈ|)
- Reptile: O(B Β· K Β· |ΞΈ|)
- ANIL: O(B Β· K Β· |ΞΈ_head|Β²) where |ΞΈ_head| << |ΞΈ|

**When to Use:**

- **MAML**: Best performance, have computational budget
- **FOMAML**: Good balance of speed and performance
- **Reptile**: Simplest to implement, fastest training
- **ANIL**: Fast adaptation, when features are pre-trained
- **MAML++**: Production use, worth hyperparameter tuning

**Inner Loop Steps K:**

- K=1: Minimal adaptation, very fast
- K=5: Standard choice, good balance
- K=10+: Better adaptation, slower, risk overfitting

**Typical Results (5-way 1-shot Omniglot):**

- Random: 20% accuracy
- Fine-tuning: 50-60%
- MAML: 95-98%
- FOMAML: 93-95%
- Reptile: 92-94%

**Implementation Tips:**

1. **Start with FOMAML**: Simpler, good baseline
2. **Use gradient clipping**: Stability crucial
3. **Batch size**: 4-16 tasks per meta-update
4. **Learning rates**: Ξ±=0.01 (inner), Ξ²=0.001 (outer)
5. **Warm-up**: Lower outer LR for first 1000 iterations
"""
    
    print(comparison)
    print()


def print_complexity_analysis():
    """Print detailed complexity analysis."""
    print("="*70)
    print("Complexity Analysis")
    print("="*70)
    print()
    
    print("**Example: ResNet-18 (11M parameters)**")
    print()
    print("Configuration:")
    print("  β€’ Batch size B = 8 tasks")
    print("  β€’ Inner steps K = 5")
    print("  β€’ Parameters |ΞΈ| = 11M")
    print()
    print("Time per meta-iteration (GPU):")
    print("  β€’ MAML: ~10s (second-order gradients)")
    print("  β€’ FOMAML: ~2s (first-order)")
    print("  β€’ Reptile: ~1.5s (no backprop through inner loop)")
    print("  β€’ ANIL: ~0.5s (only adapt final layer)")
    print()
    print("Memory usage:")
    print("  β€’ MAML: ~60GB (stores computational graph for K steps)")
    print("  β€’ FOMAML: ~15GB (no Hessian)")
    print("  β€’ Reptile: ~12GB (minimal graph)")
    print("  β€’ ANIL: ~8GB (small adaptation)")
    print()
    print("Speedup techniques:")
    print("  β€’ Mixed precision (FP16): 2Γ— faster, 1/2 memory")
    print("  β€’ Gradient checkpointing: 1/K memory, ~1.5Γ— slower")
    print("  β€’ Distributed (multi-GPU): Linear speedup in B")
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    
    train_maml_sine()
    print_method_comparison()
    print_complexity_analysis()
    
    print("="*70)
    print("MAML and Meta-Learning Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ MAML: Bi-level optimization for fast adaptation")
    print("  β€’ FOMAML: First-order approximation (faster)")
    print("  β€’ Reptile: Simpler algorithm moving toward adapted params")
    print("  β€’ ANIL: Only adapt final layer (faster)")
    print("  β€’ Sine regression: Classic MAML demonstration")
    print()
    print("Key insight: Learn initialization ΞΈ that enables")
    print("             fast adaptation via gradient descent")
    print("Trade-off: Computational cost vs. adaptation speed")
    print("Applications: Few-shot learning, RL, regression")
    print()