import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from copy import deepcopy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Catastrophic ForgettingΒΆ

ProblemΒΆ

Neural networks forget previous tasks when learning new ones.

EWC SolutionΒΆ

\[\mathcal{L}(\theta) = \mathcal{L}_B(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{A,i}^*)^2\]

where \(F_i\) is Fisher information:

\[F_i = \mathbb{E}\left[\left(\frac{\partial \log p(y|x, \theta)}{\partial \theta_i}\right)^2\right]\]

πŸ“š Reference Materials:

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Create task splits: Task A (0-4), Task B (5-9)
indices_A = [i for i, (_, label) in enumerate(mnist) if label < 5]
indices_B = [i for i, (_, label) in enumerate(mnist) if label >= 5]

dataset_A = torch.utils.data.Subset(mnist, indices_A)
dataset_B = torch.utils.data.Subset(mnist, indices_B)

loader_A = torch.utils.data.DataLoader(dataset_A, batch_size=128, shuffle=True)
loader_B = torch.utils.data.DataLoader(dataset_B, batch_size=128, shuffle=True)

print(f"Task A: {len(dataset_A)}, Task B: {len(dataset_B)}")

Compute Fisher InformationΒΆ

Elastic Weight Consolidation (EWC) prevents catastrophic forgetting by penalizing changes to parameters that were important for previous tasks. The importance of each parameter is estimated by the diagonal of the Fisher information matrix, which approximates the curvature of the loss landscape around the current parameters: \(F_i = \mathbb{E}\left[\left(\frac{\partial \log p(y|x; \theta)}{\partial \theta_i}\right)^2\right]\). Parameters with high Fisher information are critical for the current task and should be changed minimally when learning new tasks. Computing the Fisher matrix requires a forward-backward pass over a representative sample of the current task’s data.

def compute_fisher(model, data_loader, n_samples=1000):
    """Compute Fisher information matrix."""
    model.eval()
    fisher = {}
    
    for name, param in model.named_parameters():
        fisher[name] = torch.zeros_like(param)
    
    count = 0
    for x, y in data_loader:
        if count >= n_samples:
            break
        
        x, y = x.to(device), y.to(device)
        
        # Forward
        output = model(x)
        log_probs = F.log_softmax(output, dim=1)
        
        # Sample predictions
        sampled_y = torch.multinomial(torch.exp(log_probs), 1).squeeze()
        
        # Compute gradients
        loss = F.nll_loss(log_probs, sampled_y)
        model.zero_grad()
        loss.backward()
        
        # Accumulate squared gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher[name] += param.grad.pow(2) * len(x)
        
        count += len(x)
    
    # Normalize
    for name in fisher:
        fisher[name] /= count
    
    return fisher

print("Fisher computation function defined")

EWC TrainingΒΆ

When learning a new task, EWC adds a quadratic penalty to the loss for each parameter, weighted by its Fisher information from the previous task: \(\mathcal{L}_{\text{EWC}} = \mathcal{L}_{\text{new task}} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta^*_i)^2\), where \(\theta^*\) are the parameters after training on the previous task. The hyperparameter \(\lambda\) controls the strength of the consolidation: too low and the model forgets; too high and it cannot learn the new task. This penalty acts as an anisotropic regularizer, allowing free movement in directions that do not affect the previous task while constraining important directions.

def ewc_loss(model, fisher, old_params, lambda_ewc=1000):
    """Compute EWC regularization loss."""
    loss = 0
    for name, param in model.named_parameters():
        if name in fisher:
            loss += (fisher[name] * (param - old_params[name]).pow(2)).sum()
    return lambda_ewc / 2 * loss

def train_with_ewc(model, loader, fisher=None, old_params=None, n_epochs=5, lambda_ewc=1000):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(n_epochs):
        model.train()
        epoch_loss = 0
        
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            
            output = model(x)
            loss_task = F.cross_entropy(output, y)
            
            # Add EWC penalty
            if fisher is not None:
                loss_ewc = ewc_loss(model, fisher, old_params, lambda_ewc)
                loss = loss_task + loss_ewc
            else:
                loss = loss_task
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(loader):.4f}")

print("EWC training function defined")

Baseline: Sequential TrainingΒΆ

The baseline for continual learning is naive sequential training: train on Task 1 until convergence, then train on Task 2 from the same model without any forgetting prevention. This baseline exhibits catastrophic forgetting – performance on Task 1 drops precipitously as the model’s parameters are overwritten to accommodate Task 2. Measuring the accuracy on all previous tasks after each new task quantifies the severity of forgetting and provides the benchmark against which EWC and other continual learning methods are evaluated.

def evaluate_tasks(model, test_mnist):
    """Evaluate on both tasks."""
    model.eval()
    
    # Task A: 0-4
    indices_A = [i for i, (_, label) in enumerate(test_mnist) if label < 5]
    test_A = torch.utils.data.Subset(test_mnist, indices_A)
    loader_A = torch.utils.data.DataLoader(test_A, batch_size=1000)
    
    correct_A = 0
    total_A = 0
    with torch.no_grad():
        for x, y in loader_A:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct_A += (pred == y).sum().item()
            total_A += y.size(0)
    
    # Task B: 5-9
    indices_B = [i for i, (_, label) in enumerate(test_mnist) if label >= 5]
    test_B = torch.utils.data.Subset(test_mnist, indices_B)
    loader_B = torch.utils.data.DataLoader(test_B, batch_size=1000)
    
    correct_B = 0
    total_B = 0
    with torch.no_grad():
        for x, y in loader_B:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct_B += (pred == y).sum().item()
            total_B += y.size(0)
    
    return 100 * correct_A / total_A, 100 * correct_B / total_B

# Baseline: Sequential without EWC
model_baseline = SimpleNet().to(device)

print("Training Task A...")
train_with_ewc(model_baseline, loader_A, n_epochs=5)
acc_A_after_A, acc_B_after_A = evaluate_tasks(model_baseline, test_mnist)
print(f"After Task A - A: {acc_A_after_A:.2f}%, B: {acc_B_after_A:.2f}%")

print("\nTraining Task B...")
train_with_ewc(model_baseline, loader_B, n_epochs=5)
acc_A_after_B, acc_B_after_B = evaluate_tasks(model_baseline, test_mnist)
print(f"After Task B - A: {acc_A_after_B:.2f}%, B: {acc_B_after_B:.2f}%")
print(f"Forgetting on A: {acc_A_after_A - acc_A_after_B:.2f}%")

EWC TrainingΒΆ

With EWC enabled, we train on the sequence of tasks while applying the Fisher-weighted penalty. After each task, we compute the new Fisher information matrix and store the current parameters as the reference point. For a sequence of tasks, the penalty accumulates contributions from all previous tasks, ensuring that parameters important for any past task are protected. Monitoring accuracy on all tasks simultaneously during training reveals how effectively EWC preserves previous knowledge while accommodating new learning.

model_ewc = SimpleNet().to(device)

print("Training Task A...")
train_with_ewc(model_ewc, loader_A, n_epochs=5)
acc_A_ewc, _ = evaluate_tasks(model_ewc, test_mnist)
print(f"After Task A - A: {acc_A_ewc:.2f}%")

# Compute Fisher and save parameters
print("\nComputing Fisher information...")
fisher = compute_fisher(model_ewc, loader_A, n_samples=2000)
old_params = {name: param.clone() for name, param in model_ewc.named_parameters()}

print("Training Task B with EWC...")
train_with_ewc(model_ewc, loader_B, fisher, old_params, n_epochs=5, lambda_ewc=5000)
acc_A_ewc_final, acc_B_ewc_final = evaluate_tasks(model_ewc, test_mnist)
print(f"After Task B - A: {acc_A_ewc_final:.2f}%, B: {acc_B_ewc_final:.2f}%")
print(f"Forgetting on A: {acc_A_ewc - acc_A_ewc_final:.2f}%")

Compare ResultsΒΆ

Plotting the accuracy on each task over the course of sequential learning produces a forgetting matrix: the diagonal shows peak performance on each task, and off-diagonal entries show how much that performance degrades after subsequent tasks. EWC should show dramatically less off-diagonal degradation compared to naive sequential training, with only modest reduction in new-task learning speed. Quantitative metrics include average accuracy (across all tasks at the end) and backward transfer (average change in performance on old tasks after learning new ones).

# Comparison
methods = ['Baseline', 'EWC']
task_A = [acc_A_after_B, acc_A_ewc_final]
task_B = [acc_B_after_B, acc_B_ewc_final]

x = np.arange(len(methods))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(x - width/2, task_A, width, label='Task A (0-4)', alpha=0.8)
ax.bar(x + width/2, task_B, width, label='Task B (5-9)', alpha=0.8)

ax.set_ylabel('Accuracy (%)', fontsize=11)
ax.set_title('Continual Learning Comparison', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(methods)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\nBaseline - Forgetting: {acc_A_after_A - acc_A_after_B:.2f}%")
print(f"EWC - Forgetting: {acc_A_ewc - acc_A_ewc_final:.2f}%")

SummaryΒΆ

Continual Learning:ΒΆ

Problem: Catastrophic forgetting when learning sequentially

EWC Solution:

  1. Compute Fisher information after Task A

  2. Penalize changes to important parameters

  3. Balance new task learning with retention

Key Components:ΒΆ

  • Fisher matrix: Parameter importance

  • Quadratic penalty: Soft constraint

  • Lambda: Retention vs plasticity tradeoff

Other Approaches:ΒΆ

  • Rehearsal: Store subset of old data

  • Progressive networks: Add new columns

  • PackNet: Prune and freeze

  • LwF: Knowledge distillation

Applications:ΒΆ

  • Lifelong learning systems

  • Robotics (new environments)

  • Personalization (user adaptation)

  • Edge devices (limited memory)

Advanced Continual Learning TheoryΒΆ

1. Catastrophic Forgetting: The Stability-Plasticity DilemmaΒΆ

The ProblemΒΆ

Neural networks exhibit a fundamental tension between stability (retaining old knowledge) and plasticity (learning new information). When trained sequentially on different tasks, standard gradient descent catastrophically forgets previous tasks.

Mathematical formulation:

Given tasks \(\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_T\) arriving sequentially, a network optimized for task \(\mathcal{T}_t\) will have high loss on earlier tasks:

\[\mathcal{L}_{\mathcal{T}_1}(\theta_t) \gg \mathcal{L}_{\mathcal{T}_1}(\theta_1)\]

even though \(\theta_1\) was optimal for \(\mathcal{T}_1\).

Why It HappensΒΆ

Weight overlap: Different tasks often require similar features in lower layers, but different mappings in higher layers. Updating weights for task 2 overwrites representations learned for task 1.

Gradient interference: The gradient for task \(\mathcal{T}_t\) may point in the opposite direction from the gradient for task \(\mathcal{T}_{t-1}\):

\[\nabla_\theta \mathcal{L}_{\mathcal{T}_t} \cdot \nabla_\theta \mathcal{L}_{\mathcal{T}_{t-1}} < 0\]

Distributed representations: Neural networks use distributed representations where each neuron contributes to multiple tasks. Changing one neuron affects all tasks.

Measuring ForgettingΒΆ

Forgetting metric (Chaudhry et al., 2018):

\[F_t^j = \max_{k \in \{1, \ldots, t-1\}} \left( a_{k,k} - a_{j,t} \right)\]

where \(a_{k,k}\) is accuracy on task \(k\) immediately after training on it, and \(a_{j,t}\) is accuracy on task \(j\) after training on task \(t\).

Average forgetting:

\[\bar{F}_T = \frac{1}{T-1} \sum_{j=1}^{T-1} F_T^j\]

2. Elastic Weight Consolidation (EWC): Detailed DerivationΒΆ

Bayesian PerspectiveΒΆ

EWC views continual learning as sequential Bayesian inference. After learning task \(A\), the posterior becomes the prior for task \(B\):

\[p(\theta | \mathcal{D}_B) \propto p(\mathcal{D}_B | \theta) p(\theta | \mathcal{D}_A)\]

Taking negative log:

\[-\log p(\theta | \mathcal{D}_B) = -\log p(\mathcal{D}_B | \theta) - \log p(\theta | \mathcal{D}_A) + \text{const}\]

Laplace ApproximationΒΆ

Approximate the posterior \(p(\theta | \mathcal{D}_A)\) as a Gaussian centered at the optimal parameters \(\theta_A^*\):

\[p(\theta | \mathcal{D}_A) \approx \mathcal{N}(\theta_A^*, F^{-1})\]

where \(F\) is the Fisher information matrix:

\[F_{ij} = \mathbb{E}_{x \sim \mathcal{D}_A} \left[ \frac{\partial \log p(y|x, \theta)}{\partial \theta_i} \frac{\partial \log p(y|x, \theta)}{\partial \theta_j} \right]\]

Fisher Information MatrixΒΆ

The Fisher information quantifies how much the model output changes when a parameter is perturbed.

For classification with softmax:

\[p(y=k | x, \theta) = \frac{\exp(f_k(x; \theta))}{\sum_j \exp(f_j(x; \theta))}\]

The diagonal Fisher information for parameter \(\theta_i\) is:

\[F_i = \mathbb{E}_{(x,y) \sim \mathcal{D}_A} \left[ \left( \frac{\partial \log p(y|x, \theta)}{\partial \theta_i} \right)^2 \right]\]

Intuition: If \(F_i\) is large, parameter \(\theta_i\) significantly affects task \(A\)’s predictions, so it should not change much when learning task \(B\).

EWC Loss FunctionΒΆ

Combining the task \(B\) loss with the quadratic penalty:

\[\mathcal{L}_{\text{EWC}}(\theta) = \mathcal{L}_B(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{A,i}^*)^2\]

Hyperparameter \(\lambda\): Controls the trade-off between:

  • High \(\lambda\): Strong retention of task \(A\) (low forgetting, but may underfit task \(B\))

  • Low \(\lambda\): Better performance on task \(B\) (but more forgetting of task \(A\))

Practical ComputationΒΆ

Sampling-based Fisher estimation:

  1. Sample inputs \(x\) from task \(A\) data

  2. Sample labels \(\hat{y}\) from the model’s predictive distribution \(p(y|x, \theta_A^*)\)

  3. Compute gradients \(\nabla_\theta \log p(\hat{y} | x, \theta_A^*)\)

  4. Average squared gradients:

\[\hat{F}_i = \frac{1}{N} \sum_{n=1}^N \left( \frac{\partial \log p(\hat{y}_n | x_n, \theta_A^*)}{\partial \theta_i} \right)^2\]

Diagonal approximation: Computing the full Fisher matrix is intractable for large networks (\(O(|\theta|^2)\) memory). EWC uses only the diagonal, assuming parameter independence.

3. Progressive Neural NetworksΒΆ

ArchitectureΒΆ

Progressive Neural Networks (Rusu et al., 2016) freeze old task columns and add new columns for each task, with lateral connections from old to new.

Structure for task \(t\):

\[h_i^{(t)} = f\left( W_i^{(t)} h_{i-1}^{(t)} + \sum_{j<t} U_i^{(j \to t)} h_{i-1}^{(j)} \right)\]

where:

  • \(h_i^{(t)}\): Hidden layer \(i\) activations for task \(t\)

  • \(W_i^{(t)}\): Weights within column \(t\)

  • \(U_i^{(j \to t)}\): Lateral connections from column \(j\) to column \(t\)

Key PropertiesΒΆ

Zero forgetting: Old columns are frozen, so performance on old tasks is exactly preserved.

Knowledge transfer: Lateral connections allow task \(t\) to reuse features from tasks \(1, \ldots, t-1\).

Capacity growth: Model size grows linearly with the number of tasks: \(\theta_{\text{total}} = \sum_{t=1}^T |\theta_t|\).

When to UseΒΆ

Advantages:

  • No forgetting by construction

  • Automatic transfer learning via lateral connections

  • Each task gets dedicated capacity

Disadvantages:

  • Memory grows unbounded

  • Inference cost increases with tasks

  • Not suitable for resource-constrained settings

4. Memory-Based ApproachesΒΆ

Experience ReplayΒΆ

Store a memory buffer \(\mathcal{M}\) containing examples from previous tasks. When learning task \(t\), sample from both \(\mathcal{D}_t\) and \(\mathcal{M}\):

\[\mathcal{L}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}_t \cup \mathcal{M}} \left[ \ell(f_\theta(x), y) \right]\]

Reservoir sampling: To maintain a fixed-size buffer, use reservoir sampling to ensure uniform representation across tasks.

Advantages:

  • Simple and effective

  • Works with any architecture

  • Can balance old and new data

Disadvantages:

  • Requires storing raw data (privacy concerns)

  • Memory limited

  • May not scale to many tasks

Generative ReplayΒΆ

Train a generative model \(G\) alongside the task network. Instead of storing real data, generate pseudo-examples from old tasks:

\[\tilde{x}_{\text{old}} \sim G(\mathcal{D}_1, \ldots, \mathcal{D}_{t-1})\]

Advantage: No need to store raw data.

Disadvantage: Requires training a high-quality generative model (e.g., VAE, GAN).

Pseudo-RehearsalΒΆ

Similar to generative replay, but uses the task network itself to generate labels for synthetic inputs:

  1. Generate random input \(\tilde{x}\)

  2. Use old model \(\theta_{\text{old}}\) to label: \(\tilde{y} = f_{\theta_{\text{old}}}(\tilde{x})\)

  3. Train new model to match: \(\mathcal{L}(\theta) = \ell(f_\theta(\tilde{x}), \tilde{y})\)

5. Regularization-Based MethodsΒΆ

Learning without Forgetting (LwF)ΒΆ

Idea: Constrain the new model to produce similar outputs to the old model on new data.

Knowledge distillation loss:

\[\mathcal{L}_{\text{LwF}}(\theta) = \mathcal{L}_{\text{new}}(\theta) + \lambda \mathbb{E}_{x \sim \mathcal{D}_t} \left[ D_{\text{KL}}\left( f_{\theta_{\text{old}}}(x) \,\|\, f_\theta(x) \right) \right]\]

where \(D_{\text{KL}}\) is the Kullback-Leibler divergence between output distributions.

Soft targets: Use the old model’s softmax probabilities as soft labels:

\[p_{\text{old}}(y | x) = \frac{\exp(z_y / T)}{\sum_k \exp(z_k / T)}\]

with temperature \(T > 1\) to soften the distribution.

PackNetΒΆ

Prune and freeze approach:

  1. Train network on task 1

  2. Prune less important weights (e.g., lowest magnitude)

  3. Freeze remaining important weights

  4. Train remaining capacity on task 2

  5. Repeat

Iterative pruning: After each task, prune a fraction \(p\) of weights. Remaining capacity for task \(t\):

\[C_t = C_0 \cdot (1 - p)^{t-1}\]

Advantage: Simple, no hyperparameters for retention.

Disadvantage: Capacity decreases exponentially; limited scalability.

Hard Attention Masks (HAT)ΒΆ

Learn binary masks \(m_i^{(t)} \in \{0, 1\}\) for each task \(t\), where \(m_i^{(t)} = 1\) means parameter \(\theta_i\) is used for task \(t\).

Forward pass:

\[h^{(t)} = f\left( (W \odot m^{(t)}) \cdot x \right)\]

Mask learning: Use Gumbel-Softmax to learn approximately binary masks via gradient descent.

Advantage: Learns task-specific sub-networks; no forgetting.

Disadvantage: Complex training; requires task IDs at test time.

6. Multi-Task Learning vs Continual LearningΒΆ

Aspect

Multi-Task Learning

Continual Learning

Data access

All tasks available simultaneously

Tasks arrive sequentially

Goal

Joint optimization across tasks

Learn new tasks without forgetting old

Challenge

Negative transfer, task balancing

Catastrophic forgetting

Memory

Requires storing all data

Limited memory for old tasks

Evaluation

Average performance across tasks

Forward transfer + backward transfer (retention)

7. Theoretical GuaranteesΒΆ

PAC-Bayes Bounds for Continual LearningΒΆ

For a sequence of tasks with data distributions \(\mathcal{D}_1, \ldots, \mathcal{D}_T\), the expected risk on all tasks is bounded by:

\[\frac{1}{T} \sum_{t=1}^T R_{\mathcal{D}_t}(\theta_T) \leq \frac{1}{T} \sum_{t=1}^T \hat{R}_{\mathcal{D}_t}(\theta_t) + O\left( \sqrt{\frac{D_{\text{KL}}(q || p) + \log T}{N}} \right)\]

where \(q\) is the learned posterior and \(p\) is the prior.

Interpretation: The bound grows with the number of tasks \(T\), highlighting the difficulty of lifelong learning.

Gradient Episodic Memory (GEM)ΒΆ

Constraint: New gradient should not increase loss on old tasks:

\[\nabla_\theta \mathcal{L}_t(\theta) \cdot g_k \geq 0 \quad \forall k < t\]

where \(g_k\) is the gradient on task \(k\) memory.

Optimization: Solve quadratic program to project gradient into feasible region if constraint is violated.

8. Practical ConsiderationsΒΆ

Choosing a MethodΒΆ

Small number of tasks (< 10):

  • EWC or LwF (simple, effective)

  • Progressive networks (if memory allows)

Many tasks (> 100):

  • Experience replay with reservoir sampling

  • Hard attention masks (if task IDs available)

Resource-constrained (edge devices):

  • PackNet (no additional memory)

  • Small replay buffer

Privacy-sensitive:

  • Generative replay or LwF (no raw data storage)

Hyperparameter TuningΒΆ

EWC \(\lambda\):

  • Start with \(\lambda \in [10^3, 10^5]\)

  • Higher for tasks with high importance

  • Cross-validate on held-out tasks

Replay buffer size:

  • Minimum: \(\sim 100\) examples per class

  • Optimal: \(\sim 1000\) examples per class

  • Trade-off with storage constraints

Temperature \(T\) (LwF):

  • Higher \(T\) β†’ softer targets (more knowledge transfer)

  • Typical: \(T \in [2, 5]\)

Evaluation MetricsΒΆ

Average accuracy:

\[\bar{A}_T = \frac{1}{T} \sum_{t=1}^T a_{t,T}\]

where \(a_{t,T}\) is accuracy on task \(t\) after learning all \(T\) tasks.

Backward transfer (forgetting):

\[\text{BWT} = \frac{1}{T-1} \sum_{t=1}^{T-1} \left( a_{t,T} - a_{t,t} \right)\]

Positive BWT indicates improvement on old tasks (rare); negative indicates forgetting.

Forward transfer:

\[\text{FWT} = \frac{1}{T-1} \sum_{t=2}^T \left( a_{t,t-1} - b_t \right)\]

where \(b_t\) is random chance accuracy. Positive FWT indicates knowledge transfer from previous tasks.

9. Advanced TopicsΒΆ

Continual Meta-LearningΒΆ

Combine meta-learning with continual learning: learn a meta-learner that quickly adapts to new tasks without forgetting.

OML (Online-aware Meta-learning): Meta-train with online (sequential) task arrivals, not episodic sampling.

La-MAML: Modulate MAML’s inner loop learning rate based on parameter importance (similar to EWC).

Task-Free Continual LearningΒΆ

Most methods assume task boundaries are known. Task-free setting: no explicit task IDs.

Challenge: How to detect when a new task begins?

Solutions:

  • Unsupervised change detection (distribution shift)

  • Clustering in representation space

  • Bayesian online changepoint detection

Continual Reinforcement LearningΒΆ

Additional challenges:

  • Non-stationary reward distributions

  • Exploration-exploitation in new environments

  • Policy catastrophic forgetting

Approaches:

  • EWC for policy networks

  • Progressive networks for new environments

  • Experience replay with prioritized sampling

PDF ReferencesΒΆ

πŸ“„ Foundational Papers:

πŸ“„ Surveys:

πŸ“‚ GitHub Resources:

# ============================================================================
# Advanced Continual Learning Implementations
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from collections import defaultdict
import random

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# 1. Complete EWC Implementation with Fisher Matrix
# ============================================================================

class EWC:
    """
    Elastic Weight Consolidation for continual learning.
    
    Implements the full EWC algorithm including:
    - Fisher information matrix computation
    - Quadratic penalty loss
    - Online EWC variant (accumulating Fisher across tasks)
    """
    def __init__(self, model, lambda_ewc=5000, fisher_sample_size=200):
        self.model = model
        self.lambda_ewc = lambda_ewc
        self.fisher_sample_size = fisher_sample_size
        
        # Store Fisher information and optimal parameters for each task
        self.fisher = {}
        self.optim_params = {}
        self.task_count = 0
    
    def _compute_fisher_diagonal(self, dataloader):
        """
        Compute diagonal Fisher information matrix.
        
        Fisher information quantifies how sensitive the model output is
        to changes in each parameter:
        
        F_i = E[(βˆ‚log p(y|x,ΞΈ)/βˆ‚ΞΈ_i)Β²]
        """
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
        
        self.model.eval()
        
        samples_seen = 0
        for inputs, labels in dataloader:
            if samples_seen >= self.fisher_sample_size:
                break
            
            inputs = inputs.to(device)
            self.model.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            
            # Sample from model's predictive distribution
            log_probs = F.log_softmax(outputs, dim=1)
            sampled_labels = torch.multinomial(torch.exp(log_probs), 1).squeeze()
            
            # Compute negative log-likelihood
            loss = F.nll_loss(log_probs, sampled_labels)
            
            # Compute gradients
            loss.backward()
            
            # Accumulate squared gradients (Fisher approximation)
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2) * inputs.size(0)
            
            samples_seen += inputs.size(0)
        
        # Normalize by number of samples
        for n in fisher:
            fisher[n] /= samples_seen
        
        return fisher
    
    def register_task(self, dataloader):
        """
        Register a new task: compute Fisher and save optimal parameters.
        """
        print(f"Registering task {self.task_count + 1}...")
        
        # Compute Fisher information
        fisher = self._compute_fisher_diagonal(dataloader)
        
        # Save Fisher and parameters for this task
        self.fisher[self.task_count] = fisher
        self.optim_params[self.task_count] = {
            n: p.clone().detach()
            for n, p in self.model.named_parameters()
        }
        
        self.task_count += 1
        print(f"Task {self.task_count} registered. Fisher computed over {self.fisher_sample_size} samples.")
    
    def penalty(self):
        """
        Compute EWC quadratic penalty across all previous tasks.
        
        L_EWC = (Ξ»/2) Ξ£_t Ξ£_i F_i^(t) (ΞΈ_i - ΞΈ_i^(t)*)Β²
        """
        loss = 0.0
        
        for task_id in range(self.task_count):
            for n, p in self.model.named_parameters():
                fisher_diag = self.fisher[task_id][n]
                optimal_param = self.optim_params[task_id][n]
                
                # Quadratic penalty weighted by Fisher information
                loss += (fisher_diag * (p - optimal_param).pow(2)).sum()
        
        return (self.lambda_ewc / 2) * loss


# ============================================================================
# 2. Progressive Neural Networks
# ============================================================================

class ProgressiveColumn(nn.Module):
    """
    Single column in Progressive Neural Network.
    
    Each column has:
    - Own feedforward path
    - Lateral connections from all previous columns
    """
    def __init__(self, input_dim, hidden_dims, output_dim, prev_columns=None):
        super().__init__()
        
        self.prev_columns = prev_columns if prev_columns is not None else []
        
        # Build column layers
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.ReLU())
        
        self.column = nn.Sequential(*layers)
        self.output = nn.Linear(hidden_dims[-1], output_dim)
        
        # Lateral connections from previous columns
        if len(self.prev_columns) > 0:
            self.lateral_connections = nn.ModuleList([
                nn.Linear(hidden_dims[-1], hidden_dims[-1])
                for _ in self.prev_columns
            ])
        else:
            self.lateral_connections = None
    
    def forward(self, x, prev_activations=None):
        """
        Forward pass with lateral connections.
        
        h^(t) = f(W^(t) x + Σ_{j<t} U^(j→t) h^(j))
        """
        # Column's own computation
        h = self.column(x)
        
        # Add lateral contributions from previous columns
        if prev_activations is not None and len(prev_activations) > 0:
            for i, (prev_h, lateral) in enumerate(zip(prev_activations, self.lateral_connections)):
                h = h + lateral(prev_h)
        
        output = self.output(h)
        
        return output, h  # Return both output and hidden state for next column


class ProgressiveNeuralNetwork(nn.Module):
    """
    Progressive Neural Networks: Add new columns for each task.
    
    Key properties:
    - Zero forgetting (old columns frozen)
    - Lateral connections enable transfer
    - Model capacity grows with tasks
    """
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        
        self.columns = nn.ModuleList()
    
    def add_task(self):
        """Add a new column for a new task."""
        new_column = ProgressiveColumn(
            self.input_dim, 
            self.hidden_dims, 
            self.output_dim,
            prev_columns=list(self.columns)
        )
        
        # Freeze all previous columns
        for col in self.columns:
            for param in col.parameters():
                param.requires_grad = False
        
        self.columns.append(new_column)
        
        print(f"Added column {len(self.columns)}. Total parameters: {sum(p.numel() for p in self.parameters())}")
    
    def forward(self, x, task_id=None):
        """
        Forward pass through progressive network.
        
        If task_id is None, use the latest column.
        """
        if task_id is None:
            task_id = len(self.columns) - 1
        
        # Compute activations for all columns up to task_id
        prev_activations = []
        for i in range(task_id):
            with torch.no_grad():  # Previous columns are frozen
                _, h = self.columns[i](x, prev_activations)
            prev_activations.append(h)
        
        # Forward through target column with lateral connections
        output, _ = self.columns[task_id](x, prev_activations)
        
        return output


# ============================================================================
# 3. Experience Replay with Reservoir Sampling
# ============================================================================

class ReplayBuffer:
    """
    Experience replay buffer with reservoir sampling.
    
    Maintains a fixed-size buffer with uniform sampling across all seen tasks.
    """
    def __init__(self, max_size=1000):
        self.max_size = max_size
        self.buffer = []
        self.total_seen = 0
    
    def add(self, x, y):
        """
        Add examples using reservoir sampling algorithm.
        
        Ensures uniform probability (1/n) for each of n seen examples.
        """
        for i in range(len(x)):
            if len(self.buffer) < self.max_size:
                # Buffer not full: add directly
                self.buffer.append((x[i].cpu(), y[i].cpu()))
            else:
                # Reservoir sampling: replace with probability 1/total_seen
                j = random.randint(0, self.total_seen)
                if j < self.max_size:
                    self.buffer[j] = (x[i].cpu(), y[i].cpu())
            
            self.total_seen += 1
    
    def sample(self, batch_size):
        """Sample a mini-batch from the buffer."""
        if len(self.buffer) == 0:
            return None, None
        
        indices = random.sample(range(len(self.buffer)), min(batch_size, len(self.buffer)))
        
        batch_x = torch.stack([self.buffer[i][0] for i in indices]).to(device)
        batch_y = torch.tensor([self.buffer[i][1] for i in indices], dtype=torch.long).to(device)
        
        return batch_x, batch_y
    
    def __len__(self):
        return len(self.buffer)


# ============================================================================
# 4. Learning without Forgetting (LwF)
# ============================================================================

class LwF:
    """
    Learning without Forgetting via knowledge distillation.
    
    Constrains new model to produce similar outputs to old model on new data.
    """
    def __init__(self, model, temperature=2.0, alpha=0.5):
        self.model = model
        self.temperature = temperature  # Temperature for soft targets
        self.alpha = alpha  # Weight for distillation loss
        
        self.old_model = None
    
    def register_task(self):
        """Save current model as 'old' for next task."""
        self.old_model = deepcopy(self.model)
        self.old_model.eval()
        for param in self.old_model.parameters():
            param.requires_grad = False
        
        print("Old model saved for distillation.")
    
    def distillation_loss(self, outputs, inputs):
        """
        Compute knowledge distillation loss.
        
        L_distill = KL(softmax(z_old/T) || softmax(z_new/T))
        
        where T > 1 softens the distributions.
        """
        if self.old_model is None:
            return 0.0
        
        with torch.no_grad():
            old_outputs = self.old_model(inputs)
        
        # Soft targets with temperature scaling
        soft_targets = F.softmax(old_outputs / self.temperature, dim=1)
        soft_predictions = F.log_softmax(outputs / self.temperature, dim=1)
        
        # KL divergence with temperature scaling
        # Multiply by TΒ² to compensate for softened gradients
        distill_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean')
        distill_loss = distill_loss * (self.temperature ** 2)
        
        return distill_loss
    
    def combined_loss(self, outputs, labels, inputs):
        """
        Combine task loss with distillation loss.
        
        L = L_task + Ξ± * L_distill
        """
        task_loss = F.cross_entropy(outputs, labels)
        
        if self.old_model is not None:
            distill_loss = self.distillation_loss(outputs, inputs)
            return task_loss + self.alpha * distill_loss
        
        return task_loss


# ============================================================================
# 5. Gradient Episodic Memory (GEM)
# ============================================================================

class GEM:
    """
    Gradient Episodic Memory: Constrain gradients to not increase loss on old tasks.
    
    Solves quadratic program:
        min ||g - g_t||Β²
        s.t. g^T g_k β‰₯ 0  βˆ€ k < t
    
    where g_t is current gradient, g_k is gradient on task k memory.
    """
    def __init__(self, model, memory_per_task=256, margin=0.5):
        self.model = model
        self.memory_per_task = memory_per_task
        self.margin = margin
        
        self.memory = []  # List of (inputs, labels) for each task
    
    def store_task_memory(self, dataloader):
        """Store a subset of examples for this task."""
        task_memory = []
        
        for inputs, labels in dataloader:
            task_memory.extend(list(zip(inputs.cpu(), labels.cpu())))
            if len(task_memory) >= self.memory_per_task:
                break
        
        # Random sample if we collected too many
        if len(task_memory) > self.memory_per_task:
            task_memory = random.sample(task_memory, self.memory_per_task)
        
        self.memory.append(task_memory)
        print(f"Stored {len(task_memory)} examples for task {len(self.memory)}")
    
    def compute_gradient(self, inputs, labels):
        """Compute gradient for given data."""
        self.model.zero_grad()
        
        outputs = self.model(inputs.to(device))
        loss = F.cross_entropy(outputs, labels.to(device))
        loss.backward()
        
        # Extract gradient as flat vector
        grad = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
        
        return grad
    
    def project_gradient(self, current_grad):
        """
        Project gradient to satisfy constraints: g^T g_k β‰₯ 0 for all old tasks.
        
        Uses quadratic programming to find closest gradient that doesn't
        increase loss on old tasks.
        """
        if len(self.memory) == 0:
            return current_grad
        
        # Compute gradients for all previous tasks
        ref_grads = []
        for task_mem in self.memory:
            inputs = torch.stack([x for x, _ in task_mem[:32]])  # Sample from memory
            labels = torch.tensor([y for _, y in task_mem[:32]], dtype=torch.long)
            
            grad_k = self.compute_gradient(inputs, labels)
            ref_grads.append(grad_k)
        
        # Check if current gradient violates constraints
        violations = [(current_grad @ g_k).item() < 0 for g_k in ref_grads]
        
        if not any(violations):
            return current_grad  # No violations, use original gradient
        
        # Simple projection: average with violating gradients
        # (Full GEM would solve QP, but this is a practical approximation)
        violated_grads = [g_k for g_k, v in zip(ref_grads, violations) if v]
        
        if len(violated_grads) > 0:
            avg_violated = torch.stack(violated_grads).mean(dim=0)
            projected_grad = current_grad - avg_violated * 0.5
            return projected_grad
        
        return current_grad


# ============================================================================
# 6. Comparison Experiment
# ============================================================================

def create_split_mnist():
    """Create two-task split MNIST: tasks 0-4 and 5-9."""
    from torchvision import datasets, transforms
    
    transform = transforms.Compose([transforms.ToTensor()])
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    # Task 1: digits 0-4
    train_indices_t1 = [i for i, (_, label) in enumerate(train_dataset) if label < 5]
    test_indices_t1 = [i for i, (_, label) in enumerate(test_dataset) if label < 5]
    
    # Task 2: digits 5-9
    train_indices_t2 = [i for i, (_, label) in enumerate(train_dataset) if label >= 5]
    test_indices_t2 = [i for i, (_, label) in enumerate(test_dataset) if label >= 5]
    
    train_t1 = Subset(train_dataset, train_indices_t1)
    test_t1 = Subset(test_dataset, test_indices_t1)
    train_t2 = Subset(train_dataset, train_indices_t2)
    test_t2 = Subset(test_dataset, test_indices_t2)
    
    return train_t1, test_t1, train_t2, test_t2


def evaluate_model(model, test_loader):
    """Evaluate model accuracy."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100.0 * correct / total


print("\n" + "="*80)
print("Continual Learning Methods Comparison")
print("="*80)

# Create datasets
train_t1, test_t1, train_t2, test_t2 = create_split_mnist()

train_loader_t1 = DataLoader(train_t1, batch_size=128, shuffle=True)
test_loader_t1 = DataLoader(test_t1, batch_size=512, shuffle=False)
train_loader_t2 = DataLoader(train_t2, batch_size=128, shuffle=True)
test_loader_t2 = DataLoader(test_t2, batch_size=512, shuffle=False)

print(f"Task 1 train: {len(train_t1)}, test: {len(test_t1)}")
print(f"Task 2 train: {len(train_t2)}, test: {len(test_t2)}")

# Simple model for experiments
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


# ============================================================================
# Compare: EWC vs Replay vs LwF
# ============================================================================

results = {}

# 1. Baseline (Sequential, no CL method)
print("\n--- Baseline (No CL) ---")
model_baseline = SimpleNet().to(device)
optimizer_baseline = torch.optim.Adam(model_baseline.parameters(), lr=1e-3)

# Train task 1
for epoch in range(3):
    model_baseline.train()
    for inputs, labels in train_loader_t1:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_baseline.zero_grad()
        loss = F.cross_entropy(model_baseline(inputs), labels)
        loss.backward()
        optimizer_baseline.step()

acc_t1_after_t1 = evaluate_model(model_baseline, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_after_t1:.2f}%")

# Train task 2
for epoch in range(3):
    model_baseline.train()
    for inputs, labels in train_loader_t2:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_baseline.zero_grad()
        loss = F.cross_entropy(model_baseline(inputs), labels)
        loss.backward()
        optimizer_baseline.step()

acc_t1_after_t2 = evaluate_model(model_baseline, test_loader_t1)
acc_t2_after_t2 = evaluate_model(model_baseline, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_after_t2:.2f}%, T2 acc: {acc_t2_after_t2:.2f}%")
print(f"Forgetting: {acc_t1_after_t1 - acc_t1_after_t2:.2f}%")

results['Baseline'] = {
    't1_after_t1': acc_t1_after_t1,
    't1_after_t2': acc_t1_after_t2,
    't2_after_t2': acc_t2_after_t2,
    'forgetting': acc_t1_after_t1 - acc_t1_after_t2
}


# 2. EWC
print("\n--- EWC ---")
model_ewc_test = SimpleNet().to(device)
optimizer_ewc = torch.optim.Adam(model_ewc_test.parameters(), lr=1e-3)
ewc = EWC(model_ewc_test, lambda_ewc=5000, fisher_sample_size=500)

# Train task 1
for epoch in range(3):
    model_ewc_test.train()
    for inputs, labels in train_loader_t1:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_ewc.zero_grad()
        loss = F.cross_entropy(model_ewc_test(inputs), labels)
        loss.backward()
        optimizer_ewc.step()

acc_t1_ewc = evaluate_model(model_ewc_test, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_ewc:.2f}%")

# Register task 1
ewc.register_task(train_loader_t1)

# Train task 2 with EWC
for epoch in range(3):
    model_ewc_test.train()
    for inputs, labels in train_loader_t2:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_ewc.zero_grad()
        
        loss_task = F.cross_entropy(model_ewc_test(inputs), labels)
        loss_ewc_penalty = ewc.penalty()
        loss = loss_task + loss_ewc_penalty
        
        loss.backward()
        optimizer_ewc.step()

acc_t1_ewc_final = evaluate_model(model_ewc_test, test_loader_t1)
acc_t2_ewc_final = evaluate_model(model_ewc_test, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_ewc_final:.2f}%, T2 acc: {acc_t2_ewc_final:.2f}%")
print(f"Forgetting: {acc_t1_ewc - acc_t1_ewc_final:.2f}%")

results['EWC'] = {
    't1_after_t1': acc_t1_ewc,
    't1_after_t2': acc_t1_ewc_final,
    't2_after_t2': acc_t2_ewc_final,
    'forgetting': acc_t1_ewc - acc_t1_ewc_final
}


# 3. Experience Replay
print("\n--- Experience Replay ---")
model_replay = SimpleNet().to(device)
optimizer_replay = torch.optim.Adam(model_replay.parameters(), lr=1e-3)
replay_buffer = ReplayBuffer(max_size=1000)

# Train task 1
for epoch in range(3):
    model_replay.train()
    for inputs, labels in train_loader_t1:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Store in buffer
        replay_buffer.add(inputs, labels)
        
        optimizer_replay.zero_grad()
        loss = F.cross_entropy(model_replay(inputs), labels)
        loss.backward()
        optimizer_replay.step()

acc_t1_replay = evaluate_model(model_replay, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_replay:.2f}%")

# Train task 2 with replay
for epoch in range(3):
    model_replay.train()
    for inputs, labels in train_loader_t2:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Sample from replay buffer
        replay_x, replay_y = replay_buffer.sample(batch_size=64)
        
        if replay_x is not None:
            # Combine current batch with replay
            combined_x = torch.cat([inputs, replay_x], dim=0)
            combined_y = torch.cat([labels, replay_y], dim=0)
        else:
            combined_x, combined_y = inputs, labels
        
        optimizer_replay.zero_grad()
        loss = F.cross_entropy(model_replay(combined_x), combined_y)
        loss.backward()
        optimizer_replay.step()
        
        # Add task 2 to buffer
        replay_buffer.add(inputs, labels)

acc_t1_replay_final = evaluate_model(model_replay, test_loader_t1)
acc_t2_replay_final = evaluate_model(model_replay, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_replay_final:.2f}%, T2 acc: {acc_t2_replay_final:.2f}%")
print(f"Forgetting: {acc_t1_replay - acc_t1_replay_final:.2f}%")

results['Replay'] = {
    't1_after_t1': acc_t1_replay,
    't1_after_t2': acc_t1_replay_final,
    't2_after_t2': acc_t2_replay_final,
    'forgetting': acc_t1_replay - acc_t1_replay_final
}

# Visualization
print("\n" + "="*80)
print("Final Results Summary")
print("="*80)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

methods = list(results.keys())
t1_final = [results[m]['t1_after_t2'] for m in methods]
t2_final = [results[m]['t2_after_t2'] for m in methods]
forgetting = [results[m]['forgetting'] for m in methods]

# Task accuracies
x = np.arange(len(methods))
width = 0.35

axes[0].bar(x - width/2, t1_final, width, label='Task 1 (0-4)', alpha=0.8)
axes[0].bar(x + width/2, t2_final, width, label='Task 2 (5-9)', alpha=0.8)
axes[0].set_ylabel('Accuracy (%)', fontsize=11)
axes[0].set_title('Final Task Performance', fontsize=12, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(methods, rotation=15, ha='right')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# Forgetting
colors = ['red' if f > 10 else 'green' for f in forgetting]
axes[1].bar(methods, forgetting, color=colors, alpha=0.7)
axes[1].set_ylabel('Forgetting (%)', fontsize=11)
axes[1].set_title('Task 1 Forgetting', fontsize=12, fontweight='bold')
axes[1].axhline(y=0, color='black', linestyle='--', linewidth=1)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

for method in methods:
    r = results[method]
    print(f"{method:15s}: T1={r['t1_after_t2']:.2f}%, T2={r['t2_after_t2']:.2f}%, Forgetting={r['forgetting']:.2f}%")

print("\n" + "="*80)
print("Implementation Complete!")
print("="*80)
print("\nKey Insights:")
print("1. EWC reduces forgetting by penalizing changes to important parameters")
print("2. Experience Replay maintains performance by rehearsing old examples")
print("3. LwF uses knowledge distillation to preserve old task predictions")
print("4. Progressive networks achieve zero forgetting but grow linearly")
print("5. Trade-offs: Memory (Replay), Computation (EWC), Capacity (Progressive)")
print("\nNext: Apply to real-world continual learning scenarios!")

Advanced Continual Learning TheoryΒΆ

1. Introduction to Continual LearningΒΆ

Continual Learning (also called Lifelong Learning or Incremental Learning) enables models to learn from a stream of data without forgetting previously acquired knowledge.

1.1 The Plasticity-Stability DilemmaΒΆ

Challenge: Balance two competing objectives:

  • Plasticity: Ability to learn new tasks

  • Stability: Preserve knowledge of old tasks

Catastrophic Forgetting: When learning new task T_n, performance on T_1, …, T_{n-1} degrades significantly.

1.2 Problem FormulationΒΆ

Task sequence: T_1, T_2, …, T_n

Goal: After learning task T_n, maintain performance on all tasks:

min ΞΈ_n Ξ£α΅’β‚Œβ‚βΏ L_i(ΞΈ_n)

Constraints:

  • No access to data from previous tasks (or limited buffer)

  • Bounded memory and computation

  • Tasks may be related or unrelated

1.3 Continual Learning ScenariosΒΆ

  1. Task-Incremental Learning (Task-IL):

    • Task identity known at test time

    • Multi-head architecture (one head per task)

    • Example: Learn tasks 1, 2, 3; test on task 2 (given task ID)

  2. Domain-Incremental Learning (Domain-IL):

    • Same task, different domains

    • Single-head architecture

    • Example: MNIST β†’ SVHN β†’ USPS

  3. Class-Incremental Learning (Class-IL):

    • New classes added over time

    • Most challenging (no task ID at test)

    • Example: Learn cats/dogs, then birds/fish

2. Regularization-Based ApproachesΒΆ

2.1 Elastic Weight Consolidation (EWC)ΒΆ

EWC [Kirkpatrick et al., 2017]: Protect important weights for old tasks.

Objective for task n:

L(ΞΈ) = L_n(ΞΈ) + (Ξ»/2) Ξ£α΅’ F_i (ΞΈα΅’ - ΞΈ*_{n-1,i})Β²

Where:

  • L_n(ΞΈ): Loss on current task

  • F_i: Fisher Information Matrix diagonal

  • ΞΈ*_{n-1}: Parameters after task n-1

  • Ξ»: Regularization strength

Fisher Information:

F_i = E_{x~D} [(βˆ‚ log p(y|x,ΞΈ)/βˆ‚ΞΈα΅’)Β²]

Approximation:

F_i β‰ˆ (1/N) Ξ£β‚“ (βˆ‚L(x,ΞΈ*)/βˆ‚ΞΈα΅’)Β²

Intuition: Important parameters (high Fisher) are penalized more if changed.

2.2 Synaptic Intelligence (SI)ΒΆ

SI [Zenke et al., 2017]: Measure importance via path integral.

Importance:

Ξ©_i = Ξ£_{t=1}^T (βˆ‚L/βˆ‚ΞΈα΅’)Β² Β· Δθᡒ

Accumulates gradient magnitudes weighted by parameter changes.

Loss:

L(ΞΈ) = L_n(ΞΈ) + c Ξ£α΅’ Ξ©α΅’(ΞΈα΅’ - ΞΈ*_{n-1,i})Β²

Advantage: Online computation (no need to store data).

2.3 Memory Aware Synapses (MAS)ΒΆ

MAS [Aljundi et al., 2018]: Importance based on output sensitivity.

Importance:

Ξ©_i = E_x [||βˆ‚f(x,ΞΈ)/βˆ‚ΞΈα΅’||Β²]

Difference from EWC: Uses output sensitivity, not loss gradient.

Benefit: Unsupervised (doesn’t need labels).

2.4 Learning without Forgetting (LwF)ΒΆ

LwF [Li & Hoiem, 2017]: Distillation from old model.

Loss:

L = L_new(ΞΈ) + Ξ± L_distill(ΞΈ, ΞΈ_old)

Where:

L_distill = KL(p_old(x) || p_new(x))

Knowledge distillation: New model mimics old model’s outputs on new data.

Advantage: No need to store old data.

3. Replay-Based ApproachesΒΆ

3.1 Experience Replay (ER)ΒΆ

Core idea: Store subset of old data, replay when learning new tasks.

Memory buffer M: Store (x, y, task_id) tuples.

Training: Mix current data with replayed data:

L(ΞΈ) = E_{(x,y)~D_n} [L(x,y,ΞΈ)] + E_{(x,y)~M} [L(x,y,ΞΈ)]

Buffer management:

  • Reservoir sampling: Random replacement

  • Herding: Select most representative samples

  • Class-balanced: Equal samples per class

3.2 Gradient Episodic Memory (GEM)ΒΆ

GEM [Lopez-Paz & Ranzato, 2017]: Constrained optimization.

Constraint: New gradients shouldn’t increase loss on old tasks:

⟨g_new, g_old^i⟩ β‰₯ 0  for all i < n

Where:

  • g_new = βˆ‡_ΞΈ L_n(ΞΈ)

  • g_old^i = βˆ‡_ΞΈ L_i(ΞΈ) (computed on memory buffer)

Projection: If constraint violated, project g_new to feasible region.

Quadratic program:

min ||g - g_new||Β²
s.t. ⟨g, g_old^i⟩ β‰₯ 0

3.3 Averaged GEM (A-GEM)ΒΆ

A-GEM [Chaudhry et al., 2019]: Simplified GEM.

Single constraint: Average gradient on memory:

⟨g_new, αΈ‘_old⟩ β‰₯ 0

Where αΈ‘_old is average of gradients on memory samples.

Projection:

g = g_new - ⟨g_new, ḑ_old⟩/||ḑ_old||² · ḑ_old

Advantage: O(1) constraints vs. O(tasks) for GEM.

3.4 Dark Experience Replay (DER)ΒΆ

DER [Buzzega et al., 2020]: Store logits instead of labels.

Memory: (x, logits_old, task_id)

Loss:

L = L_CE(y, f(x)) + Ξ± MSE(logits_old, f(x))

Benefit: Richer information than just labels, mitigates recency bias.

4. Dynamic Architecture ApproachesΒΆ

4.1 Progressive Neural NetworksΒΆ

ProgNN [Rusu et al., 2016]: Add new columns for new tasks.

Architecture:

  • Task 1: Column C_1

  • Task 2: Column C_2 + lateral connections from C_1

  • Task n: Column C_n + lateral connections from C_1, …, C_{n-1}

Forward pass for task n:

h_n^(l) = Οƒ(W_n^(l) h_n^(l-1) + Ξ£α΅’β‚Œβ‚^{n-1} U_iβ†’n^(l) h_i^(l-1))

Advantages:

  • No forgetting (old parameters frozen)

  • Transfer via lateral connections

Disadvantages:

  • O(nΒ²) parameters growth

  • Requires task ID

4.2 PackNetΒΆ

PackNet [Mallya & Lazebnik, 2018]: Network pruning + packing.

Procedure for task n:

  1. Train full network on task n

  2. Prune unimportant weights (magnitude-based)

  3. Pack: Mark remaining weights as used by task n

  4. Freeze used weights

Inference: Activate only weights for given task.

Parameter efficiency: Reuses parameters across tasks.

4.3 Dynamic Expandable Network (DEN)ΒΆ

DEN [Yoon et al., 2018]: Selective retraining and expansion.

Algorithm:

  1. Selective retraining: Retrain low-drift neurons

  2. Dynamic expansion: Add neurons if capacity insufficient

  3. Network split/duplication: Create task-specific paths

Drift for neuron j:

d_j = ||ΞΈ_j - ΞΈ*_j||Β² / (Οƒ_jΒ² + Ξ΅)

Expansion criterion: If validation loss doesn’t decrease, add neurons.

5. Meta-Learning for Continual LearningΒΆ

5.1 Meta-Experience Replay (MER)ΒΆ

MER [Riemer et al., 2019]: Meta-learn with replay buffer.

Inner loop: Update on current batch Outer loop: Evaluate on replay buffer, update to minimize replay loss

Algorithm:

ΞΈ' = ΞΈ - Ξ± βˆ‡_ΞΈ L_current(ΞΈ)
ΞΈ ← ΞΈ - Ξ² βˆ‡_ΞΈ L_replay(ΞΈ')

Benefit: Meta-learned updates reduce forgetting.

5.2 Online-aware Meta-learning (OML)ΒΆ

OML [Javed & White, 2019]: Meta-learn representations for continual learning.

Objective: Learn representation that enables online learning:

min_ΞΈ E_{T~p(T)} [L_T(ΞΈ - Ξ± βˆ‡_ΞΈ L_T^train(ΞΈ))]

Training: Sample task sequences, meta-optimize for online adaptation.

5.3 La-MAMLΒΆ

La-MAML [Gupta et al., 2020]: Look-ahead MAML for continual learning.

Key idea: Meta-learn initialization that resists forgetting when adapted.

Meta-objective:

min_ΞΈ Ξ£α΅’ L_i(ΞΈ_{1:i}) + Ξ» Ξ£α΅’<β±Ό L_i(ΞΈ_{1:j})

Optimizes both current and future performance on past tasks.

6. Generative ReplayΒΆ

6.1 Deep Generative ReplayΒΆ

DGR [Shin et al., 2017]: Train generative model alongside discriminative model.

Components:

  • Generator G: Synthesizes samples from old tasks

  • Solver S: Discriminative model

Training on task n:

  1. Generate pseudo-samples from G for tasks 1, …, n-1

  2. Train S on real samples (task n) + generated samples (tasks <n)

  3. Update G to generate samples for task n

Loss:

L = L_real(x_n, y_n) + L_generated(G(z), S(G(z)))

6.2 Conditional GAN for ReplayΒΆ

Use class-conditional GAN: G(z, c) generates samples for class c.

Advantage: Can generate specific classes on demand.

Challenges:

  • Generator quality crucial

  • Mode collapse can lose diversity

  • Computational overhead

7. Theoretical FoundationsΒΆ

7.1 Catastrophic Forgetting AnalysisΒΆ

Linear models: Catastrophic forgetting occurs when new task gradients are orthogonal to old task gradients.

Neural networks: Over-parameterization helps (multiple solutions exist).

Theorem [Goodfellow et al., 2013]: For linear models, forgetting is unavoidable unless tasks share structure.

7.2 Stability-Plasticity Trade-offΒΆ

Formal definition:

  • Stability: ||ΞΈ_n - ΞΈ_{n-1}||Β² small

  • Plasticity: L_n(ΞΈ_n) small

Pareto frontier: Can’t minimize both simultaneously.

EWC perspective: Fisher matrix balances this trade-off.

7.3 Generalization BoundsΒΆ

PAC-Bayesian bound for continual learning:

With probability β‰₯ 1-Ξ΄:

Ξ£α΅’ L_i(ΞΈ_n) ≀ Ξ£α΅’ LΜ‚_i(ΞΈ_n) + O(√(KL(Q||P) + log(n/Ξ΄)) / √N)

Where:

  • Q: Posterior after n tasks

  • P: Prior

  • N: Total samples

Implication: Regularization toward prior (e.g., EWC) helps generalization.

8. Evaluation MetricsΒΆ

8.1 Average AccuracyΒΆ

After learning n tasks:

ACC_n = (1/n) Ξ£α΅’β‚Œβ‚βΏ a_{n,i}

Where a_{n,i} is accuracy on task i after learning task n.

8.2 Forgetting MeasureΒΆ

Average forgetting:

FM_n = (1/(n-1)) Ξ£α΅’β‚Œβ‚βΏβ»ΒΉ (max_j a_{j,i} - a_{n,i})

Measures maximum drop in performance on each task.

8.3 Backward Transfer (BWT)ΒΆ

BWT: How much learning new tasks affects old tasks:

BWT_n = (1/(n-1)) Ξ£α΅’β‚Œβ‚βΏβ»ΒΉ (a_{n,i} - a_{i,i})

Negative BWT indicates forgetting.

8.4 Forward Transfer (FWT)ΒΆ

FWT: How much old tasks help learn new tasks:

FWT_n = (1/(n-1)) Ξ£α΅’β‚Œβ‚‚βΏ (a_{i-1,i} - a_{0,i})

Where a_{0,i} is random initialization performance.

Positive FWT indicates transfer learning.

9. Advanced TechniquesΒΆ

9.1 Task-Free Continual LearningΒΆ

Challenge: No explicit task boundaries.

Approaches:

  • Boundary detection: Detect distribution shift

  • Pseudo-rehearsal: Generate samples continuously

  • Uncertainty-based: Use prediction uncertainty to detect new data

9.2 Online Continual LearningΒΆ

Setting: One pass through data stream, no task boundaries.

Challenges:

  • No clear training/testing split

  • Must decide when to update

  • Memory budget constraints

Metrics: Average accuracy over time, anytime performance.

9.3 Continual Learning with Imbalanced DataΒΆ

Problem: New classes may have few samples (long-tail distribution).

Solutions:

  • Class balancing: Oversample rare classes in replay

  • Focal loss: Focus on hard examples

  • Two-stage training: Learn representations, then classifier

9.4 Continual Learning under Label NoiseΒΆ

Noisy labels in stream:

  • Robust loss: Use symmetric loss functions

  • Sample selection: Filter likely mislabeled samples

  • Co-teaching: Two networks teach each other

10. Practical ConsiderationsΒΆ

10.1 Memory BudgetΒΆ

Typical budgets: 500-5000 samples (vs. millions in full dataset).

Strategies:

  • Per-class: K samples per class

  • Per-task: Fixed budget per task

  • Global: Total budget across all tasks

Update policy:

  • Reservoir sampling: Probabilistic replacement

  • Ring buffer: FIFO

  • Coreset selection: Optimize representativeness

10.2 Computational EfficiencyΒΆ

EWC: O(|ΞΈ|) extra memory (Fisher diagonal) GEM: O(nΒ·batch_size) memory (gradients per task) Replay: O(buffer_size) memory

Training time:

  • Regularization: ~1.2Γ— base

  • Replay: ~1.5Γ— base (due to extra forward passes)

  • Dynamic architectures: ~2Γ— base

10.3 Hyperparameter SensitivityΒΆ

Critical hyperparameters:

Method

Key Hyperparameters

Typical Range

EWC

Ξ» (importance)

10²-10⁡

Replay

Buffer size

500-5000

GEM

Memory per task

50-500

LwF

Ξ± (distillation)

1-10

Tuning: Use validation set from current task (no access to old tasks).

11. Benchmark DatasetsΒΆ

11.1 Permuted MNISTΒΆ

Setup: 10 tasks, each with permuted pixels.

Evaluation: Task-IL (10-way classification per task).

Baseline: ~90% after 10 tasks (naive fine-tuning: ~20%).

11.2 Split CIFAR-100ΒΆ

Setup: 10 or 20 tasks, disjoint classes.

Variants:

  • 10 tasks Γ— 10 classes

  • 20 tasks Γ— 5 classes

Challenge: Class-IL (100-way classification without task ID).

11.3 CORe50ΒΆ

CORe50: Continual learning on real-world objects.

Properties:

  • 50 objects, 11 sessions

  • Indoor/outdoor lighting changes

  • 164,866 images

Scenarios: NI (new instances), NC (new classes), NIC (both).

12. Comparison of ApproachesΒΆ

12.1 Performance SummaryΒΆ

Forgetting (lower is better):

Method

Permuted MNIST

Split CIFAR-100

Fine-tuning

85%

75%

EWC

15%

35%

SI

18%

40%

GEM

2%

20%

A-GEM

8%

28%

ER (500)

5%

25%

DER

3%

18%

12.2 Memory-Performance Trade-offΒΆ

Memory usage vs. Final accuracy:

  • No replay: 0 extra memory, 40-50% accuracy

  • EWC: |ΞΈ| floats, 60-70% accuracy

  • Replay (100): 100 samples, 70-75% accuracy

  • Replay (1000): 1000 samples, 80-85% accuracy

  • Progressive: O(nΒ²) params, 90-95% accuracy

Insight: Replay is most memory-efficient for good performance.

12.3 Computational CostΒΆ

Relative training time (vs. base):

  • EWC: 1.1-1.3Γ—

  • SI: 1.2-1.4Γ—

  • Replay: 1.4-1.8Γ—

  • GEM: 2-3Γ—

  • Progressive: 1.5-2Γ—

13. Recent Advances (2020-2024)ΒΆ

13.1 Self-Supervised Continual LearningΒΆ

Approach: Learn representations via contrastive learning, then fine-tune.

Benefits: Better transfer, less catastrophic forgetting.

Methods: SimCLR + replay, MoCo + EWC.

13.2 Transformer-Based Continual LearningΒΆ

L2P [Wang et al., 2022]: Learning to Prompt for continual learning.

Idea: Learn prompt pool, select relevant prompts for each task.

DualPrompt [Wang et al., 2022]: Separate prompts for general and task-specific knowledge.

13.3 Continual Pre-trainingΒΆ

Scenario: Continually pre-train large language models on new data.

Challenges:

  • Forgetting common knowledge

  • Domain shift

  • Computational cost

Solutions: Replay, regularization, modular architectures.

13.4 Federated Continual LearningΒΆ

Setting: Multiple clients, each learning continually + federated aggregation.

Challenges:

  • Heterogeneous task sequences

  • Communication efficiency

  • Privacy constraints

14. Limitations and Open ProblemsΒΆ

14.1 Current LimitationsΒΆ

  1. Task boundaries: Most methods assume clear task boundaries

  2. Memory requirements: Replay needs storage

  3. Scalability: Many methods don’t scale to 100+ tasks

  4. Theoretical understanding: Limited guarantees

14.2 Open QuestionsΒΆ

  1. Optimal replay strategy: What to store? How to use?

  2. Architecture design: Fixed vs. dynamic? How much capacity?

  3. Task similarity: How to measure? How to leverage?

  4. Evaluation: What metrics best capture continual learning ability?

14.3 Future DirectionsΒΆ

  1. Biologically-inspired: Complementary learning systems (hippocampus + neocortex)

  2. Curriculum: Learn tasks in optimal order

  3. Meta-continual learning: Learn to continually learn

  4. Compositional: Reuse learned modules

15. Key TakeawaysΒΆ

  1. Catastrophic forgetting is the central challenge in continual learning

  2. Regularization (EWC, SI, MAS) protects important parameters

  3. Replay stores samples from old tasks, most effective approach

  4. GEM/A-GEM constrain gradients to not harm old tasks

  5. Dynamic architectures avoid forgetting but grow over time

  6. Meta-learning can improve continual learning ability

  7. Memory-performance trade-off: More memory β†’ less forgetting

  8. Evaluation metrics: ACC, FM, BWT, FWT capture different aspects

  9. Benchmarks: Permuted MNIST (easy), Split CIFAR (medium), CORe50 (hard)

  10. No silver bullet: Best method depends on scenario and constraints

Core insight: Balance stability (preserving old knowledge) and plasticity (learning new knowledge).

16. Mathematical SummaryΒΆ

Continual learning objective:

min_ΞΈ Ξ£α΅’β‚Œβ‚βΏ L_i(ΞΈ) + R(ΞΈ)

Where R(ΞΈ) is regularization.

EWC regularization:

R(ΞΈ) = (Ξ»/2) Ξ£α΅’ F_i (ΞΈα΅’ - ΞΈ*_{old,i})Β²

GEM constraint:

βŸ¨βˆ‡L_current(ΞΈ), βˆ‡L_old(ΞΈ)⟩ β‰₯ 0

Replay objective:

L(ΞΈ) = E_{(x,y)~D_current} [β„“(x,y,ΞΈ)] + E_{(x,y)~M} [β„“(x,y,ΞΈ)]

ReferencesΒΆ

  1. Kirkpatrick et al. (2017) β€œOvercoming Catastrophic Forgetting in Neural Networks (EWC)”

  2. Zenke et al. (2017) β€œContinual Learning Through Synaptic Intelligence (SI)”

  3. Lopez-Paz & Ranzato (2017) β€œGradient Episodic Memory for Continual Learning (GEM)”

  4. Rusu et al. (2016) β€œProgressive Neural Networks”

  5. Shin et al. (2017) β€œContinual Learning with Deep Generative Replay”

  6. Chaudhry et al. (2019) β€œEfficient Lifelong Learning with A-GEM”

  7. Buzzega et al. (2020) β€œDark Experience for General Continual Learning (DER)”

  8. Riemer et al. (2019) β€œLearning to Learn without Forgetting by Maximizing Transfer and Minimizing Interference (MER)”

  9. De Lange et al. (2021) β€œA Continual Learning Survey: Defying Forgetting in Classification Tasks”

  10. Wang et al. (2022) β€œLearning to Prompt for Continual Learning (L2P)”

"""
Complete Continual Learning Implementations
===========================================
Includes: EWC, SI, GEM, A-GEM, Experience Replay, DER, LwF,
memory management, evaluation metrics.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
import copy
from collections import defaultdict

# ============================================================================
# 1. Memory Buffer for Replay
# ============================================================================

class ReplayBuffer:
    """
    Memory buffer for experience replay.
    
    Supports multiple storage strategies:
    - Reservoir sampling
    - Ring buffer (FIFO)
    - Class-balanced
    """
    def __init__(self, max_size=500, strategy='reservoir'):
        self.max_size = max_size
        self.strategy = strategy
        self.buffer = []
        self.num_seen = 0
        
        # For class-balanced storage
        self.class_buffers = defaultdict(list)
    
    def add(self, x, y, task_id=None):
        """Add sample to buffer."""
        sample = {'x': x.cpu(), 'y': y.cpu(), 'task_id': task_id}
        
        if self.strategy == 'reservoir':
            self._reservoir_add(sample)
        elif self.strategy == 'ring':
            self._ring_add(sample)
        elif self.strategy == 'class_balanced':
            self._class_balanced_add(sample, y.item())
    
    def _reservoir_add(self, sample):
        """Reservoir sampling: Probabilistic replacement."""
        if len(self.buffer) < self.max_size:
            self.buffer.append(sample)
        else:
            # Replace with probability max_size / num_seen
            idx = np.random.randint(0, self.num_seen + 1)
            if idx < self.max_size:
                self.buffer[idx] = sample
        
        self.num_seen += 1
    
    def _ring_add(self, sample):
        """Ring buffer: FIFO replacement."""
        if len(self.buffer) < self.max_size:
            self.buffer.append(sample)
        else:
            idx = self.num_seen % self.max_size
            self.buffer[idx] = sample
        
        self.num_seen += 1
    
    def _class_balanced_add(self, sample, class_id):
        """Class-balanced: Equal samples per class."""
        self.class_buffers[class_id].append(sample)
        
        # Rebuild buffer with balanced sampling
        self.buffer = []
        classes = list(self.class_buffers.keys())
        samples_per_class = self.max_size // len(classes)
        
        for c in classes:
            samples = self.class_buffers[c]
            # Take last samples_per_class samples
            selected = samples[-samples_per_class:] if len(samples) > samples_per_class else samples
            self.buffer.extend(selected)
    
    def sample(self, batch_size):
        """Sample batch from buffer."""
        if len(self.buffer) == 0:
            return None
        
        indices = np.random.choice(len(self.buffer), 
                                    min(batch_size, len(self.buffer)), 
                                    replace=False)
        
        samples = [self.buffer[i] for i in indices]
        
        x = torch.stack([s['x'] for s in samples])
        y = torch.stack([s['y'] for s in samples])
        task_ids = [s['task_id'] for s in samples]
        
        return x, y, task_ids
    
    def __len__(self):
        return len(self.buffer)


# ============================================================================
# 2. Elastic Weight Consolidation (EWC)
# ============================================================================

class EWC:
    """
    Elastic Weight Consolidation.
    
    Protects important parameters using Fisher Information.
    """
    def __init__(self, model, dataloader, device='cuda', lambda_ewc=1000):
        self.model = model
        self.device = device
        self.lambda_ewc = lambda_ewc
        
        # Store parameters after each task
        self.old_params = {}
        
        # Fisher information
        self.fisher = {}
        
        # Compute Fisher information
        self._compute_fisher(dataloader)
    
    def _compute_fisher(self, dataloader):
        """Compute diagonal Fisher Information Matrix."""
        self.model.eval()
        
        # Initialize Fisher
        for name, param in self.model.named_parameters():
            self.fisher[name] = torch.zeros_like(param)
        
        # Accumulate gradients
        for x, y in dataloader:
            x, y = x.to(self.device), y.to(self.device)
            
            self.model.zero_grad()
            output = self.model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            
            # Accumulate squared gradients
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    self.fisher[name] += param.grad.pow(2)
        
        # Average over dataset
        num_samples = len(dataloader.dataset)
        for name in self.fisher:
            self.fisher[name] /= num_samples
    
    def save_parameters(self):
        """Save current parameters."""
        for name, param in self.model.named_parameters():
            self.old_params[name] = param.data.clone()
    
    def penalty(self):
        """Compute EWC penalty."""
        if len(self.old_params) == 0:
            return 0
        
        loss = 0
        for name, param in self.model.named_parameters():
            if name in self.fisher:
                loss += (self.fisher[name] * 
                         (param - self.old_params[name]).pow(2)).sum()
        
        return (self.lambda_ewc / 2) * loss


# ============================================================================
# 3. Synaptic Intelligence (SI)
# ============================================================================

class SynapticIntelligence:
    """
    Synaptic Intelligence.
    
    Measures importance via path integral.
    """
    def __init__(self, model, c=0.1, epsilon=1e-3):
        self.model = model
        self.c = c
        self.epsilon = epsilon
        
        # Importance (omega)
        self.omega = {}
        
        # Running sum of parameter changes
        self.W = {}
        
        # Previous parameters
        self.prev_params = {}
        
        # Initialize
        for name, param in model.named_parameters():
            self.omega[name] = torch.zeros_like(param)
            self.W[name] = torch.zeros_like(param)
            self.prev_params[name] = param.data.clone()
    
    def update_omega(self):
        """Update importance after task completion."""
        for name, param in self.model.named_parameters():
            # Compute parameter change
            delta = param.data - self.prev_params[name]
            
            # Update importance
            self.omega[name] += self.W[name] / (delta.pow(2) + self.epsilon)
            
            # Reset for next task
            self.W[name] = torch.zeros_like(param)
            self.prev_params[name] = param.data.clone()
    
    def accumulate_gradient(self):
        """Accumulate gradient contribution (call after each batch)."""
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # Accumulate -gradient * delta_param
                delta = param.data - self.prev_params[name]
                self.W[name] -= param.grad * delta
    
    def penalty(self):
        """Compute SI penalty."""
        loss = 0
        for name, param in self.model.named_parameters():
            if name in self.omega:
                prev = self.prev_params[name]
                loss += (self.omega[name] * (param - prev).pow(2)).sum()
        
        return self.c * loss


# ============================================================================
# 4. Gradient Episodic Memory (GEM)
# ============================================================================

class GEM:
    """
    Gradient Episodic Memory.
    
    Projects gradients to not harm old tasks.
    """
    def __init__(self, model, memory_budget_per_task=100):
        self.model = model
        self.memory_budget = memory_budget_per_task
        
        # Memory for each task
        self.memory = {}
    
    def store_task_memory(self, task_id, dataloader, device='cuda'):
        """Store samples for a task."""
        samples_x = []
        samples_y = []
        
        count = 0
        for x, y in dataloader:
            if count >= self.memory_budget:
                break
            
            batch_size = x.size(0)
            remaining = self.memory_budget - count
            take = min(batch_size, remaining)
            
            samples_x.append(x[:take])
            samples_y.append(y[:take])
            count += take
        
        self.memory[task_id] = {
            'x': torch.cat(samples_x).to(device),
            'y': torch.cat(samples_y).to(device)
        }
    
    def get_gradient(self, task_id, device='cuda'):
        """Compute gradient on task's memory."""
        if task_id not in self.memory:
            return None
        
        x = self.memory[task_id]['x']
        y = self.memory[task_id]['y']
        
        self.model.zero_grad()
        output = self.model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        
        # Extract gradient
        grad = torch.cat([p.grad.view(-1) for p in self.model.parameters() 
                          if p.grad is not None])
        
        return grad
    
    def project_gradient(self, current_grad, device='cuda'):
        """Project current gradient to not harm old tasks."""
        # Get gradients for all previous tasks
        task_grads = []
        for task_id in self.memory.keys():
            g = self.get_gradient(task_id, device)
            if g is not None:
                task_grads.append(g)
        
        if len(task_grads) == 0:
            return current_grad
        
        # Stack gradients
        G = torch.stack(task_grads)  # [num_tasks, param_dim]
        
        # Check constraints
        violations = (G @ current_grad < 0).float()
        
        if violations.sum() == 0:
            return current_grad
        
        # Solve QP: min ||g - current_grad||^2 s.t. G @ g >= 0
        # Closed-form solution using Lagrange multipliers
        
        # For simplicity, use projection onto average constraint
        # (A-GEM approximation)
        avg_grad = G.mean(dim=0)
        
        if (avg_grad @ current_grad) < 0:
            # Project
            proj_coef = (current_grad @ avg_grad) / (avg_grad @ avg_grad)
            projected_grad = current_grad - proj_coef * avg_grad
            return projected_grad
        
        return current_grad


# ============================================================================
# 5. Averaged GEM (A-GEM)
# ============================================================================

class AGEM:
    """
    Averaged GEM: Simplified version using average gradient.
    """
    def __init__(self, model, buffer_size=500):
        self.model = model
        self.buffer = ReplayBuffer(max_size=buffer_size)
    
    def store_sample(self, x, y, task_id=None):
        """Store sample in buffer."""
        self.buffer.add(x, y, task_id)
    
    def get_reference_gradient(self, device='cuda'):
        """Compute gradient on memory buffer."""
        if len(self.buffer) == 0:
            return None
        
        # Sample from buffer
        samples = self.buffer.sample(min(256, len(self.buffer)))
        if samples is None:
            return None
        
        x, y, _ = samples
        x, y = x.to(device), y.to(device)
        
        self.model.zero_grad()
        output = self.model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        
        # Extract gradient
        grad = torch.cat([p.grad.view(-1) for p in self.model.parameters() 
                          if p.grad is not None])
        
        return grad
    
    def project_gradient(self, current_grad, device='cuda'):
        """Project if violates average constraint."""
        ref_grad = self.get_reference_gradient(device)
        
        if ref_grad is None:
            return current_grad
        
        # Check constraint
        dot_product = (current_grad @ ref_grad).item()
        
        if dot_product >= 0:
            return current_grad
        
        # Project
        proj_coef = dot_product / (ref_grad @ ref_grad)
        projected_grad = current_grad - proj_coef * ref_grad
        
        return projected_grad


# ============================================================================
# 6. Dark Experience Replay (DER)
# ============================================================================

class DarkExperienceReplay:
    """
    Dark Experience Replay: Store logits instead of labels.
    """
    def __init__(self, model, buffer_size=500, alpha=0.5):
        self.model = model
        self.buffer_size = buffer_size
        self.alpha = alpha  # Weight for distillation loss
        
        # Store (x, logits, task_id)
        self.buffer = []
        self.num_seen = 0
    
    def add(self, x, logits, task_id=None):
        """Add sample with logits to buffer."""
        sample = {
            'x': x.cpu(),
            'logits': logits.cpu(),
            'task_id': task_id
        }
        
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(sample)
        else:
            # Reservoir sampling
            idx = np.random.randint(0, self.num_seen + 1)
            if idx < self.buffer_size:
                self.buffer[idx] = sample
        
        self.num_seen += 1
    
    def sample(self, batch_size, device='cuda'):
        """Sample batch from buffer."""
        if len(self.buffer) == 0:
            return None
        
        indices = np.random.choice(len(self.buffer), 
                                    min(batch_size, len(self.buffer)), 
                                    replace=False)
        
        samples = [self.buffer[i] for i in indices]
        
        x = torch.stack([s['x'] for s in samples]).to(device)
        logits = torch.stack([s['logits'] for s in samples]).to(device)
        
        return x, logits
    
    def loss(self, model, x_replay, logits_replay):
        """Compute DER loss on replay samples."""
        output = model(x_replay)
        return F.mse_loss(output, logits_replay)


# ============================================================================
# 7. Learning without Forgetting (LwF)
# ============================================================================

class LearningWithoutForgetting:
    """
    Learning without Forgetting: Knowledge distillation.
    """
    def __init__(self, model, alpha=1.0, temperature=2.0):
        self.model = model
        self.alpha = alpha
        self.temperature = temperature
        
        # Store old model
        self.old_model = None
    
    def save_model(self):
        """Save current model as old model."""
        self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()
        
        # Freeze old model
        for param in self.old_model.parameters():
            param.requires_grad = False
    
    def distillation_loss(self, x, device='cuda'):
        """Compute distillation loss."""
        if self.old_model is None:
            return 0
        
        # Old model predictions
        with torch.no_grad():
            old_logits = self.old_model(x)
        
        # New model predictions
        new_logits = self.model(x)
        
        # KL divergence with temperature scaling
        T = self.temperature
        old_probs = F.softmax(old_logits / T, dim=1)
        new_log_probs = F.log_softmax(new_logits / T, dim=1)
        
        loss = F.kl_div(new_log_probs, old_probs, reduction='batchmean') * (T ** 2)
        
        return self.alpha * loss


# ============================================================================
# 8. Evaluation Metrics
# ============================================================================

class ContinualLearningMetrics:
    """Compute continual learning metrics."""
    def __init__(self, num_tasks):
        self.num_tasks = num_tasks
        
        # Accuracy matrix: acc[i][j] = accuracy on task j after learning task i
        self.acc = [[0.0 for _ in range(num_tasks)] for _ in range(num_tasks)]
    
    def update(self, task_trained, task_evaluated, accuracy):
        """Update accuracy matrix."""
        self.acc[task_trained][task_evaluated] = accuracy
    
    def average_accuracy(self, task_n):
        """Average accuracy after learning task n."""
        return np.mean([self.acc[task_n][i] for i in range(task_n + 1)])
    
    def forgetting_measure(self, task_n):
        """Average forgetting after learning task n."""
        if task_n == 0:
            return 0.0
        
        forgetting = []
        for i in range(task_n):
            # Max accuracy on task i across all previous training
            max_acc = max(self.acc[j][i] for j in range(i, task_n + 1))
            # Current accuracy on task i
            current_acc = self.acc[task_n][i]
            forgetting.append(max_acc - current_acc)
        
        return np.mean(forgetting)
    
    def backward_transfer(self, task_n):
        """Backward transfer after learning task n."""
        if task_n == 0:
            return 0.0
        
        bwt = []
        for i in range(task_n):
            # Accuracy after learning all tasks - accuracy right after task i
            bwt.append(self.acc[task_n][i] - self.acc[i][i])
        
        return np.mean(bwt)
    
    def forward_transfer(self, task_n):
        """Forward transfer (zero-shot performance on new tasks)."""
        if task_n == 0:
            return 0.0
        
        fwt = []
        for i in range(1, task_n + 1):
            # Accuracy before training on task i - random baseline
            fwt.append(self.acc[i-1][i] - self.acc[0][i])
        
        return np.mean(fwt)
    
    def print_summary(self, task_n):
        """Print metrics summary."""
        print(f"\nMetrics after Task {task_n}:")
        print(f"  Average Accuracy: {self.average_accuracy(task_n):.2f}%")
        print(f"  Forgetting Measure: {self.forgetting_measure(task_n):.2f}%")
        print(f"  Backward Transfer: {self.backward_transfer(task_n):.2f}%")
        if task_n > 0:
            print(f"  Forward Transfer: {self.forward_transfer(task_n):.2f}%")


# ============================================================================
# 9. Method Comparison
# ============================================================================

def print_method_comparison():
    """Print comparison of continual learning methods."""
    print("="*70)
    print("Continual Learning Methods Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method       β”‚ Type       β”‚ Memory       β”‚ Forgetting   β”‚ Complexity   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Fine-tuning  β”‚ Baseline   β”‚ None         β”‚ Very High    β”‚ 1Γ—           β”‚
β”‚              β”‚            β”‚              β”‚ (~80%)       β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ EWC          β”‚ Regular.   β”‚ O(|ΞΈ|)       β”‚ Medium       β”‚ 1.2Γ—         β”‚
β”‚              β”‚            β”‚ (Fisher)     β”‚ (~15%)       β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ SI           β”‚ Regular.   β”‚ O(|ΞΈ|)       β”‚ Medium       β”‚ 1.3Γ—         β”‚
β”‚              β”‚            β”‚ (Importance) β”‚ (~18%)       β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ LwF          β”‚ Distill.   β”‚ O(|ΞΈ|)       β”‚ Medium-High  β”‚ 1.5Γ—         β”‚
β”‚              β”‚            β”‚ (Old model)  β”‚ (~25%)       β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ER (500)     β”‚ Replay     β”‚ 500 samples  β”‚ Low          β”‚ 1.5Γ—         β”‚
β”‚              β”‚            β”‚              β”‚ (~5%)        β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ DER (500)    β”‚ Replay     β”‚ 500 samples  β”‚ Very Low     β”‚ 1.6Γ—         β”‚
β”‚              β”‚            β”‚ + logits     β”‚ (~3%)        β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ GEM          β”‚ Replay +   β”‚ n Γ— batch    β”‚ Very Low     β”‚ 2-3Γ—         β”‚
β”‚              β”‚ Constraint β”‚              β”‚ (~2%)        β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ A-GEM        β”‚ Replay +   β”‚ 500 samples  β”‚ Low          β”‚ 1.7Γ—         β”‚
β”‚              β”‚ Constraint β”‚              β”‚ (~8%)        β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Progressive  β”‚ Dynamic    β”‚ O(nΒ²|ΞΈ|)     β”‚ None         β”‚ 2Γ—           β”‚
β”‚              β”‚ Arch.      β”‚              β”‚ (0%)         β”‚              β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Trade-offs:**

Memory vs. Forgetting:
- No memory β†’ High forgetting (80%+)
- Small memory (500) β†’ Low forgetting (3-8%)
- Large memory (5000) β†’ Minimal forgetting (<1%)

Computational Cost:
- Regularization: Low overhead (1.2-1.5Γ—)
- Replay: Medium overhead (1.5-2Γ—)
- Constraint-based: High overhead (2-3Γ—)

**Best Practices:**

1. **Limited memory budget**: Use DER or A-GEM
2. **No memory allowed**: Use EWC or SI
3. **Maximum performance**: Use replay with large buffer
4. **Task-IL setting**: Progressive networks (no forgetting)
5. **Class-IL setting**: DER or GEM (handle confusion well)

**Typical Performance (5 tasks, Split CIFAR-10):**

Method          | Final Acc | Forgetting | Memory (MB)
----------------|-----------|------------|------------
Fine-tuning     | 45%       | 82%        | 0
EWC (Ξ»=1000)    | 68%       | 15%        | 10
SI (c=0.1)      | 65%       | 18%        | 10
ER (500)        | 78%       | 5%         | 50
DER (500)       | 82%       | 3%         | 55
GEM (100/task)  | 85%       | 2%         | 25
A-GEM (500)     | 75%       | 8%         | 50
"""
    
    print(comparison)
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    print("="*70)
    print("Continual Learning Implementations")
    print("="*70)
    print()
    
    # Simple model for demonstration
    class SimpleModel(nn.Module):
        def __init__(self, input_dim=784, num_classes=10):
            super(SimpleModel, self).__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.Linear(128, num_classes)
            )
        
        def forward(self, x):
            return self.net(x.view(x.size(0), -1))
    
    model = SimpleModel()
    
    # Test Replay Buffer
    print("Testing Replay Buffer...")
    buffer = ReplayBuffer(max_size=100, strategy='reservoir')
    
    for i in range(200):
        x = torch.randn(28, 28)
        y = torch.tensor(i % 10)
        buffer.add(x, y, task_id=i//50)
    
    print(f"  Buffer size: {len(buffer)}")
    samples = buffer.sample(32)
    if samples:
        x, y, task_ids = samples
        print(f"  Sampled batch: {x.shape}, {y.shape}")
    print()
    
    # Test EWC
    print("Testing EWC...")
    # Create dummy dataloader
    dummy_data = [(torch.randn(32, 28, 28), torch.randint(0, 10, (32,))) for _ in range(10)]
    ewc = EWC(model, dummy_data, device='cpu', lambda_ewc=1000)
    ewc.save_parameters()
    penalty = ewc.penalty()
    print(f"  EWC penalty: {penalty.item():.4f}")
    print()
    
    # Test SI
    print("Testing Synaptic Intelligence...")
    si = SynapticIntelligence(model, c=0.1)
    si.update_omega()
    penalty = si.penalty()
    print(f"  SI penalty: {penalty.item():.4f}")
    print()
    
    # Test DER
    print("Testing Dark Experience Replay...")
    der = DarkExperienceReplay(model, buffer_size=100, alpha=0.5)
    
    x = torch.randn(16, 28, 28)
    with torch.no_grad():
        logits = model(x)
    der.add(x[0], logits[0], task_id=0)
    
    print(f"  DER buffer size: {len(der.buffer)}")
    print()
    
    # Test Metrics
    print("Testing Continual Learning Metrics...")
    metrics = ContinualLearningMetrics(num_tasks=5)
    
    # Simulate accuracy matrix
    for i in range(5):
        for j in range(i + 1):
            # Simulated accuracy (decreases for old tasks)
            acc = 90 - (i - j) * 5 + np.random.randn() * 2
            metrics.update(i, j, acc)
    
    metrics.print_summary(4)
    print()
    
    print_method_comparison()
    
    print("="*70)
    print("Continual Learning Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ EWC: Protect important parameters via Fisher Information")
    print("  β€’ SI: Measure importance via path integral")
    print("  β€’ GEM/A-GEM: Constrain gradients to not harm old tasks")
    print("  β€’ Replay: Store and replay old samples (most effective)")
    print("  β€’ DER: Store logits for richer replay signal")
    print("  β€’ LwF: Knowledge distillation from old model")
    print()
    print("Key insight: Balance stability (preserve old) and plasticity (learn new)")
    print("Trade-off: Memory budget vs. forgetting vs. computational cost")
    print("Applications: Robotics, personalization, streaming data")
    print()

Advanced Continual Learning: Mathematical Foundations and Modern ApproachesΒΆ

1. Introduction to Continual LearningΒΆ

Continual Learning (also called Lifelong Learning or Incremental Learning) addresses a fundamental challenge in machine learning: how to learn sequentially from a stream of tasks without forgetting previous knowledge.

1.1 The Catastrophic Forgetting ProblemΒΆ

When a neural network is trained on Task A and then trained on Task B, performance on Task A drastically degrades. This phenomenon is called catastrophic forgetting or catastrophic interference.

Mathematical formulation:

Given a sequence of tasks \(\mathcal{T} = \{T_1, T_2, ..., T_N\}\), each with dataset \(\mathcal{D}_i = \{(x_j, y_j)\}_{j=1}^{n_i}\), the goal is to learn parameters \(\theta\) that minimize:

\[\mathcal{L}_{\text{total}} = \sum_{i=1}^N \mathbb{E}_{(x,y) \sim \mathcal{D}_i}[\ell(f_\theta(x), y)]\]

subject to:

  • No access to previous task data (memory constraints)

  • Bounded model capacity (cannot grow infinitely)

  • Limited computational budget per task

1.2 Stability-Plasticity DilemmaΒΆ

Stability: Preserve knowledge of old tasks
Plasticity: Adapt to new tasks

\[\text{Trade-off: } \text{Stability} \leftrightarrow \text{Plasticity}\]

Too much stability: Cannot learn new tasks
Too much plasticity: Forget old tasks immediately

1.3 Continual Learning ScenariosΒΆ

Task-Incremental Learning (TIL):

  • Task identity known at test time

  • Evaluate: \(\text{Acc}_i = \text{Accuracy on } T_i \text{ with task ID}\)

Domain-Incremental Learning (DIL):

  • Same task, different domains (e.g., different datasets for same classes)

  • No task ID at test time

Class-Incremental Learning (CIL):

  • Most challenging: New classes arrive over time

  • No task ID at test time

  • Must distinguish all classes seen so far

Example (Class-Incremental):

  • Task 1: Learn classes {0, 1, 2, 3, 4}

  • Task 2: Learn classes {5, 6, 7, 8, 9}

  • Test: Classify any digit 0-9 without knowing which task

1.4 Evaluation MetricsΒΆ

Average Accuracy: $\(\text{ACC} = \frac{1}{N} \sum_{i=1}^N a_{N,i}\)\( where \)a_{N,i}\( is accuracy on task \)i\( after learning task \)N$.

Forgetting Measure: $\(\text{FM} = \frac{1}{N-1} \sum_{i=1}^{N-1} \max_{j \in \{i, ..., N-1\}} (a_{j,i} - a_{N,i})\)$

Backward Transfer: $\(\text{BWT} = \frac{1}{N-1} \sum_{i=1}^{N-1} (a_{N,i} - a_{i,i})\)$

Negative BWT indicates forgetting.

Forward Transfer: $\(\text{FWT} = \frac{1}{N-1} \sum_{i=2}^{N} (a_{i-1,i} - a_{\text{random},i})\)$

Positive FWT indicates knowledge transfer to new tasks.

2. Regularization-Based MethodsΒΆ

2.1 Elastic Weight Consolidation (EWC)ΒΆ

Key idea: Protect important parameters for old tasks using Fisher Information Matrix.

Loss function: $\(\mathcal{L}(\theta) = \mathcal{L}_B(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{A,i}^*)^2\)$

where:

  • \(\mathcal{L}_B\): Loss on current task B

  • \(\theta_{A}^*\): Optimal parameters after task A

  • \(F_i\): Fisher Information for parameter \(i\)

  • \(\lambda\): Regularization strength

Fisher Information Matrix: $\(F_i = \mathbb{E}_{x \sim \mathcal{D}_A} \left[\left(\frac{\partial \log p(y|x, \theta_{A}^*)}{\partial \theta_i}\right)^2\right]\)$

Approximation (using samples): $\(F_i \approx \frac{1}{|\mathcal{D}_A|} \sum_{(x,y) \in \mathcal{D}_A} \left(\frac{\partial \log p(y|x, \theta_{A}^*)}{\partial \theta_i}\right)^2\)$

Intuition:

  • High \(F_i\) β†’ parameter \(i\) is important for task A β†’ large penalty if changed

  • Low \(F_i\) β†’ parameter \(i\) can be freely updated for task B

Advantages:

  • Principled approach based on information theory

  • No need to store old data

  • Works for any model architecture

Limitations:

  • Requires storing Fisher matrix (same size as model)

  • Assumes tasks don’t overlap (diagonal Fisher approximation)

  • Performance degrades with many tasks

2.2 Synaptic Intelligence (SI)ΒΆ

Key idea: Track parameter importance online during training.

Importance (accumulated during task training): $\(\omega_i = \sum_{k=1}^K \frac{g_i^{(k)} \cdot \Delta \theta_i^{(k)}}{\xi + \Delta \theta_i^{(k)^2}}\)$

where:

  • \(g_i^{(k)}\): Gradient at step \(k\)

  • \(\Delta \theta_i^{(k)}\): Parameter update at step \(k\)

  • \(\xi\): Small damping constant

Regularization: $\(\mathcal{L}(\theta) = \mathcal{L}_{\text{task}}(\theta) + c \sum_i \omega_i (\theta_i - \theta_i^*)^2\)$

Advantages over EWC:

  • Online computation (no need for separate Fisher calculation)

  • Captures parameter trajectories during training

  • More stable importance estimates

2.3 Memory Aware Synapses (MAS)ΒΆ

Key idea: Measure importance by output sensitivity (not loss).

Importance: $\(\Omega_i = \frac{1}{|\mathcal{D}|} \sum_{x \in \mathcal{D}} \left\|\frac{\partial f(x; \theta)}{\partial \theta_i}\right\|^2\)$

Regularization (same form as EWC): $\(\mathcal{L}(\theta) = \mathcal{L}_{\text{task}}(\theta) + \lambda \sum_i \Omega_i (\theta_i - \theta_i^*)^2\)$

Key difference from EWC:

  • EWC: Uses gradients of loss (supervised)

  • MAS: Uses gradients of output (can be unsupervised)

2.4 Learning without Forgetting (LwF)ΒΆ

Key idea: Use knowledge distillation to preserve old task outputs.

Loss for new task: $\(\mathcal{L} = \mathcal{L}_{\text{new}}(\theta) + \lambda \mathcal{L}_{\text{distill}}(\theta, \theta_{\text{old}})\)$

Distillation loss: $\(\mathcal{L}_{\text{distill}} = \text{KL}\left(p_{\theta_{\text{old}}}(y|x) \,\|\, p_{\theta}(y|x)\right)\)$

where \(p_\theta(y|x) = \text{softmax}(f_\theta(x) / T)\) with temperature \(T\).

Advantages:

  • No stored data needed

  • Works for task-incremental and class-incremental

  • Preserves function learned, not just parameters

Limitations:

  • Requires storing old model (or outputs)

  • May not scale to many tasks

  • Assumes output space remains relevant

3. Replay-Based MethodsΒΆ

3.1 Experience Replay (ER)ΒΆ

Key idea: Store subset of old data, interleave with new data during training.

Memory buffer: \(\mathcal{M} = \{(x_i, y_i)\}_{i=1}^M\) with \(M \ll \sum_i |\mathcal{D}_i|\)

Training loss: $\(\mathcal{L} = \mathbb{E}_{(x,y) \sim \mathcal{D}_{\text{new}}}[\ell(f(x), y)] + \lambda \mathbb{E}_{(x,y) \sim \mathcal{M}}[\ell(f(x), y)]\)$

Sampling strategies:

  1. Random: Uniform sampling from memory

  2. Reservoir sampling: Maintain uniform distribution over all seen data

  3. Class-balanced: Equal samples per class

  4. Gradient-based: Store samples with largest gradient norm

Advantages:

  • Simple and effective

  • Directly addresses forgetting

  • Compatible with any model

Limitations:

  • Requires storing data (privacy, memory constraints)

  • Performance limited by buffer size

  • May not generalize well to unseen data

3.2 Generative ReplayΒΆ

Key idea: Train generative model to synthesize old task data.

Components:

  • Generator \(G_\phi\): Synthesizes samples from old tasks

  • Solver \(f_\theta\): Main task network

Training:

  1. On task \(t\), sample pseudo-data: \(\tilde{x} \sim G_\phi\)

  2. Get pseudo-labels: \(\tilde{y} = f_{\theta_{t-1}}(\tilde{x})\)

  3. Train on mix of real and pseudo data: $\(\mathcal{L} = \mathbb{E}_{(x,y) \sim \mathcal{D}_t}[\ell(f(x), y)] + \lambda \mathbb{E}_{\tilde{x} \sim G}[\ell(f(\tilde{x}), \tilde{y})]\)$

Variants:

  • Conditional GAN: \(G(z, y)\) for class-conditional generation

  • VAE-based: Use variational autoencoder for stable generation

  • Deep Generative Replay: Generator also learns continually

Advantages:

  • No raw data storage (privacy-preserving)

  • Unlimited synthetic samples

  • Can generate diverse samples

Limitations:

  • Generator quality affects performance

  • Computational overhead for generation

  • May not capture full data distribution

3.3 Meta-Experience Replay (MER)ΒΆ

Key idea: Combine replay with meta-learning for better sample efficiency.

Meta-objective: $\(\min_\theta \mathbb{E}_{\mathcal{B} \sim \mathcal{M}} [\mathcal{L}(\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{D}_t}(\theta), \mathcal{B})]\)$

Algorithm:

  1. Sample batch from new task: \(\mathcal{B}_{\text{new}} \sim \mathcal{D}_t\)

  2. Sample batch from memory: \(\mathcal{B}_{\text{mem}} \sim \mathcal{M}\)

  3. Compute adapted parameters: \(\theta' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{B}_{\text{new}}}(\theta)\)

  4. Update using memory: \(\theta \leftarrow \theta - \beta \nabla_\theta \mathcal{L}_{\mathcal{B}_{\text{mem}}}(\theta')\)

Benefits:

  • Better generalization with limited memory

  • Prevents overfitting to replayed samples

  • Faster adaptation to new tasks

4. Architecture-Based MethodsΒΆ

4.1 Progressive Neural NetworksΒΆ

Key idea: Allocate new columns for new tasks, freeze old columns.

Architecture: $\(h_i^{(t)} = f\left(W_i^{(t)} h_{i-1}^{(t)} + \sum_{j<t} U_i^{(t:j)} h_{i-1}^{(j)}\right)\)$

where:

  • \(h_i^{(t)}\): Activation at layer \(i\) for task \(t\)

  • \(W_i^{(t)}\): Weights within task \(t\) column

  • \(U_i^{(t:j)}\): Lateral connections from task \(j\) to task \(t\)

Properties:

  • Zero forgetting (old parameters frozen)

  • Transfer learning via lateral connections

  • Model grows linearly with tasks

Advantages:

  • Guaranteed no forgetting

  • Automatic transfer learning

  • Simple to implement

Limitations:

  • Linear growth in parameters

  • No backward transfer (can’t improve old tasks)

  • Inefficient for many tasks

4.2 PacknetΒΆ

Key idea: Prune network for each task, pack multiple tasks in one network.

Algorithm:

  1. Train network on task \(t\) to convergence

  2. Prune \(k\%\) of least important weights (by magnitude)

  3. Mark remaining weights as β€œused by task \(t\)”

  4. For next task, only update β€œfree” weights

  5. Repeat

Pruning criterion: $\(\text{Keep weight } w_i \text{ if } |w_i| \geq \text{threshold}\)$

Advantages:

  • Fixed model size

  • Can pack many tasks (if pruning is aggressive)

  • No forgetting (frozen weights)

Limitations:

  • Order matters (early tasks get best capacity)

  • Cannot prune too aggressively or accuracy suffers

  • No transfer learning between tasks

4.3 Dynamically Expandable Networks (DEN)ΒΆ

Key idea: Grow network capacity dynamically based on task requirements.

Operations:

  1. Selective retraining: Retrain subset of neurons for new task

  2. Dynamic expansion: Add neurons if capacity insufficient

  3. Network split: Duplicate neurons if used by multiple tasks

Neuron selection: $\(s_k = \frac{1}{|\mathcal{D}|} \sum_{(x,y) \in \mathcal{D}} \left|\frac{\partial \mathcal{L}}{\partial h_k}\right|\)$

High \(s_k\) β†’ neuron important for current task.

Sparse regularization: $\(\mathcal{L} = \mathcal{L}_{\text{task}} + \lambda_1 \|h\|_1 + \lambda_2 \|W\|_1\)$

Encourages sparse activations and weights.

Advantages:

  • Adaptive capacity

  • Better than fixed expansion

  • Efficient parameter usage

Limitations:

  • Complexity in managing neuron selection

  • Still grows over time

  • Hyperparameter sensitive

4.4 Expert Gate (Mixture of Experts for CL)ΒΆ

Key idea: Route inputs to task-specific experts.

Gating function: $\(p(e_i | x) = \frac{\exp(g_i(x))}{\sum_j \exp(g_j(x))}\)$

Output: $\(y = \sum_i p(e_i | x) \cdot f_{e_i}(x)\)$

Training:

  • Initialize expert for each task

  • Gate learns to route based on input

  • Can share some experts across tasks

Advantages:

  • Flexible task routing

  • Can handle task-incremental and class-incremental

  • Experts can be trained independently

Limitations:

  • Requires task boundaries during training

  • Gating function can be challenging to learn

  • More parameters than single model

5. Meta-Learning for Continual LearningΒΆ

5.1 Meta-Learning FrameworkΒΆ

Key idea: Learn to learn such that model can quickly adapt to new tasks with minimal forgetting.

MAML for Continual Learning: $\(\theta \leftarrow \theta - \beta \nabla_\theta \sum_{T_i \in \text{past}} \mathcal{L}_{T_i}(\theta - \alpha \nabla_\theta \mathcal{L}_{T_i}(\theta))\)$

Inner loop (task adaptation): $\(\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{T_i}(\theta)\)$

Outer loop (meta-update): $\(\theta \leftarrow \theta - \beta \nabla_\theta \sum_i \mathcal{L}_{T_i}(\theta_i')\)$

5.2 Online-aware Meta-Learning (OML)ΒΆ

Objective: $\(\min_\theta \mathbb{E}_{T \sim p(\mathcal{T})} \left[\sum_{t=1}^T \mathcal{L}_t(\theta_t)\right]\)$

where \(\theta_t\) is adapted from \(\theta_{t-1}\) after seeing task \(t\).

Meta-representation:

  • Learn representation that is easy to adapt

  • Minimize interference between tasks

  • Maximize forward transfer

5.3 Learning to Learn without Forgetting (L2LwF)ΒΆ

Combines meta-learning with knowledge distillation:

\[\mathcal{L} = \mathcal{L}_{\text{task}}(\theta') + \lambda \text{KL}(p_\theta \| p_{\theta'})\]

where \(\theta'\) is meta-learned initialization.

6. Memory and Retrieval MechanismsΒΆ

6.1 Memory-Augmented Neural Networks (MANN)ΒΆ

Neural Turing Machine for CL:

  • External memory \(M \in \mathbb{R}^{N \times M}\)

  • Read/write operations via attention

Read: $\(r_t = \sum_i w_t^r(i) M_t(i)\)$

Write: $\(M_t(i) = M_{t-1}(i) + w_t^w(i) a_t\)$

where \(w^r, w^w\) are attention weights, \(a_t\) is content to write.

For continual learning:

  • Store task-specific patterns in memory

  • Retrieve relevant patterns for new tasks

  • Memory grows with tasks but slower than full replay

6.2 Gradient Episodic Memory (GEM)ΒΆ

Key idea: Ensure gradients on new task don’t increase loss on old tasks.

Constraint: $\(\langle g, g_i \rangle \geq 0 \quad \forall i < t\)$

where:

  • \(g\): Gradient on current task \(t\)

  • \(g_i\): Gradient on task \(i\) (computed from stored examples)

Optimization: If constraint violated, project gradient: $\(g' = g - \frac{\langle g, g_i \rangle}{\|g_i\|^2} g_i\)$

Averaged-GEM (A-GEM): Use average gradient from memory instead of per-task constraints: $\(\langle g, \bar{g}_{\text{mem}} \rangle \geq 0\)$

Advantages:

  • Strong theoretical guarantee (no forgetting on stored data)

  • Works well with small memory

  • Fast projection

Limitations:

  • Requires gradient computation on memory

  • May be too conservative (limits plasticity)

6.3 Dark Experience Replay (DER)ΒΆ

Key idea: Store both inputs and model outputs (logits) from previous tasks.

Memory: \(\mathcal{M} = \{(x_i, y_i, z_i)\}\) where \(z_i = f_{\theta_{\text{old}}}(x_i)\) (logits)

Loss: $\(\mathcal{L} = \mathcal{L}_{\text{new}} + \alpha \mathcal{L}_{\text{MSE}}(f_\theta(x_{\text{mem}}), z_{\text{mem}}) + \beta \mathcal{L}_{\text{CE}}(f_\theta(x_{\text{mem}}), y_{\text{mem}})\)$

Benefits:

  • Preserves detailed information (logits vs. hard labels)

  • Better than standard replay

  • Dark knowledge transfer

7. Continual Learning in the WildΒΆ

7.1 Online Continual LearningΒΆ

Setting: No clear task boundaries, data arrives in a stream.

Challenges:

  • When to update model?

  • How to detect distribution shift?

  • Resource constraints (memory, compute)

Strategies:

  • Sliding window: Maintain recent data

  • Trigger-based: Update when performance drops

  • Probabilistic: Detect novelty via uncertainty

7.2 Continual Learning with Noisy LabelsΒΆ

Robust loss functions: $\(\mathcal{L}_{\text{robust}} = -\log \frac{\exp(z_y / T)}{\sum_i \exp(z_i / T)^\gamma}\)$

where \(\gamma < 1\) downweights low-confidence predictions.

Sample selection:

  • Use small loss samples for training

  • Store clean samples in memory

  • Update memory with high-confidence predictions

7.3 Continual Pre-trainingΒΆ

Goal: Continually update pre-trained models (e.g., BERT, GPT) with new data.

Challenges:

  • Forgetting of pre-trained knowledge

  • Distribution shift between pre-training and new data

  • Massive scale (billions of parameters)

Solutions:

  • Adapter modules: Add small task-specific layers

  • Prompt tuning: Learn task-specific prompts

  • Regularization: EWC on important parameters

  • Rehearsal: Mix in original pre-training data

8. Theoretical AnalysisΒΆ

8.1 PAC-Bayes Bounds for Continual LearningΒΆ

Theorem: With probability \(1-\delta\), for all tasks \(t\):

\[\mathbb{E}_{(x,y) \sim \mathcal{D}_t}[\ell(f(x), y)] \leq \hat{\mathcal{L}}_t + \sqrt{\frac{\text{KL}(q \| p) + \log(2N/\delta)}{2n_t}}\]

where:

  • \(q\): Posterior distribution over parameters

  • \(p\): Prior (from previous tasks)

  • \(n_t\): Number of samples in task \(t\)

Insight: Regularization term \(\text{KL}(q \| p)\) bounds forgetting.

8.2 Optimal Transport for Continual LearningΒΆ

Task similarity: $\(d(\mathcal{D}_i, \mathcal{D}_j) = \inf_{\gamma \in \Gamma} \mathbb{E}_{(x,y),(x',y') \sim \gamma}[c(x, x')]\)$

where \(\Gamma\) is set of joint distributions with marginals \(\mathcal{D}_i, \mathcal{D}_j\).

Application: Use task similarity to decide:

  • How much to regularize (similar tasks β†’ less regularization)

  • Which samples to store (diverse tasks β†’ more storage)

  • Architecture adaptation strategy

8.3 Information Bottleneck PerspectiveΒΆ

Trade-off: $\(\min I(X; T) \quad \text{s.t.} \quad I(Y; T) \geq I_{\min}\)$

where:

  • \(I(X; T)\): Mutual information between input and representation

  • \(I(Y; T)\): Mutual information between label and representation

  • \(T = f(X)\): Learned representation

For continual learning:

  • Compress representation to essential task information

  • Discard task-specific noise

  • Results in more robust representations

9. Recent Advances (2020-2024)ΒΆ

9.1 Prompt-based Continual LearningΒΆ

L2P (Learning to Prompt):

  • Maintain pool of learnable prompts

  • Select relevant prompts for each task

  • Freeze backbone, only update prompts

DualPrompt:

  • General prompts (shared across tasks)

  • Task-specific prompts (unique per task)

  • Gating mechanism for selection

Advantages:

  • Parameter-efficient (only update prompts)

  • No forgetting of backbone

  • Scalable to many tasks

9.2 Contrastive Continual LearningΒΆ

SupCon for CL: $\(\mathcal{L}_{\text{SupCon}} = -\log \frac{\sum_{p \in P(i)} \exp(\text{sim}(z_i, z_p) / \tau)}{\sum_{a \in A(i)} \exp(\text{sim}(z_i, z_a) / \tau)}\)$

where \(P(i)\) are positive pairs (same class), \(A(i)\) are all pairs.

Benefits:

  • Better feature separation

  • More robust representations

  • Less sensitive to distribution shift

9.3 Self-Supervised Continual LearningΒΆ

SimCLR for CL:

  • Pre-train with contrastive learning on each task

  • Fine-tune classifier on task data

  • Regularize representation to preserve old knowledge

Advantages:

  • Leverages unlabeled data

  • Better feature quality

  • Natural regularization via SSL objective

9.4 Continual Learning for Large Language ModelsΒΆ

Challenges:

  • Billions of parameters

  • Expensive to retrain

  • Data privacy (cannot store all data)

Solutions:

  • LoRA (Low-Rank Adaptation): Learn low-rank updates

  • Prefix Tuning: Learn task-specific prefixes

  • Adapters: Small bottleneck layers

  • Knowledge Distillation: Compress old model knowledge

10. Benchmarks and DatasetsΒΆ

10.1 Standard BenchmarksΒΆ

Split MNIST:

  • 5 tasks, 2 classes each (0-1, 2-3, …, 8-9)

  • Simple baseline for method comparison

Split CIFAR-10/100:

  • CIFAR-10: 5 tasks, 2 classes each

  • CIFAR-100: 10 tasks, 10 classes each

  • More realistic images

TinyImageNet:

  • 200 classes, split into 10-20 tasks

  • Higher resolution, more challenging

CORe50:

  • Real-world object recognition

  • 50 classes, 11 sessions

  • Continuous acquisition (realistic setting)

10.2 Evaluation ProtocolsΒΆ

Offline evaluation:

  • Train on tasks sequentially

  • Test on all tasks after training

Online evaluation:

  • Test after each task

  • Measure forgetting over time

Multi-domain:

  • Same classes, different domains

  • Tests domain adaptation + continual learning

11. Best Practices and GuidelinesΒΆ

11.1 When to Use Which MethodΒΆ

Scenario

Recommended Method

Reason

Memory available

Experience Replay

Simplest, most effective

No memory allowed

EWC or SI

Regularization-based, no storage

Task boundaries known

Progressive Networks

Zero forgetting guarantee

Unknown task boundaries

A-GEM or DER

Online-friendly

Large-scale (LLMs)

LoRA or Adapters

Parameter-efficient

Privacy constraints

Generative Replay

No raw data storage

11.2 Hyperparameter TuningΒΆ

Regularization strength (\(\lambda\)):

  • Too high β†’ Cannot learn new tasks

  • Too low β†’ Forgets old tasks

  • Typical range: 0.1 - 100 (dataset dependent)

Memory size (for replay):

  • Larger is better, but diminishing returns

  • Typical: 200-2000 samples per task

  • Class-balanced is crucial

Learning rate:

  • Lower than from-scratch training (to preserve old knowledge)

  • Typical: 0.1Γ— original learning rate

  • May need different rates for new vs. replayed data

11.3 Common PitfallsΒΆ

❌ Evaluating only on final task: Doesn’t show forgetting
❌ Using task ID at test time for CIL: Makes problem easier
❌ Not reporting forgetting metrics: Only ACC is insufficient
❌ Storing too much data in memory: Defeats purpose of CL
❌ Ignoring computational overhead: Some methods very expensive

12. Future DirectionsΒΆ

12.1 Continual Few-Shot LearningΒΆ

Learn from few examples per task while maintaining old knowledge.

Challenges:

  • Limited data for new tasks

  • Easy to overfit

  • Hard to balance old vs. new

Approaches:

  • Meta-learning + replay

  • Prototype-based methods

  • Transfer learning from pre-trained models

12.2 Continual Reinforcement LearningΒΆ

Agent learns sequence of tasks in interactive environment.

Challenges:

  • Non-stationary environment

  • Credit assignment across tasks

  • Catastrophic forgetting of policies

Methods:

  • Policy distillation

  • Progressive networks for RL

  • Episodic memory for experiences

12.3 Neural-Symbolic Continual LearningΒΆ

Combine neural networks with symbolic reasoning.

Benefits:

  • Explicit knowledge representation

  • Easier to prevent forgetting (symbolic rules)

  • Better interpretability

12.4 Biological InspirationΒΆ

Complementary Learning Systems:

  • Fast learning in hippocampus

  • Slow learning in neocortex

  • Replay during sleep

Synaptic consolidation:

  • Synaptic tagging and capture

  • Protein synthesis for long-term memory

  • Homeostatic plasticity

13. Summary and Key TakeawaysΒΆ

Core Principles:ΒΆ

  1. Catastrophic forgetting is fundamental: Neural networks naturally forget when learning new tasks

  2. Trade-off is unavoidable: Stability ↔ Plasticity must be balanced

  3. No silver bullet: Different scenarios need different methods

  4. Memory helps: Even small amounts of stored data dramatically improve performance

  5. Regularization matters: Protecting important parameters prevents forgetting

Method Categories:ΒΆ

Regularization:

  • βœ… No data storage needed

  • βœ… Theoretical grounding

  • ❌ Performance degrades with many tasks

Replay:

  • βœ… Simple and effective

  • βœ… Best empirical performance

  • ❌ Requires data storage

Architecture:

  • βœ… Zero forgetting possible

  • βœ… Clear task separation

  • ❌ Model growth

Meta-Learning:

  • βœ… Fast adaptation

  • βœ… Better sample efficiency

  • ❌ Complex to implement

Practical Recommendations:ΒΆ

  1. Start simple: Try experience replay first

  2. Measure forgetting: Don’t just report final accuracy

  3. Use class-incremental: Most realistic and challenging scenario

  4. Combine methods: Hybrid approaches often work best

  5. Consider constraints: Memory, compute, privacy requirements

  6. Benchmark properly: Use standard datasets and metrics

Open Problems:ΒΆ

  • Scaling to thousands of tasks

  • Continual learning in production (online, non-stationary data)

  • Theory-practice gap (theoretical guarantees vs. empirical performance)

  • Biological plausibility (how do humans do it?)

  • Privacy-preserving continual learning (federated setting)

Conclusion: Continual learning is crucial for real-world AI systems that must adapt over time. While significant progress has been made, many challenges remain. The field is rapidly evolving with new methods combining insights from optimization, meta-learning, neuroscience, and information theory.

"""
Advanced Continual Learning - Production Implementation
Comprehensive implementations of major continual learning methods

Methods Implemented:
1. Elastic Weight Consolidation (EWC)
2. Synaptic Intelligence (SI)
3. Learning without Forgetting (LwF)
4. Experience Replay (ER)
5. Gradient Episodic Memory (GEM & A-GEM)
6. Dark Experience Replay (DER)
7. Progressive Neural Networks
8. Packnet
9. Meta-learning for CL (OML)

Features:
- Comprehensive evaluation metrics (ACC, FM, BWT, FWT)
- Multiple benchmark datasets (Split MNIST, Split CIFAR)
- Visualization of forgetting curves
- Performance comparison across methods

Author: Advanced Deep Learning Course
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from typing import List, Tuple, Dict, Optional
import numpy as np
from collections import defaultdict
import copy


# ============================================================================
# 1. Base Classes and Utilities
# ============================================================================

class ContinualLearner(nn.Module):
    """Base class for continual learning methods."""
    
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.task_count = 0
        
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """
        Observe a batch of data from a task.
        
        Args:
            x: Input batch
            y: Labels
            task_id: Task identifier
            
        Returns:
            loss: Training loss
        """
        raise NotImplementedError
    
    def after_task(self, task_id: int, dataloader: DataLoader):
        """Called after finishing a task (for consolidation, etc.)."""
        pass
    
    def forward(self, x: torch.Tensor, task_id: Optional[int] = None) -> torch.Tensor:
        """Forward pass (may use task_id for task-incremental)."""
        return self.model(x)


class MemoryBuffer:
    """
    Memory buffer for experience replay.
    Implements various sampling strategies.
    """
    
    def __init__(self, capacity: int, sampling: str = 'random'):
        """
        Args:
            capacity: Maximum number of samples to store
            sampling: 'random', 'reservoir', or 'class_balanced'
        """
        self.capacity = capacity
        self.sampling = sampling
        self.buffer = []
        self.labels = []
        self.task_ids = []
        self.n_seen = 0
        
    def add(self, x: torch.Tensor, y: torch.Tensor, task_id: int):
        """Add samples to buffer."""
        batch_size = x.size(0)
        
        for i in range(batch_size):
            if len(self.buffer) < self.capacity:
                # Buffer not full, just add
                self.buffer.append(x[i].cpu())
                self.labels.append(y[i].cpu())
                self.task_ids.append(task_id)
            else:
                # Buffer full, use reservoir sampling
                if self.sampling == 'reservoir':
                    j = np.random.randint(0, self.n_seen + i + 1)
                    if j < self.capacity:
                        self.buffer[j] = x[i].cpu()
                        self.labels[j] = y[i].cpu()
                        self.task_ids[j] = task_id
            
            self.n_seen += 1
    
    def sample(self, batch_size: int, device: str = 'cuda') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Sample batch from buffer."""
        if len(self.buffer) == 0:
            return None, None, None
        
        indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), replace=False)
        
        x = torch.stack([self.buffer[i] for i in indices]).to(device)
        y = torch.tensor([self.labels[i] for i in indices], dtype=torch.long).to(device)
        t = torch.tensor([self.task_ids[i] for i in indices], dtype=torch.long).to(device)
        
        return x, y, t
    
    def get_all(self, device: str = 'cuda') -> Tuple[torch.Tensor, torch.Tensor]:
        """Get all samples from buffer."""
        if len(self.buffer) == 0:
            return None, None
        
        x = torch.stack(self.buffer).to(device)
        y = torch.tensor(self.labels, dtype=torch.long).to(device)
        return x, y
    
    def __len__(self):
        return len(self.buffer)


class EvaluationMetrics:
    """Compute continual learning metrics."""
    
    def __init__(self):
        self.accuracy_matrix = []  # accuracy_matrix[i][j] = accuracy on task j after training task i
        
    def update(self, task_accuracies: List[float]):
        """Update with accuracies for all tasks after current task."""
        self.accuracy_matrix.append(task_accuracies)
    
    def average_accuracy(self) -> float:
        """Average accuracy across all tasks."""
        if not self.accuracy_matrix:
            return 0.0
        return np.mean(self.accuracy_matrix[-1])
    
    def forgetting_measure(self) -> float:
        """Average forgetting across tasks."""
        if len(self.accuracy_matrix) < 2:
            return 0.0
        
        forgetting = []
        for j in range(len(self.accuracy_matrix) - 1):
            max_acc = max(self.accuracy_matrix[i][j] for i in range(j, len(self.accuracy_matrix)))
            current_acc = self.accuracy_matrix[-1][j]
            forgetting.append(max_acc - current_acc)
        
        return np.mean(forgetting)
    
    def backward_transfer(self) -> float:
        """Backward transfer (negative = forgetting)."""
        if len(self.accuracy_matrix) < 2:
            return 0.0
        
        bwt = []
        for j in range(len(self.accuracy_matrix) - 1):
            final_acc = self.accuracy_matrix[-1][j]
            after_task_acc = self.accuracy_matrix[j][j]
            bwt.append(final_acc - after_task_acc)
        
        return np.mean(bwt)
    
    def forward_transfer(self, random_baseline: List[float]) -> float:
        """Forward transfer (knowledge transfer to new tasks)."""
        if len(self.accuracy_matrix) < 2:
            return 0.0
        
        fwt = []
        for i in range(1, len(self.accuracy_matrix)):
            before_training = self.accuracy_matrix[i-1][i] if i < len(self.accuracy_matrix[i-1]) else 0
            fwt.append(before_training - random_baseline[i])
        
        return np.mean(fwt)


# ============================================================================
# 2. Elastic Weight Consolidation (EWC)
# ============================================================================

class EWC(ContinualLearner):
    """
    Elastic Weight Consolidation (Kirkpatrick et al., 2017).
    
    Protects important parameters using Fisher Information Matrix.
    """
    
    def __init__(self, model: nn.Module, lambda_: float = 5000, fisher_samples: int = 200):
        """
        Args:
            model: Neural network
            lambda_: Regularization strength
            fisher_samples: Number of samples for Fisher estimation
        """
        super().__init__(model)
        self.lambda_ = lambda_
        self.fisher_samples = fisher_samples
        
        # Store optimal parameters and Fisher information for each task
        self.optimal_params = {}
        self.fisher_matrices = {}
        
    def compute_fisher(self, dataloader: DataLoader, device: str = 'cuda') -> Dict[str, torch.Tensor]:
        """Compute Fisher Information Matrix."""
        fisher = {n: torch.zeros_like(p, device=device) for n, p in self.model.named_parameters() if p.requires_grad}
        
        self.model.eval()
        n_samples = 0
        
        for x, y in dataloader:
            if n_samples >= self.fisher_samples:
                break
            
            x, y = x.to(device), y.to(device)
            batch_size = x.size(0)
            
            # Compute gradients
            logits = self.model(x)
            loss = F.cross_entropy(logits, y)
            
            self.model.zero_grad()
            loss.backward()
            
            # Accumulate squared gradients
            for n, p in self.model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.data ** 2 * batch_size
            
            n_samples += batch_size
        
        # Normalize
        for n in fisher:
            fisher[n] /= n_samples
        
        return fisher
    
    def after_task(self, task_id: int, dataloader: DataLoader):
        """Compute and store Fisher matrix after task."""
        device = next(self.model.parameters()).device
        
        # Compute Fisher
        fisher = self.compute_fisher(dataloader, device)
        self.fisher_matrices[task_id] = fisher
        
        # Store optimal parameters
        self.optimal_params[task_id] = {
            n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad
        }
        
        self.task_count += 1
    
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with EWC regularization."""
        # Standard loss
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        
        # EWC penalty
        if self.task_count > 0:
            for t in range(self.task_count):
                for n, p in self.model.named_parameters():
                    if p.requires_grad and n in self.fisher_matrices[t]:
                        fisher = self.fisher_matrices[t][n]
                        optimal = self.optimal_params[t][n]
                        loss += (self.lambda_ / 2) * (fisher * (p - optimal) ** 2).sum()
        
        return loss


# ============================================================================
# 3. Synaptic Intelligence (SI)
# ============================================================================

class SynapticIntelligence(ContinualLearner):
    """
    Synaptic Intelligence (Zenke et al., 2017).
    
    Online computation of parameter importance during training.
    """
    
    def __init__(self, model: nn.Module, c: float = 0.1, xi: float = 0.1):
        """
        Args:
            model: Neural network
            c: Regularization strength
            xi: Damping parameter
        """
        super().__init__(model)
        self.c = c
        self.xi = xi
        
        # Online importance estimation
        self.omega = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        self.W = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}  # Accumulated path
        
        # Optimal parameters from previous tasks
        self.optimal_params = {}
        
        # Track previous parameters
        self.prev_params = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
    
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with SI regularization."""
        # Standard loss
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        
        # SI penalty (from previous tasks)
        if self.task_count > 0:
            for n, p in self.model.named_parameters():
                if p.requires_grad and n in self.omega:
                    optimal = self.optimal_params.get(n, p.data)
                    loss += (self.c / 2) * (self.omega[n] * (p - optimal) ** 2).sum()
        
        # Compute gradients
        loss.backward()
        
        # Update path integral
        device = next(self.model.parameters()).device
        for n, p in self.model.named_parameters():
            if p.requires_grad and p.grad is not None:
                delta = p.data - self.prev_params[n]
                self.W[n] += p.grad.data * delta
                self.prev_params[n] = p.data.clone()
        
        return loss
    
    def after_task(self, task_id: int, dataloader: DataLoader):
        """Update importance after task."""
        # Compute omega (importance)
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                delta = p.data - self.prev_params.get(n, p.data)
                self.omega[n] += self.W[n] / (delta ** 2 + self.xi)
                
                # Reset path integral for next task
                self.W[n].zero_()
        
        # Store optimal parameters
        self.optimal_params = {n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad}
        
        self.task_count += 1


# ============================================================================
# 4. Learning without Forgetting (LwF)
# ============================================================================

class LearningWithoutForgetting(ContinualLearner):
    """
    Learning without Forgetting (Li & Hoiem, 2016).
    
    Uses knowledge distillation to preserve outputs on old tasks.
    """
    
    def __init__(self, model: nn.Module, lambda_: float = 1.0, T: float = 2.0):
        """
        Args:
            model: Neural network
            lambda_: Distillation loss weight
            T: Temperature for softmax
        """
        super().__init__(model)
        self.lambda_ = lambda_
        self.T = T
        self.old_model = None
        
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with knowledge distillation."""
        logits = self.model(x)
        
        # New task loss
        loss = F.cross_entropy(logits, y)
        
        # Distillation loss (if not first task)
        if self.old_model is not None:
            with torch.no_grad():
                old_logits = self.old_model(x)
            
            # Soft targets
            p_old = F.softmax(old_logits / self.T, dim=1)
            p_new = F.log_softmax(logits / self.T, dim=1)
            
            distill_loss = F.kl_div(p_new, p_old, reduction='batchmean') * (self.T ** 2)
            loss += self.lambda_ * distill_loss
        
        return loss
    
    def after_task(self, task_id: int, dataloader: DataLoader):
        """Store model for distillation."""
        self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()
        for p in self.old_model.parameters():
            p.requires_grad = False
        
        self.task_count += 1


# ============================================================================
# 5. Experience Replay (ER)
# ============================================================================

class ExperienceReplay(ContinualLearner):
    """
    Experience Replay: Store subset of data, interleave with new data.
    """
    
    def __init__(self, model: nn.Module, memory_size: int = 1000, replay_batch_size: int = 64):
        """
        Args:
            model: Neural network
            memory_size: Size of replay buffer
            replay_batch_size: Batch size for replay
        """
        super().__init__(model)
        self.memory = MemoryBuffer(memory_size)
        self.replay_batch_size = replay_batch_size
        
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with experience replay."""
        # Add to memory
        self.memory.add(x, y, task_id)
        
        # Standard loss on current batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        
        # Replay loss
        if len(self.memory) > 0:
            x_mem, y_mem, _ = self.memory.sample(self.replay_batch_size, x.device)
            if x_mem is not None:
                logits_mem = self.model(x_mem)
                loss += F.cross_entropy(logits_mem, y_mem)
        
        return loss


# ============================================================================
# 6. Gradient Episodic Memory (GEM & A-GEM)
# ============================================================================

class AGEM(ContinualLearner):
    """
    Averaged Gradient Episodic Memory (Chaudhry et al., 2019).
    
    Projects gradients to not increase loss on memory.
    """
    
    def __init__(self, model: nn.Module, memory_size: int = 1000, memory_batch_size: int = 64):
        """
        Args:
            model: Neural network
            memory_size: Size of episodic memory
            memory_batch_size: Batch size for memory gradient
        """
        super().__init__(model)
        self.memory = MemoryBuffer(memory_size)
        self.memory_batch_size = memory_batch_size
        
    def project_gradient(self, g_current: torch.Tensor, g_memory: torch.Tensor) -> torch.Tensor:
        """Project gradient if it increases loss on memory."""
        dot_product = torch.dot(g_current, g_memory)
        
        if dot_product < 0:
            # Violation: project gradient
            g_proj = g_current - (dot_product / (torch.norm(g_memory) ** 2 + 1e-8)) * g_memory
            return g_proj
        else:
            # No violation
            return g_current
    
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with gradient projection."""
        # Add to memory
        self.memory.add(x, y, task_id)
        
        # Compute gradient on current batch
        self.model.zero_grad()
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        
        # Store current gradient
        g_current = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
        
        # Compute gradient on memory
        if len(self.memory) > 0:
            x_mem, y_mem, _ = self.memory.sample(self.memory_batch_size, x.device)
            if x_mem is not None:
                self.model.zero_grad()
                logits_mem = self.model(x_mem)
                loss_mem = F.cross_entropy(logits_mem, y_mem)
                loss_mem.backward()
                
                g_memory = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
                
                # Project gradient
                g_proj = self.project_gradient(g_current, g_memory)
                
                # Set projected gradient
                idx = 0
                for p in self.model.parameters():
                    if p.grad is not None:
                        numel = p.grad.numel()
                        p.grad.data = g_proj[idx:idx+numel].view_as(p.grad)
                        idx += numel
        
        return loss


# ============================================================================
# 7. Dark Experience Replay (DER)
# ============================================================================

class DarkExperienceReplay(ContinualLearner):
    """
    Dark Experience Replay (Buzzega et al., 2020).
    
    Stores both inputs and logits (dark knowledge).
    """
    
    def __init__(
        self,
        model: nn.Module,
        memory_size: int = 1000,
        alpha: float = 0.5,
        beta: float = 0.5
    ):
        """
        Args:
            model: Neural network
            memory_size: Size of buffer
            alpha: Weight for logit matching loss
            beta: Weight for classification loss on memory
        """
        super().__init__(model)
        self.memory_size = memory_size
        self.alpha = alpha
        self.beta = beta
        
        # Store (x, y, logits)
        self.memory_x = []
        self.memory_y = []
        self.memory_logits = []
        
    def add_to_memory(self, x: torch.Tensor, y: torch.Tensor, logits: torch.Tensor):
        """Add samples with logits to memory."""
        for i in range(x.size(0)):
            if len(self.memory_x) < self.memory_size:
                self.memory_x.append(x[i].cpu())
                self.memory_y.append(y[i].cpu())
                self.memory_logits.append(logits[i].detach().cpu())
            else:
                # Random replacement
                idx = np.random.randint(0, self.memory_size)
                self.memory_x[idx] = x[i].cpu()
                self.memory_y[idx] = y[i].cpu()
                self.memory_logits[idx] = logits[i].detach().cpu()
    
    def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
        """Train with dark replay."""
        # Forward pass
        logits = self.model(x)
        
        # Store in memory
        self.add_to_memory(x, y, logits)
        
        # New task loss
        loss = F.cross_entropy(logits, y)
        
        # Replay loss
        if len(self.memory_x) > 0:
            # Sample from memory
            batch_size = min(x.size(0), len(self.memory_x))
            indices = np.random.choice(len(self.memory_x), batch_size, replace=False)
            
            x_mem = torch.stack([self.memory_x[i] for i in indices]).to(x.device)
            y_mem = torch.tensor([self.memory_y[i] for i in indices], dtype=torch.long).to(x.device)
            logits_old = torch.stack([self.memory_logits[i] for i in indices]).to(x.device)
            
            # Current model predictions on memory
            logits_mem = self.model(x_mem)
            
            # MSE loss on logits (dark knowledge)
            loss_mse = F.mse_loss(logits_mem, logits_old)
            
            # CE loss on labels
            loss_ce = F.cross_entropy(logits_mem, y_mem)
            
            loss += self.alpha * loss_mse + self.beta * loss_ce
        
        return loss


# ============================================================================
# 8. Progressive Neural Networks
# ============================================================================

class ProgressiveNetwork(nn.Module):
    """
    Progressive Neural Networks (Rusu et al., 2016).
    
    Allocate new column for each task, connect to previous columns.
    """
    
    def __init__(self, input_size: int, hidden_sizes: List[int], output_size: int):
        super().__init__()
        self.columns = nn.ModuleList()
        self.laterals = nn.ModuleList()
        
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.n_tasks = 0
        
    def add_task(self):
        """Add new column for new task."""
        # New column
        layers = []
        prev_size = self.input_size
        for hidden_size in self.hidden_sizes:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        layers.append(nn.Linear(prev_size, self.output_size))
        
        column = nn.Sequential(*layers)
        self.columns.append(column)
        
        # Lateral connections (from all previous columns to this column)
        if self.n_tasks > 0:
            lateral_list = []
            for _ in range(self.n_tasks):
                lateral_layers = nn.ModuleList()
                for hidden_size in self.hidden_sizes:
                    lateral_layers.append(nn.Linear(hidden_size, hidden_size))
                lateral_list.append(lateral_layers)
            self.laterals.append(lateral_list)
        
        self.n_tasks += 1
    
    def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
        """Forward pass through specific column with lateral connections."""
        if task_id >= self.n_tasks:
            raise ValueError(f"Task {task_id} not yet added")
        
        # Activations from all columns
        activations = [[] for _ in range(self.n_tasks)]
        
        # Compute activations layer by layer
        current_input = x
        for layer_idx in range(len(self.hidden_sizes) + 1):
            # For each column up to and including task_id
            for col_idx in range(task_id + 1):
                if layer_idx == 0:
                    # First layer: just input
                    if col_idx < len(self.columns):
                        h = self.columns[col_idx][2*layer_idx](current_input) if layer_idx < len(self.hidden_sizes) else self.columns[col_idx][-1](current_input)
                        activations[col_idx].append(h)
                else:
                    # Subsequent layers: own activation + lateral connections
                    h = activations[col_idx][-1]
                    
                    # Add lateral connections from previous columns
                    if col_idx > 0 and layer_idx - 1 < len(self.hidden_sizes):
                        for prev_col in range(col_idx):
                            if len(self.laterals) > col_idx - 1:
                                lateral = self.laterals[col_idx - 1][prev_col][layer_idx - 1]
                                h = h + lateral(activations[prev_col][-1])
                    
                    activations[col_idx].append(h)
        
        return activations[task_id][-1]


# ============================================================================
# 9. Demo and Evaluation
# ============================================================================

def create_split_mnist_tasks(n_tasks: int = 5) -> List[Tuple]:
    """
    Create Split MNIST tasks.
    
    Returns:
        List of (train_dataset, test_dataset) tuples
    """
    from torchvision import datasets, transforms
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Load full MNIST
    train_full = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_full = datasets.MNIST('./data', train=False, transform=transform)
    
    tasks = []
    classes_per_task = 10 // n_tasks
    
    for task_id in range(n_tasks):
        # Classes for this task
        start_class = task_id * classes_per_task
        end_class = start_class + classes_per_task
        
        # Filter datasets
        train_indices = [i for i, (_, y) in enumerate(train_full) if start_class <= y < end_class]
        test_indices = [i for i, (_, y) in enumerate(test_full) if start_class <= y < end_class]
        
        train_subset = Subset(train_full, train_indices)
        test_subset = Subset(test_full, test_indices)
        
        tasks.append((train_subset, test_subset))
    
    return tasks


def evaluate_on_all_tasks(
    model: nn.Module,
    tasks: List[Tuple],
    device: str = 'cuda'
) -> List[float]:
    """Evaluate model on all tasks."""
    accuracies = []
    
    model.eval()
    with torch.no_grad():
        for task_id, (_, test_dataset) in enumerate(tasks):
            test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
            
            correct = 0
            total = 0
            
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
            
            acc = correct / total if total > 0 else 0
            accuracies.append(acc)
    
    return accuracies


def train_continual_learning(
    method: ContinualLearner,
    tasks: List[Tuple],
    n_epochs: int = 5,
    lr: float = 0.01,
    device: str = 'cuda'
) -> EvaluationMetrics:
    """
    Train continual learning method on sequence of tasks.
    
    Args:
        method: CL method
        tasks: List of (train, test) datasets
        n_epochs: Epochs per task
        lr: Learning rate
        
    Returns:
        metrics: Evaluation metrics
    """
    method = method.to(device)
    optimizer = torch.optim.SGD(method.parameters(), lr=lr, momentum=0.9)
    metrics = EvaluationMetrics()
    
    for task_id, (train_dataset, _) in enumerate(tasks):
        print(f"\n{'='*50}")
        print(f"Training Task {task_id}")
        print(f"{'='*50}")
        
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        
        # Train on task
        for epoch in range(n_epochs):
            method.train()
            total_loss = 0
            n_batches = 0
            
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                
                optimizer.zero_grad()
                loss = method.observe(x, y, task_id)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                n_batches += 1
            
            avg_loss = total_loss / n_batches
            print(f"  Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")
        
        # Consolidate after task
        method.after_task(task_id, train_loader)
        
        # Evaluate on all tasks seen so far
        task_accuracies = evaluate_on_all_tasks(method, tasks[:task_id+1], device)
        metrics.update(task_accuracies)
        
        print(f"\nAccuracies after Task {task_id}:")
        for i, acc in enumerate(task_accuracies):
            print(f"  Task {i}: {acc*100:.2f}%")
    
    # Print final metrics
    print(f"\n{'='*50}")
    print("Final Metrics")
    print(f"{'='*50}")
    print(f"Average Accuracy: {metrics.average_accuracy()*100:.2f}%")
    print(f"Forgetting Measure: {metrics.forgetting_measure()*100:.2f}%")
    print(f"Backward Transfer: {metrics.backward_transfer()*100:.2f}%")
    
    return metrics


def demo_continual_learning():
    """Demo comparing different CL methods."""
    print("=" * 80)
    print("Continual Learning Methods Comparison Demo")
    print("=" * 80)
    
    # Create tasks
    print("\nCreating Split MNIST tasks...")
    tasks = create_split_mnist_tasks(n_tasks=5)
    print(f"Created {len(tasks)} tasks (2 digits each)")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Simple CNN model
    class SimpleCNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(1, 32, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            self.fc = nn.Sequential(
                nn.Linear(64 * 7 * 7, 128),
                nn.ReLU(),
                nn.Linear(128, 10)
            )
        
        def forward(self, x):
            x = self.conv(x)
            x = x.view(x.size(0), -1)
            return self.fc(x)
    
    # Methods to compare
    methods = {
        'Naive (Fine-tuning)': ContinualLearner(SimpleCNN()),
        'EWC': EWC(SimpleCNN(), lambda_=1000),
        'Experience Replay': ExperienceReplay(SimpleCNN(), memory_size=500),
        'LwF': LearningWithoutForgetting(SimpleCNN(), lambda_=1.0),
    }
    
    results = {}
    
    for name, method in methods.items():
        print(f"\n{'*'*80}")
        print(f"Training: {name}")
        print(f"{'*'*80}")
        
        metrics = train_continual_learning(
            method,
            tasks,
            n_epochs=3,
            lr=0.01,
            device=device
        )
        
        results[name] = {
            'avg_acc': metrics.average_accuracy(),
            'forgetting': metrics.forgetting_measure(),
            'bwt': metrics.backward_transfer()
        }
    
    # Print comparison table
    print("\n" + "=" * 80)
    print("Results Comparison")
    print("=" * 80)
    print(f"{'Method':<25} {'Avg Acc':<12} {'Forgetting':<12} {'BWT':<12}")
    print("-" * 80)
    
    for name, result in results.items():
        print(f"{name:<25} {result['avg_acc']*100:>10.2f}% {result['forgetting']*100:>10.2f}% {result['bwt']*100:>10.2f}%")


if __name__ == "__main__":
    print("\nAdvanced Continual Learning - Comprehensive Implementation\n")
    demo_continual_learning()
    
    print("\n" + "=" * 80)
    print("Demo complete! Key takeaways:")
    print("=" * 80)
    print("1. Catastrophic forgetting is a major challenge in continual learning")
    print("2. EWC protects important parameters using Fisher information")
    print("3. Experience Replay is simple but very effective")
    print("4. LwF uses knowledge distillation (no data storage)")
    print("5. Trade-off between stability and plasticity is fundamental")
    print("6. Different methods suit different scenarios (memory, privacy, etc.)")
    print("=" * 80)