import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from copy import deepcopy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Catastrophic ForgettingΒΆ
ProblemΒΆ
Neural networks forget previous tasks when learning new ones.
EWC SolutionΒΆ
where \(F_i\) is Fisher information:
π Reference Materials:
bayesian_inference_deep_learning.pdf - Bayesian Inference Deep Learning
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_mnist = datasets.MNIST('./data', train=False, download=True, transform=transform)
# Create task splits: Task A (0-4), Task B (5-9)
indices_A = [i for i, (_, label) in enumerate(mnist) if label < 5]
indices_B = [i for i, (_, label) in enumerate(mnist) if label >= 5]
dataset_A = torch.utils.data.Subset(mnist, indices_A)
dataset_B = torch.utils.data.Subset(mnist, indices_B)
loader_A = torch.utils.data.DataLoader(dataset_A, batch_size=128, shuffle=True)
loader_B = torch.utils.data.DataLoader(dataset_B, batch_size=128, shuffle=True)
print(f"Task A: {len(dataset_A)}, Task B: {len(dataset_B)}")
Compute Fisher InformationΒΆ
Elastic Weight Consolidation (EWC) prevents catastrophic forgetting by penalizing changes to parameters that were important for previous tasks. The importance of each parameter is estimated by the diagonal of the Fisher information matrix, which approximates the curvature of the loss landscape around the current parameters: \(F_i = \mathbb{E}\left[\left(\frac{\partial \log p(y|x; \theta)}{\partial \theta_i}\right)^2\right]\). Parameters with high Fisher information are critical for the current task and should be changed minimally when learning new tasks. Computing the Fisher matrix requires a forward-backward pass over a representative sample of the current taskβs data.
def compute_fisher(model, data_loader, n_samples=1000):
"""Compute Fisher information matrix."""
model.eval()
fisher = {}
for name, param in model.named_parameters():
fisher[name] = torch.zeros_like(param)
count = 0
for x, y in data_loader:
if count >= n_samples:
break
x, y = x.to(device), y.to(device)
# Forward
output = model(x)
log_probs = F.log_softmax(output, dim=1)
# Sample predictions
sampled_y = torch.multinomial(torch.exp(log_probs), 1).squeeze()
# Compute gradients
loss = F.nll_loss(log_probs, sampled_y)
model.zero_grad()
loss.backward()
# Accumulate squared gradients
for name, param in model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.pow(2) * len(x)
count += len(x)
# Normalize
for name in fisher:
fisher[name] /= count
return fisher
print("Fisher computation function defined")
EWC TrainingΒΆ
When learning a new task, EWC adds a quadratic penalty to the loss for each parameter, weighted by its Fisher information from the previous task: \(\mathcal{L}_{\text{EWC}} = \mathcal{L}_{\text{new task}} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta^*_i)^2\), where \(\theta^*\) are the parameters after training on the previous task. The hyperparameter \(\lambda\) controls the strength of the consolidation: too low and the model forgets; too high and it cannot learn the new task. This penalty acts as an anisotropic regularizer, allowing free movement in directions that do not affect the previous task while constraining important directions.
def ewc_loss(model, fisher, old_params, lambda_ewc=1000):
"""Compute EWC regularization loss."""
loss = 0
for name, param in model.named_parameters():
if name in fisher:
loss += (fisher[name] * (param - old_params[name]).pow(2)).sum()
return lambda_ewc / 2 * loss
def train_with_ewc(model, loader, fisher=None, old_params=None, n_epochs=5, lambda_ewc=1000):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
output = model(x)
loss_task = F.cross_entropy(output, y)
# Add EWC penalty
if fisher is not None:
loss_ewc = ewc_loss(model, fisher, old_params, lambda_ewc)
loss = loss_task + loss_ewc
else:
loss = loss_task
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(loader):.4f}")
print("EWC training function defined")
Baseline: Sequential TrainingΒΆ
The baseline for continual learning is naive sequential training: train on Task 1 until convergence, then train on Task 2 from the same model without any forgetting prevention. This baseline exhibits catastrophic forgetting β performance on Task 1 drops precipitously as the modelβs parameters are overwritten to accommodate Task 2. Measuring the accuracy on all previous tasks after each new task quantifies the severity of forgetting and provides the benchmark against which EWC and other continual learning methods are evaluated.
def evaluate_tasks(model, test_mnist):
"""Evaluate on both tasks."""
model.eval()
# Task A: 0-4
indices_A = [i for i, (_, label) in enumerate(test_mnist) if label < 5]
test_A = torch.utils.data.Subset(test_mnist, indices_A)
loader_A = torch.utils.data.DataLoader(test_A, batch_size=1000)
correct_A = 0
total_A = 0
with torch.no_grad():
for x, y in loader_A:
x, y = x.to(device), y.to(device)
output = model(x)
pred = output.argmax(dim=1)
correct_A += (pred == y).sum().item()
total_A += y.size(0)
# Task B: 5-9
indices_B = [i for i, (_, label) in enumerate(test_mnist) if label >= 5]
test_B = torch.utils.data.Subset(test_mnist, indices_B)
loader_B = torch.utils.data.DataLoader(test_B, batch_size=1000)
correct_B = 0
total_B = 0
with torch.no_grad():
for x, y in loader_B:
x, y = x.to(device), y.to(device)
output = model(x)
pred = output.argmax(dim=1)
correct_B += (pred == y).sum().item()
total_B += y.size(0)
return 100 * correct_A / total_A, 100 * correct_B / total_B
# Baseline: Sequential without EWC
model_baseline = SimpleNet().to(device)
print("Training Task A...")
train_with_ewc(model_baseline, loader_A, n_epochs=5)
acc_A_after_A, acc_B_after_A = evaluate_tasks(model_baseline, test_mnist)
print(f"After Task A - A: {acc_A_after_A:.2f}%, B: {acc_B_after_A:.2f}%")
print("\nTraining Task B...")
train_with_ewc(model_baseline, loader_B, n_epochs=5)
acc_A_after_B, acc_B_after_B = evaluate_tasks(model_baseline, test_mnist)
print(f"After Task B - A: {acc_A_after_B:.2f}%, B: {acc_B_after_B:.2f}%")
print(f"Forgetting on A: {acc_A_after_A - acc_A_after_B:.2f}%")
EWC TrainingΒΆ
With EWC enabled, we train on the sequence of tasks while applying the Fisher-weighted penalty. After each task, we compute the new Fisher information matrix and store the current parameters as the reference point. For a sequence of tasks, the penalty accumulates contributions from all previous tasks, ensuring that parameters important for any past task are protected. Monitoring accuracy on all tasks simultaneously during training reveals how effectively EWC preserves previous knowledge while accommodating new learning.
model_ewc = SimpleNet().to(device)
print("Training Task A...")
train_with_ewc(model_ewc, loader_A, n_epochs=5)
acc_A_ewc, _ = evaluate_tasks(model_ewc, test_mnist)
print(f"After Task A - A: {acc_A_ewc:.2f}%")
# Compute Fisher and save parameters
print("\nComputing Fisher information...")
fisher = compute_fisher(model_ewc, loader_A, n_samples=2000)
old_params = {name: param.clone() for name, param in model_ewc.named_parameters()}
print("Training Task B with EWC...")
train_with_ewc(model_ewc, loader_B, fisher, old_params, n_epochs=5, lambda_ewc=5000)
acc_A_ewc_final, acc_B_ewc_final = evaluate_tasks(model_ewc, test_mnist)
print(f"After Task B - A: {acc_A_ewc_final:.2f}%, B: {acc_B_ewc_final:.2f}%")
print(f"Forgetting on A: {acc_A_ewc - acc_A_ewc_final:.2f}%")
Compare ResultsΒΆ
Plotting the accuracy on each task over the course of sequential learning produces a forgetting matrix: the diagonal shows peak performance on each task, and off-diagonal entries show how much that performance degrades after subsequent tasks. EWC should show dramatically less off-diagonal degradation compared to naive sequential training, with only modest reduction in new-task learning speed. Quantitative metrics include average accuracy (across all tasks at the end) and backward transfer (average change in performance on old tasks after learning new ones).
# Comparison
methods = ['Baseline', 'EWC']
task_A = [acc_A_after_B, acc_A_ewc_final]
task_B = [acc_B_after_B, acc_B_ewc_final]
x = np.arange(len(methods))
width = 0.35
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(x - width/2, task_A, width, label='Task A (0-4)', alpha=0.8)
ax.bar(x + width/2, task_B, width, label='Task B (5-9)', alpha=0.8)
ax.set_ylabel('Accuracy (%)', fontsize=11)
ax.set_title('Continual Learning Comparison', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(methods)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
print(f"\nBaseline - Forgetting: {acc_A_after_A - acc_A_after_B:.2f}%")
print(f"EWC - Forgetting: {acc_A_ewc - acc_A_ewc_final:.2f}%")
SummaryΒΆ
Continual Learning:ΒΆ
Problem: Catastrophic forgetting when learning sequentially
EWC Solution:
Compute Fisher information after Task A
Penalize changes to important parameters
Balance new task learning with retention
Key Components:ΒΆ
Fisher matrix: Parameter importance
Quadratic penalty: Soft constraint
Lambda: Retention vs plasticity tradeoff
Other Approaches:ΒΆ
Rehearsal: Store subset of old data
Progressive networks: Add new columns
PackNet: Prune and freeze
LwF: Knowledge distillation
Applications:ΒΆ
Lifelong learning systems
Robotics (new environments)
Personalization (user adaptation)
Edge devices (limited memory)
Advanced Continual Learning TheoryΒΆ
1. Catastrophic Forgetting: The Stability-Plasticity DilemmaΒΆ
The ProblemΒΆ
Neural networks exhibit a fundamental tension between stability (retaining old knowledge) and plasticity (learning new information). When trained sequentially on different tasks, standard gradient descent catastrophically forgets previous tasks.
Mathematical formulation:
Given tasks \(\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_T\) arriving sequentially, a network optimized for task \(\mathcal{T}_t\) will have high loss on earlier tasks:
even though \(\theta_1\) was optimal for \(\mathcal{T}_1\).
Why It HappensΒΆ
Weight overlap: Different tasks often require similar features in lower layers, but different mappings in higher layers. Updating weights for task 2 overwrites representations learned for task 1.
Gradient interference: The gradient for task \(\mathcal{T}_t\) may point in the opposite direction from the gradient for task \(\mathcal{T}_{t-1}\):
Distributed representations: Neural networks use distributed representations where each neuron contributes to multiple tasks. Changing one neuron affects all tasks.
Measuring ForgettingΒΆ
Forgetting metric (Chaudhry et al., 2018):
where \(a_{k,k}\) is accuracy on task \(k\) immediately after training on it, and \(a_{j,t}\) is accuracy on task \(j\) after training on task \(t\).
Average forgetting:
2. Elastic Weight Consolidation (EWC): Detailed DerivationΒΆ
Bayesian PerspectiveΒΆ
EWC views continual learning as sequential Bayesian inference. After learning task \(A\), the posterior becomes the prior for task \(B\):
Taking negative log:
Laplace ApproximationΒΆ
Approximate the posterior \(p(\theta | \mathcal{D}_A)\) as a Gaussian centered at the optimal parameters \(\theta_A^*\):
where \(F\) is the Fisher information matrix:
Fisher Information MatrixΒΆ
The Fisher information quantifies how much the model output changes when a parameter is perturbed.
For classification with softmax:
The diagonal Fisher information for parameter \(\theta_i\) is:
Intuition: If \(F_i\) is large, parameter \(\theta_i\) significantly affects task \(A\)βs predictions, so it should not change much when learning task \(B\).
EWC Loss FunctionΒΆ
Combining the task \(B\) loss with the quadratic penalty:
Hyperparameter \(\lambda\): Controls the trade-off between:
High \(\lambda\): Strong retention of task \(A\) (low forgetting, but may underfit task \(B\))
Low \(\lambda\): Better performance on task \(B\) (but more forgetting of task \(A\))
Practical ComputationΒΆ
Sampling-based Fisher estimation:
Sample inputs \(x\) from task \(A\) data
Sample labels \(\hat{y}\) from the modelβs predictive distribution \(p(y|x, \theta_A^*)\)
Compute gradients \(\nabla_\theta \log p(\hat{y} | x, \theta_A^*)\)
Average squared gradients:
Diagonal approximation: Computing the full Fisher matrix is intractable for large networks (\(O(|\theta|^2)\) memory). EWC uses only the diagonal, assuming parameter independence.
3. Progressive Neural NetworksΒΆ
ArchitectureΒΆ
Progressive Neural Networks (Rusu et al., 2016) freeze old task columns and add new columns for each task, with lateral connections from old to new.
Structure for task \(t\):
where:
\(h_i^{(t)}\): Hidden layer \(i\) activations for task \(t\)
\(W_i^{(t)}\): Weights within column \(t\)
\(U_i^{(j \to t)}\): Lateral connections from column \(j\) to column \(t\)
Key PropertiesΒΆ
Zero forgetting: Old columns are frozen, so performance on old tasks is exactly preserved.
Knowledge transfer: Lateral connections allow task \(t\) to reuse features from tasks \(1, \ldots, t-1\).
Capacity growth: Model size grows linearly with the number of tasks: \(\theta_{\text{total}} = \sum_{t=1}^T |\theta_t|\).
When to UseΒΆ
Advantages:
No forgetting by construction
Automatic transfer learning via lateral connections
Each task gets dedicated capacity
Disadvantages:
Memory grows unbounded
Inference cost increases with tasks
Not suitable for resource-constrained settings
4. Memory-Based ApproachesΒΆ
Experience ReplayΒΆ
Store a memory buffer \(\mathcal{M}\) containing examples from previous tasks. When learning task \(t\), sample from both \(\mathcal{D}_t\) and \(\mathcal{M}\):
Reservoir sampling: To maintain a fixed-size buffer, use reservoir sampling to ensure uniform representation across tasks.
Advantages:
Simple and effective
Works with any architecture
Can balance old and new data
Disadvantages:
Requires storing raw data (privacy concerns)
Memory limited
May not scale to many tasks
Generative ReplayΒΆ
Train a generative model \(G\) alongside the task network. Instead of storing real data, generate pseudo-examples from old tasks:
Advantage: No need to store raw data.
Disadvantage: Requires training a high-quality generative model (e.g., VAE, GAN).
Pseudo-RehearsalΒΆ
Similar to generative replay, but uses the task network itself to generate labels for synthetic inputs:
Generate random input \(\tilde{x}\)
Use old model \(\theta_{\text{old}}\) to label: \(\tilde{y} = f_{\theta_{\text{old}}}(\tilde{x})\)
Train new model to match: \(\mathcal{L}(\theta) = \ell(f_\theta(\tilde{x}), \tilde{y})\)
5. Regularization-Based MethodsΒΆ
Learning without Forgetting (LwF)ΒΆ
Idea: Constrain the new model to produce similar outputs to the old model on new data.
Knowledge distillation loss:
where \(D_{\text{KL}}\) is the Kullback-Leibler divergence between output distributions.
Soft targets: Use the old modelβs softmax probabilities as soft labels:
with temperature \(T > 1\) to soften the distribution.
PackNetΒΆ
Prune and freeze approach:
Train network on task 1
Prune less important weights (e.g., lowest magnitude)
Freeze remaining important weights
Train remaining capacity on task 2
Repeat
Iterative pruning: After each task, prune a fraction \(p\) of weights. Remaining capacity for task \(t\):
Advantage: Simple, no hyperparameters for retention.
Disadvantage: Capacity decreases exponentially; limited scalability.
Hard Attention Masks (HAT)ΒΆ
Learn binary masks \(m_i^{(t)} \in \{0, 1\}\) for each task \(t\), where \(m_i^{(t)} = 1\) means parameter \(\theta_i\) is used for task \(t\).
Forward pass:
Mask learning: Use Gumbel-Softmax to learn approximately binary masks via gradient descent.
Advantage: Learns task-specific sub-networks; no forgetting.
Disadvantage: Complex training; requires task IDs at test time.
6. Multi-Task Learning vs Continual LearningΒΆ
Aspect |
Multi-Task Learning |
Continual Learning |
|---|---|---|
Data access |
All tasks available simultaneously |
Tasks arrive sequentially |
Goal |
Joint optimization across tasks |
Learn new tasks without forgetting old |
Challenge |
Negative transfer, task balancing |
Catastrophic forgetting |
Memory |
Requires storing all data |
Limited memory for old tasks |
Evaluation |
Average performance across tasks |
Forward transfer + backward transfer (retention) |
7. Theoretical GuaranteesΒΆ
PAC-Bayes Bounds for Continual LearningΒΆ
For a sequence of tasks with data distributions \(\mathcal{D}_1, \ldots, \mathcal{D}_T\), the expected risk on all tasks is bounded by:
where \(q\) is the learned posterior and \(p\) is the prior.
Interpretation: The bound grows with the number of tasks \(T\), highlighting the difficulty of lifelong learning.
Gradient Episodic Memory (GEM)ΒΆ
Constraint: New gradient should not increase loss on old tasks:
where \(g_k\) is the gradient on task \(k\) memory.
Optimization: Solve quadratic program to project gradient into feasible region if constraint is violated.
8. Practical ConsiderationsΒΆ
Choosing a MethodΒΆ
Small number of tasks (< 10):
EWC or LwF (simple, effective)
Progressive networks (if memory allows)
Many tasks (> 100):
Experience replay with reservoir sampling
Hard attention masks (if task IDs available)
Resource-constrained (edge devices):
PackNet (no additional memory)
Small replay buffer
Privacy-sensitive:
Generative replay or LwF (no raw data storage)
Hyperparameter TuningΒΆ
EWC \(\lambda\):
Start with \(\lambda \in [10^3, 10^5]\)
Higher for tasks with high importance
Cross-validate on held-out tasks
Replay buffer size:
Minimum: \(\sim 100\) examples per class
Optimal: \(\sim 1000\) examples per class
Trade-off with storage constraints
Temperature \(T\) (LwF):
Higher \(T\) β softer targets (more knowledge transfer)
Typical: \(T \in [2, 5]\)
Evaluation MetricsΒΆ
Average accuracy:
where \(a_{t,T}\) is accuracy on task \(t\) after learning all \(T\) tasks.
Backward transfer (forgetting):
Positive BWT indicates improvement on old tasks (rare); negative indicates forgetting.
Forward transfer:
where \(b_t\) is random chance accuracy. Positive FWT indicates knowledge transfer from previous tasks.
9. Advanced TopicsΒΆ
Continual Meta-LearningΒΆ
Combine meta-learning with continual learning: learn a meta-learner that quickly adapts to new tasks without forgetting.
OML (Online-aware Meta-learning): Meta-train with online (sequential) task arrivals, not episodic sampling.
La-MAML: Modulate MAMLβs inner loop learning rate based on parameter importance (similar to EWC).
Task-Free Continual LearningΒΆ
Most methods assume task boundaries are known. Task-free setting: no explicit task IDs.
Challenge: How to detect when a new task begins?
Solutions:
Unsupervised change detection (distribution shift)
Clustering in representation space
Bayesian online changepoint detection
Continual Reinforcement LearningΒΆ
Additional challenges:
Non-stationary reward distributions
Exploration-exploitation in new environments
Policy catastrophic forgetting
Approaches:
EWC for policy networks
Progressive networks for new environments
Experience replay with prioritized sampling
PDF ReferencesΒΆ
π Foundational Papers:
Kirkpatrick et al. (2017): βOvercoming catastrophic forgetting in neural networksβ (EWC)
Rusu et al. (2016): βProgressive Neural Networksβ
Li & Hoiem (2018): βLearning without Forgettingβ
Lopez-Paz & Ranzato (2017): βGradient Episodic Memory for Continual Learningβ (GEM)
Chaudhry et al. (2019): βOn Tiny Episodic Memories in Continual Learningβ
π Surveys:
Parisi et al. (2019): βContinual Lifelong Learning with Neural Networks: A Reviewβ
De Lange et al. (2021): βA Continual Learning Survey: Defying Forgetting in Classification Tasksβ
π GitHub Resources:
# ============================================================================
# Advanced Continual Learning Implementations
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from collections import defaultdict
import random
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# ============================================================================
# 1. Complete EWC Implementation with Fisher Matrix
# ============================================================================
class EWC:
"""
Elastic Weight Consolidation for continual learning.
Implements the full EWC algorithm including:
- Fisher information matrix computation
- Quadratic penalty loss
- Online EWC variant (accumulating Fisher across tasks)
"""
def __init__(self, model, lambda_ewc=5000, fisher_sample_size=200):
self.model = model
self.lambda_ewc = lambda_ewc
self.fisher_sample_size = fisher_sample_size
# Store Fisher information and optimal parameters for each task
self.fisher = {}
self.optim_params = {}
self.task_count = 0
def _compute_fisher_diagonal(self, dataloader):
"""
Compute diagonal Fisher information matrix.
Fisher information quantifies how sensitive the model output is
to changes in each parameter:
F_i = E[(βlog p(y|x,ΞΈ)/βΞΈ_i)Β²]
"""
fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
self.model.eval()
samples_seen = 0
for inputs, labels in dataloader:
if samples_seen >= self.fisher_sample_size:
break
inputs = inputs.to(device)
self.model.zero_grad()
# Forward pass
outputs = self.model(inputs)
# Sample from model's predictive distribution
log_probs = F.log_softmax(outputs, dim=1)
sampled_labels = torch.multinomial(torch.exp(log_probs), 1).squeeze()
# Compute negative log-likelihood
loss = F.nll_loss(log_probs, sampled_labels)
# Compute gradients
loss.backward()
# Accumulate squared gradients (Fisher approximation)
for n, p in self.model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.pow(2) * inputs.size(0)
samples_seen += inputs.size(0)
# Normalize by number of samples
for n in fisher:
fisher[n] /= samples_seen
return fisher
def register_task(self, dataloader):
"""
Register a new task: compute Fisher and save optimal parameters.
"""
print(f"Registering task {self.task_count + 1}...")
# Compute Fisher information
fisher = self._compute_fisher_diagonal(dataloader)
# Save Fisher and parameters for this task
self.fisher[self.task_count] = fisher
self.optim_params[self.task_count] = {
n: p.clone().detach()
for n, p in self.model.named_parameters()
}
self.task_count += 1
print(f"Task {self.task_count} registered. Fisher computed over {self.fisher_sample_size} samples.")
def penalty(self):
"""
Compute EWC quadratic penalty across all previous tasks.
L_EWC = (Ξ»/2) Ξ£_t Ξ£_i F_i^(t) (ΞΈ_i - ΞΈ_i^(t)*)Β²
"""
loss = 0.0
for task_id in range(self.task_count):
for n, p in self.model.named_parameters():
fisher_diag = self.fisher[task_id][n]
optimal_param = self.optim_params[task_id][n]
# Quadratic penalty weighted by Fisher information
loss += (fisher_diag * (p - optimal_param).pow(2)).sum()
return (self.lambda_ewc / 2) * loss
# ============================================================================
# 2. Progressive Neural Networks
# ============================================================================
class ProgressiveColumn(nn.Module):
"""
Single column in Progressive Neural Network.
Each column has:
- Own feedforward path
- Lateral connections from all previous columns
"""
def __init__(self, input_dim, hidden_dims, output_dim, prev_columns=None):
super().__init__()
self.prev_columns = prev_columns if prev_columns is not None else []
# Build column layers
layers = []
dims = [input_dim] + hidden_dims
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i+1]))
layers.append(nn.ReLU())
self.column = nn.Sequential(*layers)
self.output = nn.Linear(hidden_dims[-1], output_dim)
# Lateral connections from previous columns
if len(self.prev_columns) > 0:
self.lateral_connections = nn.ModuleList([
nn.Linear(hidden_dims[-1], hidden_dims[-1])
for _ in self.prev_columns
])
else:
self.lateral_connections = None
def forward(self, x, prev_activations=None):
"""
Forward pass with lateral connections.
h^(t) = f(W^(t) x + Ξ£_{j<t} U^(jβt) h^(j))
"""
# Column's own computation
h = self.column(x)
# Add lateral contributions from previous columns
if prev_activations is not None and len(prev_activations) > 0:
for i, (prev_h, lateral) in enumerate(zip(prev_activations, self.lateral_connections)):
h = h + lateral(prev_h)
output = self.output(h)
return output, h # Return both output and hidden state for next column
class ProgressiveNeuralNetwork(nn.Module):
"""
Progressive Neural Networks: Add new columns for each task.
Key properties:
- Zero forgetting (old columns frozen)
- Lateral connections enable transfer
- Model capacity grows with tasks
"""
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.output_dim = output_dim
self.columns = nn.ModuleList()
def add_task(self):
"""Add a new column for a new task."""
new_column = ProgressiveColumn(
self.input_dim,
self.hidden_dims,
self.output_dim,
prev_columns=list(self.columns)
)
# Freeze all previous columns
for col in self.columns:
for param in col.parameters():
param.requires_grad = False
self.columns.append(new_column)
print(f"Added column {len(self.columns)}. Total parameters: {sum(p.numel() for p in self.parameters())}")
def forward(self, x, task_id=None):
"""
Forward pass through progressive network.
If task_id is None, use the latest column.
"""
if task_id is None:
task_id = len(self.columns) - 1
# Compute activations for all columns up to task_id
prev_activations = []
for i in range(task_id):
with torch.no_grad(): # Previous columns are frozen
_, h = self.columns[i](x, prev_activations)
prev_activations.append(h)
# Forward through target column with lateral connections
output, _ = self.columns[task_id](x, prev_activations)
return output
# ============================================================================
# 3. Experience Replay with Reservoir Sampling
# ============================================================================
class ReplayBuffer:
"""
Experience replay buffer with reservoir sampling.
Maintains a fixed-size buffer with uniform sampling across all seen tasks.
"""
def __init__(self, max_size=1000):
self.max_size = max_size
self.buffer = []
self.total_seen = 0
def add(self, x, y):
"""
Add examples using reservoir sampling algorithm.
Ensures uniform probability (1/n) for each of n seen examples.
"""
for i in range(len(x)):
if len(self.buffer) < self.max_size:
# Buffer not full: add directly
self.buffer.append((x[i].cpu(), y[i].cpu()))
else:
# Reservoir sampling: replace with probability 1/total_seen
j = random.randint(0, self.total_seen)
if j < self.max_size:
self.buffer[j] = (x[i].cpu(), y[i].cpu())
self.total_seen += 1
def sample(self, batch_size):
"""Sample a mini-batch from the buffer."""
if len(self.buffer) == 0:
return None, None
indices = random.sample(range(len(self.buffer)), min(batch_size, len(self.buffer)))
batch_x = torch.stack([self.buffer[i][0] for i in indices]).to(device)
batch_y = torch.tensor([self.buffer[i][1] for i in indices], dtype=torch.long).to(device)
return batch_x, batch_y
def __len__(self):
return len(self.buffer)
# ============================================================================
# 4. Learning without Forgetting (LwF)
# ============================================================================
class LwF:
"""
Learning without Forgetting via knowledge distillation.
Constrains new model to produce similar outputs to old model on new data.
"""
def __init__(self, model, temperature=2.0, alpha=0.5):
self.model = model
self.temperature = temperature # Temperature for soft targets
self.alpha = alpha # Weight for distillation loss
self.old_model = None
def register_task(self):
"""Save current model as 'old' for next task."""
self.old_model = deepcopy(self.model)
self.old_model.eval()
for param in self.old_model.parameters():
param.requires_grad = False
print("Old model saved for distillation.")
def distillation_loss(self, outputs, inputs):
"""
Compute knowledge distillation loss.
L_distill = KL(softmax(z_old/T) || softmax(z_new/T))
where T > 1 softens the distributions.
"""
if self.old_model is None:
return 0.0
with torch.no_grad():
old_outputs = self.old_model(inputs)
# Soft targets with temperature scaling
soft_targets = F.softmax(old_outputs / self.temperature, dim=1)
soft_predictions = F.log_softmax(outputs / self.temperature, dim=1)
# KL divergence with temperature scaling
# Multiply by TΒ² to compensate for softened gradients
distill_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean')
distill_loss = distill_loss * (self.temperature ** 2)
return distill_loss
def combined_loss(self, outputs, labels, inputs):
"""
Combine task loss with distillation loss.
L = L_task + Ξ± * L_distill
"""
task_loss = F.cross_entropy(outputs, labels)
if self.old_model is not None:
distill_loss = self.distillation_loss(outputs, inputs)
return task_loss + self.alpha * distill_loss
return task_loss
# ============================================================================
# 5. Gradient Episodic Memory (GEM)
# ============================================================================
class GEM:
"""
Gradient Episodic Memory: Constrain gradients to not increase loss on old tasks.
Solves quadratic program:
min ||g - g_t||Β²
s.t. g^T g_k β₯ 0 β k < t
where g_t is current gradient, g_k is gradient on task k memory.
"""
def __init__(self, model, memory_per_task=256, margin=0.5):
self.model = model
self.memory_per_task = memory_per_task
self.margin = margin
self.memory = [] # List of (inputs, labels) for each task
def store_task_memory(self, dataloader):
"""Store a subset of examples for this task."""
task_memory = []
for inputs, labels in dataloader:
task_memory.extend(list(zip(inputs.cpu(), labels.cpu())))
if len(task_memory) >= self.memory_per_task:
break
# Random sample if we collected too many
if len(task_memory) > self.memory_per_task:
task_memory = random.sample(task_memory, self.memory_per_task)
self.memory.append(task_memory)
print(f"Stored {len(task_memory)} examples for task {len(self.memory)}")
def compute_gradient(self, inputs, labels):
"""Compute gradient for given data."""
self.model.zero_grad()
outputs = self.model(inputs.to(device))
loss = F.cross_entropy(outputs, labels.to(device))
loss.backward()
# Extract gradient as flat vector
grad = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
return grad
def project_gradient(self, current_grad):
"""
Project gradient to satisfy constraints: g^T g_k β₯ 0 for all old tasks.
Uses quadratic programming to find closest gradient that doesn't
increase loss on old tasks.
"""
if len(self.memory) == 0:
return current_grad
# Compute gradients for all previous tasks
ref_grads = []
for task_mem in self.memory:
inputs = torch.stack([x for x, _ in task_mem[:32]]) # Sample from memory
labels = torch.tensor([y for _, y in task_mem[:32]], dtype=torch.long)
grad_k = self.compute_gradient(inputs, labels)
ref_grads.append(grad_k)
# Check if current gradient violates constraints
violations = [(current_grad @ g_k).item() < 0 for g_k in ref_grads]
if not any(violations):
return current_grad # No violations, use original gradient
# Simple projection: average with violating gradients
# (Full GEM would solve QP, but this is a practical approximation)
violated_grads = [g_k for g_k, v in zip(ref_grads, violations) if v]
if len(violated_grads) > 0:
avg_violated = torch.stack(violated_grads).mean(dim=0)
projected_grad = current_grad - avg_violated * 0.5
return projected_grad
return current_grad
# ============================================================================
# 6. Comparison Experiment
# ============================================================================
def create_split_mnist():
"""Create two-task split MNIST: tasks 0-4 and 5-9."""
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
# Task 1: digits 0-4
train_indices_t1 = [i for i, (_, label) in enumerate(train_dataset) if label < 5]
test_indices_t1 = [i for i, (_, label) in enumerate(test_dataset) if label < 5]
# Task 2: digits 5-9
train_indices_t2 = [i for i, (_, label) in enumerate(train_dataset) if label >= 5]
test_indices_t2 = [i for i, (_, label) in enumerate(test_dataset) if label >= 5]
train_t1 = Subset(train_dataset, train_indices_t1)
test_t1 = Subset(test_dataset, test_indices_t1)
train_t2 = Subset(train_dataset, train_indices_t2)
test_t2 = Subset(test_dataset, test_indices_t2)
return train_t1, test_t1, train_t2, test_t2
def evaluate_model(model, test_loader):
"""Evaluate model accuracy."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return 100.0 * correct / total
print("\n" + "="*80)
print("Continual Learning Methods Comparison")
print("="*80)
# Create datasets
train_t1, test_t1, train_t2, test_t2 = create_split_mnist()
train_loader_t1 = DataLoader(train_t1, batch_size=128, shuffle=True)
test_loader_t1 = DataLoader(test_t1, batch_size=512, shuffle=False)
train_loader_t2 = DataLoader(train_t2, batch_size=128, shuffle=True)
test_loader_t2 = DataLoader(test_t2, batch_size=512, shuffle=False)
print(f"Task 1 train: {len(train_t1)}, test: {len(test_t1)}")
print(f"Task 2 train: {len(train_t2)}, test: {len(test_t2)}")
# Simple model for experiments
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# ============================================================================
# Compare: EWC vs Replay vs LwF
# ============================================================================
results = {}
# 1. Baseline (Sequential, no CL method)
print("\n--- Baseline (No CL) ---")
model_baseline = SimpleNet().to(device)
optimizer_baseline = torch.optim.Adam(model_baseline.parameters(), lr=1e-3)
# Train task 1
for epoch in range(3):
model_baseline.train()
for inputs, labels in train_loader_t1:
inputs, labels = inputs.to(device), labels.to(device)
optimizer_baseline.zero_grad()
loss = F.cross_entropy(model_baseline(inputs), labels)
loss.backward()
optimizer_baseline.step()
acc_t1_after_t1 = evaluate_model(model_baseline, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_after_t1:.2f}%")
# Train task 2
for epoch in range(3):
model_baseline.train()
for inputs, labels in train_loader_t2:
inputs, labels = inputs.to(device), labels.to(device)
optimizer_baseline.zero_grad()
loss = F.cross_entropy(model_baseline(inputs), labels)
loss.backward()
optimizer_baseline.step()
acc_t1_after_t2 = evaluate_model(model_baseline, test_loader_t1)
acc_t2_after_t2 = evaluate_model(model_baseline, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_after_t2:.2f}%, T2 acc: {acc_t2_after_t2:.2f}%")
print(f"Forgetting: {acc_t1_after_t1 - acc_t1_after_t2:.2f}%")
results['Baseline'] = {
't1_after_t1': acc_t1_after_t1,
't1_after_t2': acc_t1_after_t2,
't2_after_t2': acc_t2_after_t2,
'forgetting': acc_t1_after_t1 - acc_t1_after_t2
}
# 2. EWC
print("\n--- EWC ---")
model_ewc_test = SimpleNet().to(device)
optimizer_ewc = torch.optim.Adam(model_ewc_test.parameters(), lr=1e-3)
ewc = EWC(model_ewc_test, lambda_ewc=5000, fisher_sample_size=500)
# Train task 1
for epoch in range(3):
model_ewc_test.train()
for inputs, labels in train_loader_t1:
inputs, labels = inputs.to(device), labels.to(device)
optimizer_ewc.zero_grad()
loss = F.cross_entropy(model_ewc_test(inputs), labels)
loss.backward()
optimizer_ewc.step()
acc_t1_ewc = evaluate_model(model_ewc_test, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_ewc:.2f}%")
# Register task 1
ewc.register_task(train_loader_t1)
# Train task 2 with EWC
for epoch in range(3):
model_ewc_test.train()
for inputs, labels in train_loader_t2:
inputs, labels = inputs.to(device), labels.to(device)
optimizer_ewc.zero_grad()
loss_task = F.cross_entropy(model_ewc_test(inputs), labels)
loss_ewc_penalty = ewc.penalty()
loss = loss_task + loss_ewc_penalty
loss.backward()
optimizer_ewc.step()
acc_t1_ewc_final = evaluate_model(model_ewc_test, test_loader_t1)
acc_t2_ewc_final = evaluate_model(model_ewc_test, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_ewc_final:.2f}%, T2 acc: {acc_t2_ewc_final:.2f}%")
print(f"Forgetting: {acc_t1_ewc - acc_t1_ewc_final:.2f}%")
results['EWC'] = {
't1_after_t1': acc_t1_ewc,
't1_after_t2': acc_t1_ewc_final,
't2_after_t2': acc_t2_ewc_final,
'forgetting': acc_t1_ewc - acc_t1_ewc_final
}
# 3. Experience Replay
print("\n--- Experience Replay ---")
model_replay = SimpleNet().to(device)
optimizer_replay = torch.optim.Adam(model_replay.parameters(), lr=1e-3)
replay_buffer = ReplayBuffer(max_size=1000)
# Train task 1
for epoch in range(3):
model_replay.train()
for inputs, labels in train_loader_t1:
inputs, labels = inputs.to(device), labels.to(device)
# Store in buffer
replay_buffer.add(inputs, labels)
optimizer_replay.zero_grad()
loss = F.cross_entropy(model_replay(inputs), labels)
loss.backward()
optimizer_replay.step()
acc_t1_replay = evaluate_model(model_replay, test_loader_t1)
print(f"After Task 1 - T1 acc: {acc_t1_replay:.2f}%")
# Train task 2 with replay
for epoch in range(3):
model_replay.train()
for inputs, labels in train_loader_t2:
inputs, labels = inputs.to(device), labels.to(device)
# Sample from replay buffer
replay_x, replay_y = replay_buffer.sample(batch_size=64)
if replay_x is not None:
# Combine current batch with replay
combined_x = torch.cat([inputs, replay_x], dim=0)
combined_y = torch.cat([labels, replay_y], dim=0)
else:
combined_x, combined_y = inputs, labels
optimizer_replay.zero_grad()
loss = F.cross_entropy(model_replay(combined_x), combined_y)
loss.backward()
optimizer_replay.step()
# Add task 2 to buffer
replay_buffer.add(inputs, labels)
acc_t1_replay_final = evaluate_model(model_replay, test_loader_t1)
acc_t2_replay_final = evaluate_model(model_replay, test_loader_t2)
print(f"After Task 2 - T1 acc: {acc_t1_replay_final:.2f}%, T2 acc: {acc_t2_replay_final:.2f}%")
print(f"Forgetting: {acc_t1_replay - acc_t1_replay_final:.2f}%")
results['Replay'] = {
't1_after_t1': acc_t1_replay,
't1_after_t2': acc_t1_replay_final,
't2_after_t2': acc_t2_replay_final,
'forgetting': acc_t1_replay - acc_t1_replay_final
}
# Visualization
print("\n" + "="*80)
print("Final Results Summary")
print("="*80)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
methods = list(results.keys())
t1_final = [results[m]['t1_after_t2'] for m in methods]
t2_final = [results[m]['t2_after_t2'] for m in methods]
forgetting = [results[m]['forgetting'] for m in methods]
# Task accuracies
x = np.arange(len(methods))
width = 0.35
axes[0].bar(x - width/2, t1_final, width, label='Task 1 (0-4)', alpha=0.8)
axes[0].bar(x + width/2, t2_final, width, label='Task 2 (5-9)', alpha=0.8)
axes[0].set_ylabel('Accuracy (%)', fontsize=11)
axes[0].set_title('Final Task Performance', fontsize=12, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(methods, rotation=15, ha='right')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')
# Forgetting
colors = ['red' if f > 10 else 'green' for f in forgetting]
axes[1].bar(methods, forgetting, color=colors, alpha=0.7)
axes[1].set_ylabel('Forgetting (%)', fontsize=11)
axes[1].set_title('Task 1 Forgetting', fontsize=12, fontweight='bold')
axes[1].axhline(y=0, color='black', linestyle='--', linewidth=1)
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
for method in methods:
r = results[method]
print(f"{method:15s}: T1={r['t1_after_t2']:.2f}%, T2={r['t2_after_t2']:.2f}%, Forgetting={r['forgetting']:.2f}%")
print("\n" + "="*80)
print("Implementation Complete!")
print("="*80)
print("\nKey Insights:")
print("1. EWC reduces forgetting by penalizing changes to important parameters")
print("2. Experience Replay maintains performance by rehearsing old examples")
print("3. LwF uses knowledge distillation to preserve old task predictions")
print("4. Progressive networks achieve zero forgetting but grow linearly")
print("5. Trade-offs: Memory (Replay), Computation (EWC), Capacity (Progressive)")
print("\nNext: Apply to real-world continual learning scenarios!")
Advanced Continual Learning TheoryΒΆ
1. Introduction to Continual LearningΒΆ
Continual Learning (also called Lifelong Learning or Incremental Learning) enables models to learn from a stream of data without forgetting previously acquired knowledge.
1.1 The Plasticity-Stability DilemmaΒΆ
Challenge: Balance two competing objectives:
Plasticity: Ability to learn new tasks
Stability: Preserve knowledge of old tasks
Catastrophic Forgetting: When learning new task T_n, performance on T_1, β¦, T_{n-1} degrades significantly.
1.2 Problem FormulationΒΆ
Task sequence: T_1, T_2, β¦, T_n
Goal: After learning task T_n, maintain performance on all tasks:
min ΞΈ_n Ξ£α΅’βββΏ L_i(ΞΈ_n)
Constraints:
No access to data from previous tasks (or limited buffer)
Bounded memory and computation
Tasks may be related or unrelated
1.3 Continual Learning ScenariosΒΆ
Task-Incremental Learning (Task-IL):
Task identity known at test time
Multi-head architecture (one head per task)
Example: Learn tasks 1, 2, 3; test on task 2 (given task ID)
Domain-Incremental Learning (Domain-IL):
Same task, different domains
Single-head architecture
Example: MNIST β SVHN β USPS
Class-Incremental Learning (Class-IL):
New classes added over time
Most challenging (no task ID at test)
Example: Learn cats/dogs, then birds/fish
2. Regularization-Based ApproachesΒΆ
2.1 Elastic Weight Consolidation (EWC)ΒΆ
EWC [Kirkpatrick et al., 2017]: Protect important weights for old tasks.
Objective for task n:
L(ΞΈ) = L_n(ΞΈ) + (Ξ»/2) Ξ£α΅’ F_i (ΞΈα΅’ - ΞΈ*_{n-1,i})Β²
Where:
L_n(ΞΈ): Loss on current task
F_i: Fisher Information Matrix diagonal
ΞΈ*_{n-1}: Parameters after task n-1
Ξ»: Regularization strength
Fisher Information:
F_i = E_{x~D} [(β log p(y|x,ΞΈ)/βΞΈα΅’)Β²]
Approximation:
F_i β (1/N) Ξ£β (βL(x,ΞΈ*)/βΞΈα΅’)Β²
Intuition: Important parameters (high Fisher) are penalized more if changed.
2.2 Synaptic Intelligence (SI)ΒΆ
SI [Zenke et al., 2017]: Measure importance via path integral.
Importance:
Ξ©_i = Ξ£_{t=1}^T (βL/βΞΈα΅’)Β² Β· ΞΞΈα΅’
Accumulates gradient magnitudes weighted by parameter changes.
Loss:
L(ΞΈ) = L_n(ΞΈ) + c Ξ£α΅’ Ξ©α΅’(ΞΈα΅’ - ΞΈ*_{n-1,i})Β²
Advantage: Online computation (no need to store data).
2.3 Memory Aware Synapses (MAS)ΒΆ
MAS [Aljundi et al., 2018]: Importance based on output sensitivity.
Importance:
Ξ©_i = E_x [||βf(x,ΞΈ)/βΞΈα΅’||Β²]
Difference from EWC: Uses output sensitivity, not loss gradient.
Benefit: Unsupervised (doesnβt need labels).
2.4 Learning without Forgetting (LwF)ΒΆ
LwF [Li & Hoiem, 2017]: Distillation from old model.
Loss:
L = L_new(ΞΈ) + Ξ± L_distill(ΞΈ, ΞΈ_old)
Where:
L_distill = KL(p_old(x) || p_new(x))
Knowledge distillation: New model mimics old modelβs outputs on new data.
Advantage: No need to store old data.
3. Replay-Based ApproachesΒΆ
3.1 Experience Replay (ER)ΒΆ
Core idea: Store subset of old data, replay when learning new tasks.
Memory buffer M: Store (x, y, task_id) tuples.
Training: Mix current data with replayed data:
L(ΞΈ) = E_{(x,y)~D_n} [L(x,y,ΞΈ)] + E_{(x,y)~M} [L(x,y,ΞΈ)]
Buffer management:
Reservoir sampling: Random replacement
Herding: Select most representative samples
Class-balanced: Equal samples per class
3.2 Gradient Episodic Memory (GEM)ΒΆ
GEM [Lopez-Paz & Ranzato, 2017]: Constrained optimization.
Constraint: New gradients shouldnβt increase loss on old tasks:
β¨g_new, g_old^iβ© β₯ 0 for all i < n
Where:
g_new = β_ΞΈ L_n(ΞΈ)
g_old^i = β_ΞΈ L_i(ΞΈ) (computed on memory buffer)
Projection: If constraint violated, project g_new to feasible region.
Quadratic program:
min ||g - g_new||Β²
s.t. β¨g, g_old^iβ© β₯ 0
3.3 Averaged GEM (A-GEM)ΒΆ
A-GEM [Chaudhry et al., 2019]: Simplified GEM.
Single constraint: Average gradient on memory:
β¨g_new, αΈ‘_oldβ© β₯ 0
Where αΈ‘_old is average of gradients on memory samples.
Projection:
g = g_new - β¨g_new, αΈ‘_oldβ©/||αΈ‘_old||Β² Β· αΈ‘_old
Advantage: O(1) constraints vs. O(tasks) for GEM.
3.4 Dark Experience Replay (DER)ΒΆ
DER [Buzzega et al., 2020]: Store logits instead of labels.
Memory: (x, logits_old, task_id)
Loss:
L = L_CE(y, f(x)) + Ξ± MSE(logits_old, f(x))
Benefit: Richer information than just labels, mitigates recency bias.
4. Dynamic Architecture ApproachesΒΆ
4.1 Progressive Neural NetworksΒΆ
ProgNN [Rusu et al., 2016]: Add new columns for new tasks.
Architecture:
Task 1: Column C_1
Task 2: Column C_2 + lateral connections from C_1
Task n: Column C_n + lateral connections from C_1, β¦, C_{n-1}
Forward pass for task n:
h_n^(l) = Ο(W_n^(l) h_n^(l-1) + Ξ£α΅’ββ^{n-1} U_iβn^(l) h_i^(l-1))
Advantages:
No forgetting (old parameters frozen)
Transfer via lateral connections
Disadvantages:
O(nΒ²) parameters growth
Requires task ID
4.2 PackNetΒΆ
PackNet [Mallya & Lazebnik, 2018]: Network pruning + packing.
Procedure for task n:
Train full network on task n
Prune unimportant weights (magnitude-based)
Pack: Mark remaining weights as used by task n
Freeze used weights
Inference: Activate only weights for given task.
Parameter efficiency: Reuses parameters across tasks.
4.3 Dynamic Expandable Network (DEN)ΒΆ
DEN [Yoon et al., 2018]: Selective retraining and expansion.
Algorithm:
Selective retraining: Retrain low-drift neurons
Dynamic expansion: Add neurons if capacity insufficient
Network split/duplication: Create task-specific paths
Drift for neuron j:
d_j = ||ΞΈ_j - ΞΈ*_j||Β² / (Ο_jΒ² + Ξ΅)
Expansion criterion: If validation loss doesnβt decrease, add neurons.
5. Meta-Learning for Continual LearningΒΆ
5.1 Meta-Experience Replay (MER)ΒΆ
MER [Riemer et al., 2019]: Meta-learn with replay buffer.
Inner loop: Update on current batch Outer loop: Evaluate on replay buffer, update to minimize replay loss
Algorithm:
ΞΈ' = ΞΈ - Ξ± β_ΞΈ L_current(ΞΈ)
ΞΈ β ΞΈ - Ξ² β_ΞΈ L_replay(ΞΈ')
Benefit: Meta-learned updates reduce forgetting.
5.2 Online-aware Meta-learning (OML)ΒΆ
OML [Javed & White, 2019]: Meta-learn representations for continual learning.
Objective: Learn representation that enables online learning:
min_ΞΈ E_{T~p(T)} [L_T(ΞΈ - Ξ± β_ΞΈ L_T^train(ΞΈ))]
Training: Sample task sequences, meta-optimize for online adaptation.
5.3 La-MAMLΒΆ
La-MAML [Gupta et al., 2020]: Look-ahead MAML for continual learning.
Key idea: Meta-learn initialization that resists forgetting when adapted.
Meta-objective:
min_ΞΈ Ξ£α΅’ L_i(ΞΈ_{1:i}) + Ξ» Ξ£α΅’<β±Ό L_i(ΞΈ_{1:j})
Optimizes both current and future performance on past tasks.
6. Generative ReplayΒΆ
6.1 Deep Generative ReplayΒΆ
DGR [Shin et al., 2017]: Train generative model alongside discriminative model.
Components:
Generator G: Synthesizes samples from old tasks
Solver S: Discriminative model
Training on task n:
Generate pseudo-samples from G for tasks 1, β¦, n-1
Train S on real samples (task n) + generated samples (tasks <n)
Update G to generate samples for task n
Loss:
L = L_real(x_n, y_n) + L_generated(G(z), S(G(z)))
6.2 Conditional GAN for ReplayΒΆ
Use class-conditional GAN: G(z, c) generates samples for class c.
Advantage: Can generate specific classes on demand.
Challenges:
Generator quality crucial
Mode collapse can lose diversity
Computational overhead
7. Theoretical FoundationsΒΆ
7.1 Catastrophic Forgetting AnalysisΒΆ
Linear models: Catastrophic forgetting occurs when new task gradients are orthogonal to old task gradients.
Neural networks: Over-parameterization helps (multiple solutions exist).
Theorem [Goodfellow et al., 2013]: For linear models, forgetting is unavoidable unless tasks share structure.
7.2 Stability-Plasticity Trade-offΒΆ
Formal definition:
Stability: ||ΞΈ_n - ΞΈ_{n-1}||Β² small
Plasticity: L_n(ΞΈ_n) small
Pareto frontier: Canβt minimize both simultaneously.
EWC perspective: Fisher matrix balances this trade-off.
7.3 Generalization BoundsΒΆ
PAC-Bayesian bound for continual learning:
With probability β₯ 1-Ξ΄:
Ξ£α΅’ L_i(ΞΈ_n) β€ Ξ£α΅’ LΜ_i(ΞΈ_n) + O(β(KL(Q||P) + log(n/Ξ΄)) / βN)
Where:
Q: Posterior after n tasks
P: Prior
N: Total samples
Implication: Regularization toward prior (e.g., EWC) helps generalization.
8. Evaluation MetricsΒΆ
8.1 Average AccuracyΒΆ
After learning n tasks:
ACC_n = (1/n) Ξ£α΅’βββΏ a_{n,i}
Where a_{n,i} is accuracy on task i after learning task n.
8.2 Forgetting MeasureΒΆ
Average forgetting:
FM_n = (1/(n-1)) Ξ£α΅’βββΏβ»ΒΉ (max_j a_{j,i} - a_{n,i})
Measures maximum drop in performance on each task.
8.3 Backward Transfer (BWT)ΒΆ
BWT: How much learning new tasks affects old tasks:
BWT_n = (1/(n-1)) Ξ£α΅’βββΏβ»ΒΉ (a_{n,i} - a_{i,i})
Negative BWT indicates forgetting.
8.4 Forward Transfer (FWT)ΒΆ
FWT: How much old tasks help learn new tasks:
FWT_n = (1/(n-1)) Ξ£α΅’βββΏ (a_{i-1,i} - a_{0,i})
Where a_{0,i} is random initialization performance.
Positive FWT indicates transfer learning.
9. Advanced TechniquesΒΆ
9.1 Task-Free Continual LearningΒΆ
Challenge: No explicit task boundaries.
Approaches:
Boundary detection: Detect distribution shift
Pseudo-rehearsal: Generate samples continuously
Uncertainty-based: Use prediction uncertainty to detect new data
9.2 Online Continual LearningΒΆ
Setting: One pass through data stream, no task boundaries.
Challenges:
No clear training/testing split
Must decide when to update
Memory budget constraints
Metrics: Average accuracy over time, anytime performance.
9.3 Continual Learning with Imbalanced DataΒΆ
Problem: New classes may have few samples (long-tail distribution).
Solutions:
Class balancing: Oversample rare classes in replay
Focal loss: Focus on hard examples
Two-stage training: Learn representations, then classifier
9.4 Continual Learning under Label NoiseΒΆ
Noisy labels in stream:
Robust loss: Use symmetric loss functions
Sample selection: Filter likely mislabeled samples
Co-teaching: Two networks teach each other
10. Practical ConsiderationsΒΆ
10.1 Memory BudgetΒΆ
Typical budgets: 500-5000 samples (vs. millions in full dataset).
Strategies:
Per-class: K samples per class
Per-task: Fixed budget per task
Global: Total budget across all tasks
Update policy:
Reservoir sampling: Probabilistic replacement
Ring buffer: FIFO
Coreset selection: Optimize representativeness
10.2 Computational EfficiencyΒΆ
EWC: O(|ΞΈ|) extra memory (Fisher diagonal) GEM: O(nΒ·batch_size) memory (gradients per task) Replay: O(buffer_size) memory
Training time:
Regularization: ~1.2Γ base
Replay: ~1.5Γ base (due to extra forward passes)
Dynamic architectures: ~2Γ base
10.3 Hyperparameter SensitivityΒΆ
Critical hyperparameters:
Method |
Key Hyperparameters |
Typical Range |
|---|---|---|
EWC |
Ξ» (importance) |
10Β²-10β΅ |
Replay |
Buffer size |
500-5000 |
GEM |
Memory per task |
50-500 |
LwF |
Ξ± (distillation) |
1-10 |
Tuning: Use validation set from current task (no access to old tasks).
11. Benchmark DatasetsΒΆ
11.1 Permuted MNISTΒΆ
Setup: 10 tasks, each with permuted pixels.
Evaluation: Task-IL (10-way classification per task).
Baseline: ~90% after 10 tasks (naive fine-tuning: ~20%).
11.2 Split CIFAR-100ΒΆ
Setup: 10 or 20 tasks, disjoint classes.
Variants:
10 tasks Γ 10 classes
20 tasks Γ 5 classes
Challenge: Class-IL (100-way classification without task ID).
11.3 CORe50ΒΆ
CORe50: Continual learning on real-world objects.
Properties:
50 objects, 11 sessions
Indoor/outdoor lighting changes
164,866 images
Scenarios: NI (new instances), NC (new classes), NIC (both).
12. Comparison of ApproachesΒΆ
12.1 Performance SummaryΒΆ
Forgetting (lower is better):
Method |
Permuted MNIST |
Split CIFAR-100 |
|---|---|---|
Fine-tuning |
85% |
75% |
EWC |
15% |
35% |
SI |
18% |
40% |
GEM |
2% |
20% |
A-GEM |
8% |
28% |
ER (500) |
5% |
25% |
DER |
3% |
18% |
12.2 Memory-Performance Trade-offΒΆ
Memory usage vs. Final accuracy:
No replay: 0 extra memory, 40-50% accuracy
EWC: |ΞΈ| floats, 60-70% accuracy
Replay (100): 100 samples, 70-75% accuracy
Replay (1000): 1000 samples, 80-85% accuracy
Progressive: O(nΒ²) params, 90-95% accuracy
Insight: Replay is most memory-efficient for good performance.
12.3 Computational CostΒΆ
Relative training time (vs. base):
EWC: 1.1-1.3Γ
SI: 1.2-1.4Γ
Replay: 1.4-1.8Γ
GEM: 2-3Γ
Progressive: 1.5-2Γ
13. Recent Advances (2020-2024)ΒΆ
13.1 Self-Supervised Continual LearningΒΆ
Approach: Learn representations via contrastive learning, then fine-tune.
Benefits: Better transfer, less catastrophic forgetting.
Methods: SimCLR + replay, MoCo + EWC.
13.2 Transformer-Based Continual LearningΒΆ
L2P [Wang et al., 2022]: Learning to Prompt for continual learning.
Idea: Learn prompt pool, select relevant prompts for each task.
DualPrompt [Wang et al., 2022]: Separate prompts for general and task-specific knowledge.
13.3 Continual Pre-trainingΒΆ
Scenario: Continually pre-train large language models on new data.
Challenges:
Forgetting common knowledge
Domain shift
Computational cost
Solutions: Replay, regularization, modular architectures.
13.4 Federated Continual LearningΒΆ
Setting: Multiple clients, each learning continually + federated aggregation.
Challenges:
Heterogeneous task sequences
Communication efficiency
Privacy constraints
14. Limitations and Open ProblemsΒΆ
14.1 Current LimitationsΒΆ
Task boundaries: Most methods assume clear task boundaries
Memory requirements: Replay needs storage
Scalability: Many methods donβt scale to 100+ tasks
Theoretical understanding: Limited guarantees
14.2 Open QuestionsΒΆ
Optimal replay strategy: What to store? How to use?
Architecture design: Fixed vs. dynamic? How much capacity?
Task similarity: How to measure? How to leverage?
Evaluation: What metrics best capture continual learning ability?
14.3 Future DirectionsΒΆ
Biologically-inspired: Complementary learning systems (hippocampus + neocortex)
Curriculum: Learn tasks in optimal order
Meta-continual learning: Learn to continually learn
Compositional: Reuse learned modules
15. Key TakeawaysΒΆ
Catastrophic forgetting is the central challenge in continual learning
Regularization (EWC, SI, MAS) protects important parameters
Replay stores samples from old tasks, most effective approach
GEM/A-GEM constrain gradients to not harm old tasks
Dynamic architectures avoid forgetting but grow over time
Meta-learning can improve continual learning ability
Memory-performance trade-off: More memory β less forgetting
Evaluation metrics: ACC, FM, BWT, FWT capture different aspects
Benchmarks: Permuted MNIST (easy), Split CIFAR (medium), CORe50 (hard)
No silver bullet: Best method depends on scenario and constraints
Core insight: Balance stability (preserving old knowledge) and plasticity (learning new knowledge).
16. Mathematical SummaryΒΆ
Continual learning objective:
min_ΞΈ Ξ£α΅’βββΏ L_i(ΞΈ) + R(ΞΈ)
Where R(ΞΈ) is regularization.
EWC regularization:
R(ΞΈ) = (Ξ»/2) Ξ£α΅’ F_i (ΞΈα΅’ - ΞΈ*_{old,i})Β²
GEM constraint:
β¨βL_current(ΞΈ), βL_old(ΞΈ)β© β₯ 0
Replay objective:
L(ΞΈ) = E_{(x,y)~D_current} [β(x,y,ΞΈ)] + E_{(x,y)~M} [β(x,y,ΞΈ)]
ReferencesΒΆ
Kirkpatrick et al. (2017) βOvercoming Catastrophic Forgetting in Neural Networks (EWC)β
Zenke et al. (2017) βContinual Learning Through Synaptic Intelligence (SI)β
Lopez-Paz & Ranzato (2017) βGradient Episodic Memory for Continual Learning (GEM)β
Rusu et al. (2016) βProgressive Neural Networksβ
Shin et al. (2017) βContinual Learning with Deep Generative Replayβ
Chaudhry et al. (2019) βEfficient Lifelong Learning with A-GEMβ
Buzzega et al. (2020) βDark Experience for General Continual Learning (DER)β
Riemer et al. (2019) βLearning to Learn without Forgetting by Maximizing Transfer and Minimizing Interference (MER)β
De Lange et al. (2021) βA Continual Learning Survey: Defying Forgetting in Classification Tasksβ
Wang et al. (2022) βLearning to Prompt for Continual Learning (L2P)β
"""
Complete Continual Learning Implementations
===========================================
Includes: EWC, SI, GEM, A-GEM, Experience Replay, DER, LwF,
memory management, evaluation metrics.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
import copy
from collections import defaultdict
# ============================================================================
# 1. Memory Buffer for Replay
# ============================================================================
class ReplayBuffer:
"""
Memory buffer for experience replay.
Supports multiple storage strategies:
- Reservoir sampling
- Ring buffer (FIFO)
- Class-balanced
"""
def __init__(self, max_size=500, strategy='reservoir'):
self.max_size = max_size
self.strategy = strategy
self.buffer = []
self.num_seen = 0
# For class-balanced storage
self.class_buffers = defaultdict(list)
def add(self, x, y, task_id=None):
"""Add sample to buffer."""
sample = {'x': x.cpu(), 'y': y.cpu(), 'task_id': task_id}
if self.strategy == 'reservoir':
self._reservoir_add(sample)
elif self.strategy == 'ring':
self._ring_add(sample)
elif self.strategy == 'class_balanced':
self._class_balanced_add(sample, y.item())
def _reservoir_add(self, sample):
"""Reservoir sampling: Probabilistic replacement."""
if len(self.buffer) < self.max_size:
self.buffer.append(sample)
else:
# Replace with probability max_size / num_seen
idx = np.random.randint(0, self.num_seen + 1)
if idx < self.max_size:
self.buffer[idx] = sample
self.num_seen += 1
def _ring_add(self, sample):
"""Ring buffer: FIFO replacement."""
if len(self.buffer) < self.max_size:
self.buffer.append(sample)
else:
idx = self.num_seen % self.max_size
self.buffer[idx] = sample
self.num_seen += 1
def _class_balanced_add(self, sample, class_id):
"""Class-balanced: Equal samples per class."""
self.class_buffers[class_id].append(sample)
# Rebuild buffer with balanced sampling
self.buffer = []
classes = list(self.class_buffers.keys())
samples_per_class = self.max_size // len(classes)
for c in classes:
samples = self.class_buffers[c]
# Take last samples_per_class samples
selected = samples[-samples_per_class:] if len(samples) > samples_per_class else samples
self.buffer.extend(selected)
def sample(self, batch_size):
"""Sample batch from buffer."""
if len(self.buffer) == 0:
return None
indices = np.random.choice(len(self.buffer),
min(batch_size, len(self.buffer)),
replace=False)
samples = [self.buffer[i] for i in indices]
x = torch.stack([s['x'] for s in samples])
y = torch.stack([s['y'] for s in samples])
task_ids = [s['task_id'] for s in samples]
return x, y, task_ids
def __len__(self):
return len(self.buffer)
# ============================================================================
# 2. Elastic Weight Consolidation (EWC)
# ============================================================================
class EWC:
"""
Elastic Weight Consolidation.
Protects important parameters using Fisher Information.
"""
def __init__(self, model, dataloader, device='cuda', lambda_ewc=1000):
self.model = model
self.device = device
self.lambda_ewc = lambda_ewc
# Store parameters after each task
self.old_params = {}
# Fisher information
self.fisher = {}
# Compute Fisher information
self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader):
"""Compute diagonal Fisher Information Matrix."""
self.model.eval()
# Initialize Fisher
for name, param in self.model.named_parameters():
self.fisher[name] = torch.zeros_like(param)
# Accumulate gradients
for x, y in dataloader:
x, y = x.to(self.device), y.to(self.device)
self.model.zero_grad()
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Accumulate squared gradients
for name, param in self.model.named_parameters():
if param.grad is not None:
self.fisher[name] += param.grad.pow(2)
# Average over dataset
num_samples = len(dataloader.dataset)
for name in self.fisher:
self.fisher[name] /= num_samples
def save_parameters(self):
"""Save current parameters."""
for name, param in self.model.named_parameters():
self.old_params[name] = param.data.clone()
def penalty(self):
"""Compute EWC penalty."""
if len(self.old_params) == 0:
return 0
loss = 0
for name, param in self.model.named_parameters():
if name in self.fisher:
loss += (self.fisher[name] *
(param - self.old_params[name]).pow(2)).sum()
return (self.lambda_ewc / 2) * loss
# ============================================================================
# 3. Synaptic Intelligence (SI)
# ============================================================================
class SynapticIntelligence:
"""
Synaptic Intelligence.
Measures importance via path integral.
"""
def __init__(self, model, c=0.1, epsilon=1e-3):
self.model = model
self.c = c
self.epsilon = epsilon
# Importance (omega)
self.omega = {}
# Running sum of parameter changes
self.W = {}
# Previous parameters
self.prev_params = {}
# Initialize
for name, param in model.named_parameters():
self.omega[name] = torch.zeros_like(param)
self.W[name] = torch.zeros_like(param)
self.prev_params[name] = param.data.clone()
def update_omega(self):
"""Update importance after task completion."""
for name, param in self.model.named_parameters():
# Compute parameter change
delta = param.data - self.prev_params[name]
# Update importance
self.omega[name] += self.W[name] / (delta.pow(2) + self.epsilon)
# Reset for next task
self.W[name] = torch.zeros_like(param)
self.prev_params[name] = param.data.clone()
def accumulate_gradient(self):
"""Accumulate gradient contribution (call after each batch)."""
for name, param in self.model.named_parameters():
if param.grad is not None:
# Accumulate -gradient * delta_param
delta = param.data - self.prev_params[name]
self.W[name] -= param.grad * delta
def penalty(self):
"""Compute SI penalty."""
loss = 0
for name, param in self.model.named_parameters():
if name in self.omega:
prev = self.prev_params[name]
loss += (self.omega[name] * (param - prev).pow(2)).sum()
return self.c * loss
# ============================================================================
# 4. Gradient Episodic Memory (GEM)
# ============================================================================
class GEM:
"""
Gradient Episodic Memory.
Projects gradients to not harm old tasks.
"""
def __init__(self, model, memory_budget_per_task=100):
self.model = model
self.memory_budget = memory_budget_per_task
# Memory for each task
self.memory = {}
def store_task_memory(self, task_id, dataloader, device='cuda'):
"""Store samples for a task."""
samples_x = []
samples_y = []
count = 0
for x, y in dataloader:
if count >= self.memory_budget:
break
batch_size = x.size(0)
remaining = self.memory_budget - count
take = min(batch_size, remaining)
samples_x.append(x[:take])
samples_y.append(y[:take])
count += take
self.memory[task_id] = {
'x': torch.cat(samples_x).to(device),
'y': torch.cat(samples_y).to(device)
}
def get_gradient(self, task_id, device='cuda'):
"""Compute gradient on task's memory."""
if task_id not in self.memory:
return None
x = self.memory[task_id]['x']
y = self.memory[task_id]['y']
self.model.zero_grad()
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Extract gradient
grad = torch.cat([p.grad.view(-1) for p in self.model.parameters()
if p.grad is not None])
return grad
def project_gradient(self, current_grad, device='cuda'):
"""Project current gradient to not harm old tasks."""
# Get gradients for all previous tasks
task_grads = []
for task_id in self.memory.keys():
g = self.get_gradient(task_id, device)
if g is not None:
task_grads.append(g)
if len(task_grads) == 0:
return current_grad
# Stack gradients
G = torch.stack(task_grads) # [num_tasks, param_dim]
# Check constraints
violations = (G @ current_grad < 0).float()
if violations.sum() == 0:
return current_grad
# Solve QP: min ||g - current_grad||^2 s.t. G @ g >= 0
# Closed-form solution using Lagrange multipliers
# For simplicity, use projection onto average constraint
# (A-GEM approximation)
avg_grad = G.mean(dim=0)
if (avg_grad @ current_grad) < 0:
# Project
proj_coef = (current_grad @ avg_grad) / (avg_grad @ avg_grad)
projected_grad = current_grad - proj_coef * avg_grad
return projected_grad
return current_grad
# ============================================================================
# 5. Averaged GEM (A-GEM)
# ============================================================================
class AGEM:
"""
Averaged GEM: Simplified version using average gradient.
"""
def __init__(self, model, buffer_size=500):
self.model = model
self.buffer = ReplayBuffer(max_size=buffer_size)
def store_sample(self, x, y, task_id=None):
"""Store sample in buffer."""
self.buffer.add(x, y, task_id)
def get_reference_gradient(self, device='cuda'):
"""Compute gradient on memory buffer."""
if len(self.buffer) == 0:
return None
# Sample from buffer
samples = self.buffer.sample(min(256, len(self.buffer)))
if samples is None:
return None
x, y, _ = samples
x, y = x.to(device), y.to(device)
self.model.zero_grad()
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Extract gradient
grad = torch.cat([p.grad.view(-1) for p in self.model.parameters()
if p.grad is not None])
return grad
def project_gradient(self, current_grad, device='cuda'):
"""Project if violates average constraint."""
ref_grad = self.get_reference_gradient(device)
if ref_grad is None:
return current_grad
# Check constraint
dot_product = (current_grad @ ref_grad).item()
if dot_product >= 0:
return current_grad
# Project
proj_coef = dot_product / (ref_grad @ ref_grad)
projected_grad = current_grad - proj_coef * ref_grad
return projected_grad
# ============================================================================
# 6. Dark Experience Replay (DER)
# ============================================================================
class DarkExperienceReplay:
"""
Dark Experience Replay: Store logits instead of labels.
"""
def __init__(self, model, buffer_size=500, alpha=0.5):
self.model = model
self.buffer_size = buffer_size
self.alpha = alpha # Weight for distillation loss
# Store (x, logits, task_id)
self.buffer = []
self.num_seen = 0
def add(self, x, logits, task_id=None):
"""Add sample with logits to buffer."""
sample = {
'x': x.cpu(),
'logits': logits.cpu(),
'task_id': task_id
}
if len(self.buffer) < self.buffer_size:
self.buffer.append(sample)
else:
# Reservoir sampling
idx = np.random.randint(0, self.num_seen + 1)
if idx < self.buffer_size:
self.buffer[idx] = sample
self.num_seen += 1
def sample(self, batch_size, device='cuda'):
"""Sample batch from buffer."""
if len(self.buffer) == 0:
return None
indices = np.random.choice(len(self.buffer),
min(batch_size, len(self.buffer)),
replace=False)
samples = [self.buffer[i] for i in indices]
x = torch.stack([s['x'] for s in samples]).to(device)
logits = torch.stack([s['logits'] for s in samples]).to(device)
return x, logits
def loss(self, model, x_replay, logits_replay):
"""Compute DER loss on replay samples."""
output = model(x_replay)
return F.mse_loss(output, logits_replay)
# ============================================================================
# 7. Learning without Forgetting (LwF)
# ============================================================================
class LearningWithoutForgetting:
"""
Learning without Forgetting: Knowledge distillation.
"""
def __init__(self, model, alpha=1.0, temperature=2.0):
self.model = model
self.alpha = alpha
self.temperature = temperature
# Store old model
self.old_model = None
def save_model(self):
"""Save current model as old model."""
self.old_model = copy.deepcopy(self.model)
self.old_model.eval()
# Freeze old model
for param in self.old_model.parameters():
param.requires_grad = False
def distillation_loss(self, x, device='cuda'):
"""Compute distillation loss."""
if self.old_model is None:
return 0
# Old model predictions
with torch.no_grad():
old_logits = self.old_model(x)
# New model predictions
new_logits = self.model(x)
# KL divergence with temperature scaling
T = self.temperature
old_probs = F.softmax(old_logits / T, dim=1)
new_log_probs = F.log_softmax(new_logits / T, dim=1)
loss = F.kl_div(new_log_probs, old_probs, reduction='batchmean') * (T ** 2)
return self.alpha * loss
# ============================================================================
# 8. Evaluation Metrics
# ============================================================================
class ContinualLearningMetrics:
"""Compute continual learning metrics."""
def __init__(self, num_tasks):
self.num_tasks = num_tasks
# Accuracy matrix: acc[i][j] = accuracy on task j after learning task i
self.acc = [[0.0 for _ in range(num_tasks)] for _ in range(num_tasks)]
def update(self, task_trained, task_evaluated, accuracy):
"""Update accuracy matrix."""
self.acc[task_trained][task_evaluated] = accuracy
def average_accuracy(self, task_n):
"""Average accuracy after learning task n."""
return np.mean([self.acc[task_n][i] for i in range(task_n + 1)])
def forgetting_measure(self, task_n):
"""Average forgetting after learning task n."""
if task_n == 0:
return 0.0
forgetting = []
for i in range(task_n):
# Max accuracy on task i across all previous training
max_acc = max(self.acc[j][i] for j in range(i, task_n + 1))
# Current accuracy on task i
current_acc = self.acc[task_n][i]
forgetting.append(max_acc - current_acc)
return np.mean(forgetting)
def backward_transfer(self, task_n):
"""Backward transfer after learning task n."""
if task_n == 0:
return 0.0
bwt = []
for i in range(task_n):
# Accuracy after learning all tasks - accuracy right after task i
bwt.append(self.acc[task_n][i] - self.acc[i][i])
return np.mean(bwt)
def forward_transfer(self, task_n):
"""Forward transfer (zero-shot performance on new tasks)."""
if task_n == 0:
return 0.0
fwt = []
for i in range(1, task_n + 1):
# Accuracy before training on task i - random baseline
fwt.append(self.acc[i-1][i] - self.acc[0][i])
return np.mean(fwt)
def print_summary(self, task_n):
"""Print metrics summary."""
print(f"\nMetrics after Task {task_n}:")
print(f" Average Accuracy: {self.average_accuracy(task_n):.2f}%")
print(f" Forgetting Measure: {self.forgetting_measure(task_n):.2f}%")
print(f" Backward Transfer: {self.backward_transfer(task_n):.2f}%")
if task_n > 0:
print(f" Forward Transfer: {self.forward_transfer(task_n):.2f}%")
# ============================================================================
# 9. Method Comparison
# ============================================================================
def print_method_comparison():
"""Print comparison of continual learning methods."""
print("="*70)
print("Continual Learning Methods Comparison")
print("="*70)
print()
comparison = """
ββββββββββββββββ¬βββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Type β Memory β Forgetting β Complexity β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β Fine-tuning β Baseline β None β Very High β 1Γ β
β β β β (~80%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β EWC β Regular. β O(|ΞΈ|) β Medium β 1.2Γ β
β β β (Fisher) β (~15%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β SI β Regular. β O(|ΞΈ|) β Medium β 1.3Γ β
β β β (Importance) β (~18%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β LwF β Distill. β O(|ΞΈ|) β Medium-High β 1.5Γ β
β β β (Old model) β (~25%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β ER (500) β Replay β 500 samples β Low β 1.5Γ β
β β β β (~5%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β DER (500) β Replay β 500 samples β Very Low β 1.6Γ β
β β β + logits β (~3%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β GEM β Replay + β n Γ batch β Very Low β 2-3Γ β
β β Constraint β β (~2%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β A-GEM β Replay + β 500 samples β Low β 1.7Γ β
β β Constraint β β (~8%) β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β Progressive β Dynamic β O(nΒ²|ΞΈ|) β None β 2Γ β
β β Arch. β β (0%) β β
ββββββββββββββββ΄βββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
**Trade-offs:**
Memory vs. Forgetting:
- No memory β High forgetting (80%+)
- Small memory (500) β Low forgetting (3-8%)
- Large memory (5000) β Minimal forgetting (<1%)
Computational Cost:
- Regularization: Low overhead (1.2-1.5Γ)
- Replay: Medium overhead (1.5-2Γ)
- Constraint-based: High overhead (2-3Γ)
**Best Practices:**
1. **Limited memory budget**: Use DER or A-GEM
2. **No memory allowed**: Use EWC or SI
3. **Maximum performance**: Use replay with large buffer
4. **Task-IL setting**: Progressive networks (no forgetting)
5. **Class-IL setting**: DER or GEM (handle confusion well)
**Typical Performance (5 tasks, Split CIFAR-10):**
Method | Final Acc | Forgetting | Memory (MB)
----------------|-----------|------------|------------
Fine-tuning | 45% | 82% | 0
EWC (Ξ»=1000) | 68% | 15% | 10
SI (c=0.1) | 65% | 18% | 10
ER (500) | 78% | 5% | 50
DER (500) | 82% | 3% | 55
GEM (100/task) | 85% | 2% | 25
A-GEM (500) | 75% | 8% | 50
"""
print(comparison)
print()
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
print("="*70)
print("Continual Learning Implementations")
print("="*70)
print()
# Simple model for demonstration
class SimpleModel(nn.Module):
def __init__(self, input_dim=784, num_classes=10):
super(SimpleModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
return self.net(x.view(x.size(0), -1))
model = SimpleModel()
# Test Replay Buffer
print("Testing Replay Buffer...")
buffer = ReplayBuffer(max_size=100, strategy='reservoir')
for i in range(200):
x = torch.randn(28, 28)
y = torch.tensor(i % 10)
buffer.add(x, y, task_id=i//50)
print(f" Buffer size: {len(buffer)}")
samples = buffer.sample(32)
if samples:
x, y, task_ids = samples
print(f" Sampled batch: {x.shape}, {y.shape}")
print()
# Test EWC
print("Testing EWC...")
# Create dummy dataloader
dummy_data = [(torch.randn(32, 28, 28), torch.randint(0, 10, (32,))) for _ in range(10)]
ewc = EWC(model, dummy_data, device='cpu', lambda_ewc=1000)
ewc.save_parameters()
penalty = ewc.penalty()
print(f" EWC penalty: {penalty.item():.4f}")
print()
# Test SI
print("Testing Synaptic Intelligence...")
si = SynapticIntelligence(model, c=0.1)
si.update_omega()
penalty = si.penalty()
print(f" SI penalty: {penalty.item():.4f}")
print()
# Test DER
print("Testing Dark Experience Replay...")
der = DarkExperienceReplay(model, buffer_size=100, alpha=0.5)
x = torch.randn(16, 28, 28)
with torch.no_grad():
logits = model(x)
der.add(x[0], logits[0], task_id=0)
print(f" DER buffer size: {len(der.buffer)}")
print()
# Test Metrics
print("Testing Continual Learning Metrics...")
metrics = ContinualLearningMetrics(num_tasks=5)
# Simulate accuracy matrix
for i in range(5):
for j in range(i + 1):
# Simulated accuracy (decreases for old tasks)
acc = 90 - (i - j) * 5 + np.random.randn() * 2
metrics.update(i, j, acc)
metrics.print_summary(4)
print()
print_method_comparison()
print("="*70)
print("Continual Learning Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ EWC: Protect important parameters via Fisher Information")
print(" β’ SI: Measure importance via path integral")
print(" β’ GEM/A-GEM: Constrain gradients to not harm old tasks")
print(" β’ Replay: Store and replay old samples (most effective)")
print(" β’ DER: Store logits for richer replay signal")
print(" β’ LwF: Knowledge distillation from old model")
print()
print("Key insight: Balance stability (preserve old) and plasticity (learn new)")
print("Trade-off: Memory budget vs. forgetting vs. computational cost")
print("Applications: Robotics, personalization, streaming data")
print()
Advanced Continual Learning: Mathematical Foundations and Modern ApproachesΒΆ
1. Introduction to Continual LearningΒΆ
Continual Learning (also called Lifelong Learning or Incremental Learning) addresses a fundamental challenge in machine learning: how to learn sequentially from a stream of tasks without forgetting previous knowledge.
1.1 The Catastrophic Forgetting ProblemΒΆ
When a neural network is trained on Task A and then trained on Task B, performance on Task A drastically degrades. This phenomenon is called catastrophic forgetting or catastrophic interference.
Mathematical formulation:
Given a sequence of tasks \(\mathcal{T} = \{T_1, T_2, ..., T_N\}\), each with dataset \(\mathcal{D}_i = \{(x_j, y_j)\}_{j=1}^{n_i}\), the goal is to learn parameters \(\theta\) that minimize:
subject to:
No access to previous task data (memory constraints)
Bounded model capacity (cannot grow infinitely)
Limited computational budget per task
1.2 Stability-Plasticity DilemmaΒΆ
Stability: Preserve knowledge of old tasks
Plasticity: Adapt to new tasks
Too much stability: Cannot learn new tasks
Too much plasticity: Forget old tasks immediately
1.3 Continual Learning ScenariosΒΆ
Task-Incremental Learning (TIL):
Task identity known at test time
Evaluate: \(\text{Acc}_i = \text{Accuracy on } T_i \text{ with task ID}\)
Domain-Incremental Learning (DIL):
Same task, different domains (e.g., different datasets for same classes)
No task ID at test time
Class-Incremental Learning (CIL):
Most challenging: New classes arrive over time
No task ID at test time
Must distinguish all classes seen so far
Example (Class-Incremental):
Task 1: Learn classes {0, 1, 2, 3, 4}
Task 2: Learn classes {5, 6, 7, 8, 9}
Test: Classify any digit 0-9 without knowing which task
1.4 Evaluation MetricsΒΆ
Average Accuracy: $\(\text{ACC} = \frac{1}{N} \sum_{i=1}^N a_{N,i}\)\( where \)a_{N,i}\( is accuracy on task \)i\( after learning task \)N$.
Forgetting Measure: $\(\text{FM} = \frac{1}{N-1} \sum_{i=1}^{N-1} \max_{j \in \{i, ..., N-1\}} (a_{j,i} - a_{N,i})\)$
Backward Transfer: $\(\text{BWT} = \frac{1}{N-1} \sum_{i=1}^{N-1} (a_{N,i} - a_{i,i})\)$
Negative BWT indicates forgetting.
Forward Transfer: $\(\text{FWT} = \frac{1}{N-1} \sum_{i=2}^{N} (a_{i-1,i} - a_{\text{random},i})\)$
Positive FWT indicates knowledge transfer to new tasks.
2. Regularization-Based MethodsΒΆ
2.1 Elastic Weight Consolidation (EWC)ΒΆ
Key idea: Protect important parameters for old tasks using Fisher Information Matrix.
Loss function: $\(\mathcal{L}(\theta) = \mathcal{L}_B(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{A,i}^*)^2\)$
where:
\(\mathcal{L}_B\): Loss on current task B
\(\theta_{A}^*\): Optimal parameters after task A
\(F_i\): Fisher Information for parameter \(i\)
\(\lambda\): Regularization strength
Fisher Information Matrix: $\(F_i = \mathbb{E}_{x \sim \mathcal{D}_A} \left[\left(\frac{\partial \log p(y|x, \theta_{A}^*)}{\partial \theta_i}\right)^2\right]\)$
Approximation (using samples): $\(F_i \approx \frac{1}{|\mathcal{D}_A|} \sum_{(x,y) \in \mathcal{D}_A} \left(\frac{\partial \log p(y|x, \theta_{A}^*)}{\partial \theta_i}\right)^2\)$
Intuition:
High \(F_i\) β parameter \(i\) is important for task A β large penalty if changed
Low \(F_i\) β parameter \(i\) can be freely updated for task B
Advantages:
Principled approach based on information theory
No need to store old data
Works for any model architecture
Limitations:
Requires storing Fisher matrix (same size as model)
Assumes tasks donβt overlap (diagonal Fisher approximation)
Performance degrades with many tasks
2.2 Synaptic Intelligence (SI)ΒΆ
Key idea: Track parameter importance online during training.
Importance (accumulated during task training): $\(\omega_i = \sum_{k=1}^K \frac{g_i^{(k)} \cdot \Delta \theta_i^{(k)}}{\xi + \Delta \theta_i^{(k)^2}}\)$
where:
\(g_i^{(k)}\): Gradient at step \(k\)
\(\Delta \theta_i^{(k)}\): Parameter update at step \(k\)
\(\xi\): Small damping constant
Regularization: $\(\mathcal{L}(\theta) = \mathcal{L}_{\text{task}}(\theta) + c \sum_i \omega_i (\theta_i - \theta_i^*)^2\)$
Advantages over EWC:
Online computation (no need for separate Fisher calculation)
Captures parameter trajectories during training
More stable importance estimates
2.3 Memory Aware Synapses (MAS)ΒΆ
Key idea: Measure importance by output sensitivity (not loss).
Importance: $\(\Omega_i = \frac{1}{|\mathcal{D}|} \sum_{x \in \mathcal{D}} \left\|\frac{\partial f(x; \theta)}{\partial \theta_i}\right\|^2\)$
Regularization (same form as EWC): $\(\mathcal{L}(\theta) = \mathcal{L}_{\text{task}}(\theta) + \lambda \sum_i \Omega_i (\theta_i - \theta_i^*)^2\)$
Key difference from EWC:
EWC: Uses gradients of loss (supervised)
MAS: Uses gradients of output (can be unsupervised)
2.4 Learning without Forgetting (LwF)ΒΆ
Key idea: Use knowledge distillation to preserve old task outputs.
Loss for new task: $\(\mathcal{L} = \mathcal{L}_{\text{new}}(\theta) + \lambda \mathcal{L}_{\text{distill}}(\theta, \theta_{\text{old}})\)$
Distillation loss: $\(\mathcal{L}_{\text{distill}} = \text{KL}\left(p_{\theta_{\text{old}}}(y|x) \,\|\, p_{\theta}(y|x)\right)\)$
where \(p_\theta(y|x) = \text{softmax}(f_\theta(x) / T)\) with temperature \(T\).
Advantages:
No stored data needed
Works for task-incremental and class-incremental
Preserves function learned, not just parameters
Limitations:
Requires storing old model (or outputs)
May not scale to many tasks
Assumes output space remains relevant
3. Replay-Based MethodsΒΆ
3.1 Experience Replay (ER)ΒΆ
Key idea: Store subset of old data, interleave with new data during training.
Memory buffer: \(\mathcal{M} = \{(x_i, y_i)\}_{i=1}^M\) with \(M \ll \sum_i |\mathcal{D}_i|\)
Training loss: $\(\mathcal{L} = \mathbb{E}_{(x,y) \sim \mathcal{D}_{\text{new}}}[\ell(f(x), y)] + \lambda \mathbb{E}_{(x,y) \sim \mathcal{M}}[\ell(f(x), y)]\)$
Sampling strategies:
Random: Uniform sampling from memory
Reservoir sampling: Maintain uniform distribution over all seen data
Class-balanced: Equal samples per class
Gradient-based: Store samples with largest gradient norm
Advantages:
Simple and effective
Directly addresses forgetting
Compatible with any model
Limitations:
Requires storing data (privacy, memory constraints)
Performance limited by buffer size
May not generalize well to unseen data
3.2 Generative ReplayΒΆ
Key idea: Train generative model to synthesize old task data.
Components:
Generator \(G_\phi\): Synthesizes samples from old tasks
Solver \(f_\theta\): Main task network
Training:
On task \(t\), sample pseudo-data: \(\tilde{x} \sim G_\phi\)
Get pseudo-labels: \(\tilde{y} = f_{\theta_{t-1}}(\tilde{x})\)
Train on mix of real and pseudo data: $\(\mathcal{L} = \mathbb{E}_{(x,y) \sim \mathcal{D}_t}[\ell(f(x), y)] + \lambda \mathbb{E}_{\tilde{x} \sim G}[\ell(f(\tilde{x}), \tilde{y})]\)$
Variants:
Conditional GAN: \(G(z, y)\) for class-conditional generation
VAE-based: Use variational autoencoder for stable generation
Deep Generative Replay: Generator also learns continually
Advantages:
No raw data storage (privacy-preserving)
Unlimited synthetic samples
Can generate diverse samples
Limitations:
Generator quality affects performance
Computational overhead for generation
May not capture full data distribution
3.3 Meta-Experience Replay (MER)ΒΆ
Key idea: Combine replay with meta-learning for better sample efficiency.
Meta-objective: $\(\min_\theta \mathbb{E}_{\mathcal{B} \sim \mathcal{M}} [\mathcal{L}(\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{D}_t}(\theta), \mathcal{B})]\)$
Algorithm:
Sample batch from new task: \(\mathcal{B}_{\text{new}} \sim \mathcal{D}_t\)
Sample batch from memory: \(\mathcal{B}_{\text{mem}} \sim \mathcal{M}\)
Compute adapted parameters: \(\theta' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{B}_{\text{new}}}(\theta)\)
Update using memory: \(\theta \leftarrow \theta - \beta \nabla_\theta \mathcal{L}_{\mathcal{B}_{\text{mem}}}(\theta')\)
Benefits:
Better generalization with limited memory
Prevents overfitting to replayed samples
Faster adaptation to new tasks
4. Architecture-Based MethodsΒΆ
4.1 Progressive Neural NetworksΒΆ
Key idea: Allocate new columns for new tasks, freeze old columns.
Architecture: $\(h_i^{(t)} = f\left(W_i^{(t)} h_{i-1}^{(t)} + \sum_{j<t} U_i^{(t:j)} h_{i-1}^{(j)}\right)\)$
where:
\(h_i^{(t)}\): Activation at layer \(i\) for task \(t\)
\(W_i^{(t)}\): Weights within task \(t\) column
\(U_i^{(t:j)}\): Lateral connections from task \(j\) to task \(t\)
Properties:
Zero forgetting (old parameters frozen)
Transfer learning via lateral connections
Model grows linearly with tasks
Advantages:
Guaranteed no forgetting
Automatic transfer learning
Simple to implement
Limitations:
Linear growth in parameters
No backward transfer (canβt improve old tasks)
Inefficient for many tasks
4.2 PacknetΒΆ
Key idea: Prune network for each task, pack multiple tasks in one network.
Algorithm:
Train network on task \(t\) to convergence
Prune \(k\%\) of least important weights (by magnitude)
Mark remaining weights as βused by task \(t\)β
For next task, only update βfreeβ weights
Repeat
Pruning criterion: $\(\text{Keep weight } w_i \text{ if } |w_i| \geq \text{threshold}\)$
Advantages:
Fixed model size
Can pack many tasks (if pruning is aggressive)
No forgetting (frozen weights)
Limitations:
Order matters (early tasks get best capacity)
Cannot prune too aggressively or accuracy suffers
No transfer learning between tasks
4.3 Dynamically Expandable Networks (DEN)ΒΆ
Key idea: Grow network capacity dynamically based on task requirements.
Operations:
Selective retraining: Retrain subset of neurons for new task
Dynamic expansion: Add neurons if capacity insufficient
Network split: Duplicate neurons if used by multiple tasks
Neuron selection: $\(s_k = \frac{1}{|\mathcal{D}|} \sum_{(x,y) \in \mathcal{D}} \left|\frac{\partial \mathcal{L}}{\partial h_k}\right|\)$
High \(s_k\) β neuron important for current task.
Sparse regularization: $\(\mathcal{L} = \mathcal{L}_{\text{task}} + \lambda_1 \|h\|_1 + \lambda_2 \|W\|_1\)$
Encourages sparse activations and weights.
Advantages:
Adaptive capacity
Better than fixed expansion
Efficient parameter usage
Limitations:
Complexity in managing neuron selection
Still grows over time
Hyperparameter sensitive
4.4 Expert Gate (Mixture of Experts for CL)ΒΆ
Key idea: Route inputs to task-specific experts.
Gating function: $\(p(e_i | x) = \frac{\exp(g_i(x))}{\sum_j \exp(g_j(x))}\)$
Output: $\(y = \sum_i p(e_i | x) \cdot f_{e_i}(x)\)$
Training:
Initialize expert for each task
Gate learns to route based on input
Can share some experts across tasks
Advantages:
Flexible task routing
Can handle task-incremental and class-incremental
Experts can be trained independently
Limitations:
Requires task boundaries during training
Gating function can be challenging to learn
More parameters than single model
5. Meta-Learning for Continual LearningΒΆ
5.1 Meta-Learning FrameworkΒΆ
Key idea: Learn to learn such that model can quickly adapt to new tasks with minimal forgetting.
MAML for Continual Learning: $\(\theta \leftarrow \theta - \beta \nabla_\theta \sum_{T_i \in \text{past}} \mathcal{L}_{T_i}(\theta - \alpha \nabla_\theta \mathcal{L}_{T_i}(\theta))\)$
Inner loop (task adaptation): $\(\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{T_i}(\theta)\)$
Outer loop (meta-update): $\(\theta \leftarrow \theta - \beta \nabla_\theta \sum_i \mathcal{L}_{T_i}(\theta_i')\)$
5.2 Online-aware Meta-Learning (OML)ΒΆ
Objective: $\(\min_\theta \mathbb{E}_{T \sim p(\mathcal{T})} \left[\sum_{t=1}^T \mathcal{L}_t(\theta_t)\right]\)$
where \(\theta_t\) is adapted from \(\theta_{t-1}\) after seeing task \(t\).
Meta-representation:
Learn representation that is easy to adapt
Minimize interference between tasks
Maximize forward transfer
5.3 Learning to Learn without Forgetting (L2LwF)ΒΆ
Combines meta-learning with knowledge distillation:
where \(\theta'\) is meta-learned initialization.
6. Memory and Retrieval MechanismsΒΆ
6.1 Memory-Augmented Neural Networks (MANN)ΒΆ
Neural Turing Machine for CL:
External memory \(M \in \mathbb{R}^{N \times M}\)
Read/write operations via attention
Read: $\(r_t = \sum_i w_t^r(i) M_t(i)\)$
Write: $\(M_t(i) = M_{t-1}(i) + w_t^w(i) a_t\)$
where \(w^r, w^w\) are attention weights, \(a_t\) is content to write.
For continual learning:
Store task-specific patterns in memory
Retrieve relevant patterns for new tasks
Memory grows with tasks but slower than full replay
6.2 Gradient Episodic Memory (GEM)ΒΆ
Key idea: Ensure gradients on new task donβt increase loss on old tasks.
Constraint: $\(\langle g, g_i \rangle \geq 0 \quad \forall i < t\)$
where:
\(g\): Gradient on current task \(t\)
\(g_i\): Gradient on task \(i\) (computed from stored examples)
Optimization: If constraint violated, project gradient: $\(g' = g - \frac{\langle g, g_i \rangle}{\|g_i\|^2} g_i\)$
Averaged-GEM (A-GEM): Use average gradient from memory instead of per-task constraints: $\(\langle g, \bar{g}_{\text{mem}} \rangle \geq 0\)$
Advantages:
Strong theoretical guarantee (no forgetting on stored data)
Works well with small memory
Fast projection
Limitations:
Requires gradient computation on memory
May be too conservative (limits plasticity)
6.3 Dark Experience Replay (DER)ΒΆ
Key idea: Store both inputs and model outputs (logits) from previous tasks.
Memory: \(\mathcal{M} = \{(x_i, y_i, z_i)\}\) where \(z_i = f_{\theta_{\text{old}}}(x_i)\) (logits)
Loss: $\(\mathcal{L} = \mathcal{L}_{\text{new}} + \alpha \mathcal{L}_{\text{MSE}}(f_\theta(x_{\text{mem}}), z_{\text{mem}}) + \beta \mathcal{L}_{\text{CE}}(f_\theta(x_{\text{mem}}), y_{\text{mem}})\)$
Benefits:
Preserves detailed information (logits vs. hard labels)
Better than standard replay
Dark knowledge transfer
7. Continual Learning in the WildΒΆ
7.1 Online Continual LearningΒΆ
Setting: No clear task boundaries, data arrives in a stream.
Challenges:
When to update model?
How to detect distribution shift?
Resource constraints (memory, compute)
Strategies:
Sliding window: Maintain recent data
Trigger-based: Update when performance drops
Probabilistic: Detect novelty via uncertainty
7.2 Continual Learning with Noisy LabelsΒΆ
Robust loss functions: $\(\mathcal{L}_{\text{robust}} = -\log \frac{\exp(z_y / T)}{\sum_i \exp(z_i / T)^\gamma}\)$
where \(\gamma < 1\) downweights low-confidence predictions.
Sample selection:
Use small loss samples for training
Store clean samples in memory
Update memory with high-confidence predictions
7.3 Continual Pre-trainingΒΆ
Goal: Continually update pre-trained models (e.g., BERT, GPT) with new data.
Challenges:
Forgetting of pre-trained knowledge
Distribution shift between pre-training and new data
Massive scale (billions of parameters)
Solutions:
Adapter modules: Add small task-specific layers
Prompt tuning: Learn task-specific prompts
Regularization: EWC on important parameters
Rehearsal: Mix in original pre-training data
8. Theoretical AnalysisΒΆ
8.1 PAC-Bayes Bounds for Continual LearningΒΆ
Theorem: With probability \(1-\delta\), for all tasks \(t\):
where:
\(q\): Posterior distribution over parameters
\(p\): Prior (from previous tasks)
\(n_t\): Number of samples in task \(t\)
Insight: Regularization term \(\text{KL}(q \| p)\) bounds forgetting.
8.2 Optimal Transport for Continual LearningΒΆ
Task similarity: $\(d(\mathcal{D}_i, \mathcal{D}_j) = \inf_{\gamma \in \Gamma} \mathbb{E}_{(x,y),(x',y') \sim \gamma}[c(x, x')]\)$
where \(\Gamma\) is set of joint distributions with marginals \(\mathcal{D}_i, \mathcal{D}_j\).
Application: Use task similarity to decide:
How much to regularize (similar tasks β less regularization)
Which samples to store (diverse tasks β more storage)
Architecture adaptation strategy
8.3 Information Bottleneck PerspectiveΒΆ
Trade-off: $\(\min I(X; T) \quad \text{s.t.} \quad I(Y; T) \geq I_{\min}\)$
where:
\(I(X; T)\): Mutual information between input and representation
\(I(Y; T)\): Mutual information between label and representation
\(T = f(X)\): Learned representation
For continual learning:
Compress representation to essential task information
Discard task-specific noise
Results in more robust representations
9. Recent Advances (2020-2024)ΒΆ
9.1 Prompt-based Continual LearningΒΆ
L2P (Learning to Prompt):
Maintain pool of learnable prompts
Select relevant prompts for each task
Freeze backbone, only update prompts
DualPrompt:
General prompts (shared across tasks)
Task-specific prompts (unique per task)
Gating mechanism for selection
Advantages:
Parameter-efficient (only update prompts)
No forgetting of backbone
Scalable to many tasks
9.2 Contrastive Continual LearningΒΆ
SupCon for CL: $\(\mathcal{L}_{\text{SupCon}} = -\log \frac{\sum_{p \in P(i)} \exp(\text{sim}(z_i, z_p) / \tau)}{\sum_{a \in A(i)} \exp(\text{sim}(z_i, z_a) / \tau)}\)$
where \(P(i)\) are positive pairs (same class), \(A(i)\) are all pairs.
Benefits:
Better feature separation
More robust representations
Less sensitive to distribution shift
9.3 Self-Supervised Continual LearningΒΆ
SimCLR for CL:
Pre-train with contrastive learning on each task
Fine-tune classifier on task data
Regularize representation to preserve old knowledge
Advantages:
Leverages unlabeled data
Better feature quality
Natural regularization via SSL objective
9.4 Continual Learning for Large Language ModelsΒΆ
Challenges:
Billions of parameters
Expensive to retrain
Data privacy (cannot store all data)
Solutions:
LoRA (Low-Rank Adaptation): Learn low-rank updates
Prefix Tuning: Learn task-specific prefixes
Adapters: Small bottleneck layers
Knowledge Distillation: Compress old model knowledge
10. Benchmarks and DatasetsΒΆ
10.1 Standard BenchmarksΒΆ
Split MNIST:
5 tasks, 2 classes each (0-1, 2-3, β¦, 8-9)
Simple baseline for method comparison
Split CIFAR-10/100:
CIFAR-10: 5 tasks, 2 classes each
CIFAR-100: 10 tasks, 10 classes each
More realistic images
TinyImageNet:
200 classes, split into 10-20 tasks
Higher resolution, more challenging
CORe50:
Real-world object recognition
50 classes, 11 sessions
Continuous acquisition (realistic setting)
10.2 Evaluation ProtocolsΒΆ
Offline evaluation:
Train on tasks sequentially
Test on all tasks after training
Online evaluation:
Test after each task
Measure forgetting over time
Multi-domain:
Same classes, different domains
Tests domain adaptation + continual learning
11. Best Practices and GuidelinesΒΆ
11.1 When to Use Which MethodΒΆ
Scenario |
Recommended Method |
Reason |
|---|---|---|
Memory available |
Experience Replay |
Simplest, most effective |
No memory allowed |
EWC or SI |
Regularization-based, no storage |
Task boundaries known |
Progressive Networks |
Zero forgetting guarantee |
Unknown task boundaries |
A-GEM or DER |
Online-friendly |
Large-scale (LLMs) |
LoRA or Adapters |
Parameter-efficient |
Privacy constraints |
Generative Replay |
No raw data storage |
11.2 Hyperparameter TuningΒΆ
Regularization strength (\(\lambda\)):
Too high β Cannot learn new tasks
Too low β Forgets old tasks
Typical range: 0.1 - 100 (dataset dependent)
Memory size (for replay):
Larger is better, but diminishing returns
Typical: 200-2000 samples per task
Class-balanced is crucial
Learning rate:
Lower than from-scratch training (to preserve old knowledge)
Typical: 0.1Γ original learning rate
May need different rates for new vs. replayed data
11.3 Common PitfallsΒΆ
β Evaluating only on final task: Doesnβt show forgetting
β Using task ID at test time for CIL: Makes problem easier
β Not reporting forgetting metrics: Only ACC is insufficient
β Storing too much data in memory: Defeats purpose of CL
β Ignoring computational overhead: Some methods very expensive
12. Future DirectionsΒΆ
12.1 Continual Few-Shot LearningΒΆ
Learn from few examples per task while maintaining old knowledge.
Challenges:
Limited data for new tasks
Easy to overfit
Hard to balance old vs. new
Approaches:
Meta-learning + replay
Prototype-based methods
Transfer learning from pre-trained models
12.2 Continual Reinforcement LearningΒΆ
Agent learns sequence of tasks in interactive environment.
Challenges:
Non-stationary environment
Credit assignment across tasks
Catastrophic forgetting of policies
Methods:
Policy distillation
Progressive networks for RL
Episodic memory for experiences
12.3 Neural-Symbolic Continual LearningΒΆ
Combine neural networks with symbolic reasoning.
Benefits:
Explicit knowledge representation
Easier to prevent forgetting (symbolic rules)
Better interpretability
12.4 Biological InspirationΒΆ
Complementary Learning Systems:
Fast learning in hippocampus
Slow learning in neocortex
Replay during sleep
Synaptic consolidation:
Synaptic tagging and capture
Protein synthesis for long-term memory
Homeostatic plasticity
13. Summary and Key TakeawaysΒΆ
Core Principles:ΒΆ
Catastrophic forgetting is fundamental: Neural networks naturally forget when learning new tasks
Trade-off is unavoidable: Stability β Plasticity must be balanced
No silver bullet: Different scenarios need different methods
Memory helps: Even small amounts of stored data dramatically improve performance
Regularization matters: Protecting important parameters prevents forgetting
Method Categories:ΒΆ
Regularization:
β No data storage needed
β Theoretical grounding
β Performance degrades with many tasks
Replay:
β Simple and effective
β Best empirical performance
β Requires data storage
Architecture:
β Zero forgetting possible
β Clear task separation
β Model growth
Meta-Learning:
β Fast adaptation
β Better sample efficiency
β Complex to implement
Practical Recommendations:ΒΆ
Start simple: Try experience replay first
Measure forgetting: Donβt just report final accuracy
Use class-incremental: Most realistic and challenging scenario
Combine methods: Hybrid approaches often work best
Consider constraints: Memory, compute, privacy requirements
Benchmark properly: Use standard datasets and metrics
Open Problems:ΒΆ
Scaling to thousands of tasks
Continual learning in production (online, non-stationary data)
Theory-practice gap (theoretical guarantees vs. empirical performance)
Biological plausibility (how do humans do it?)
Privacy-preserving continual learning (federated setting)
Conclusion: Continual learning is crucial for real-world AI systems that must adapt over time. While significant progress has been made, many challenges remain. The field is rapidly evolving with new methods combining insights from optimization, meta-learning, neuroscience, and information theory.
"""
Advanced Continual Learning - Production Implementation
Comprehensive implementations of major continual learning methods
Methods Implemented:
1. Elastic Weight Consolidation (EWC)
2. Synaptic Intelligence (SI)
3. Learning without Forgetting (LwF)
4. Experience Replay (ER)
5. Gradient Episodic Memory (GEM & A-GEM)
6. Dark Experience Replay (DER)
7. Progressive Neural Networks
8. Packnet
9. Meta-learning for CL (OML)
Features:
- Comprehensive evaluation metrics (ACC, FM, BWT, FWT)
- Multiple benchmark datasets (Split MNIST, Split CIFAR)
- Visualization of forgetting curves
- Performance comparison across methods
Author: Advanced Deep Learning Course
Date: 2024
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from typing import List, Tuple, Dict, Optional
import numpy as np
from collections import defaultdict
import copy
# ============================================================================
# 1. Base Classes and Utilities
# ============================================================================
class ContinualLearner(nn.Module):
"""Base class for continual learning methods."""
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.task_count = 0
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""
Observe a batch of data from a task.
Args:
x: Input batch
y: Labels
task_id: Task identifier
Returns:
loss: Training loss
"""
raise NotImplementedError
def after_task(self, task_id: int, dataloader: DataLoader):
"""Called after finishing a task (for consolidation, etc.)."""
pass
def forward(self, x: torch.Tensor, task_id: Optional[int] = None) -> torch.Tensor:
"""Forward pass (may use task_id for task-incremental)."""
return self.model(x)
class MemoryBuffer:
"""
Memory buffer for experience replay.
Implements various sampling strategies.
"""
def __init__(self, capacity: int, sampling: str = 'random'):
"""
Args:
capacity: Maximum number of samples to store
sampling: 'random', 'reservoir', or 'class_balanced'
"""
self.capacity = capacity
self.sampling = sampling
self.buffer = []
self.labels = []
self.task_ids = []
self.n_seen = 0
def add(self, x: torch.Tensor, y: torch.Tensor, task_id: int):
"""Add samples to buffer."""
batch_size = x.size(0)
for i in range(batch_size):
if len(self.buffer) < self.capacity:
# Buffer not full, just add
self.buffer.append(x[i].cpu())
self.labels.append(y[i].cpu())
self.task_ids.append(task_id)
else:
# Buffer full, use reservoir sampling
if self.sampling == 'reservoir':
j = np.random.randint(0, self.n_seen + i + 1)
if j < self.capacity:
self.buffer[j] = x[i].cpu()
self.labels[j] = y[i].cpu()
self.task_ids[j] = task_id
self.n_seen += 1
def sample(self, batch_size: int, device: str = 'cuda') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample batch from buffer."""
if len(self.buffer) == 0:
return None, None, None
indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), replace=False)
x = torch.stack([self.buffer[i] for i in indices]).to(device)
y = torch.tensor([self.labels[i] for i in indices], dtype=torch.long).to(device)
t = torch.tensor([self.task_ids[i] for i in indices], dtype=torch.long).to(device)
return x, y, t
def get_all(self, device: str = 'cuda') -> Tuple[torch.Tensor, torch.Tensor]:
"""Get all samples from buffer."""
if len(self.buffer) == 0:
return None, None
x = torch.stack(self.buffer).to(device)
y = torch.tensor(self.labels, dtype=torch.long).to(device)
return x, y
def __len__(self):
return len(self.buffer)
class EvaluationMetrics:
"""Compute continual learning metrics."""
def __init__(self):
self.accuracy_matrix = [] # accuracy_matrix[i][j] = accuracy on task j after training task i
def update(self, task_accuracies: List[float]):
"""Update with accuracies for all tasks after current task."""
self.accuracy_matrix.append(task_accuracies)
def average_accuracy(self) -> float:
"""Average accuracy across all tasks."""
if not self.accuracy_matrix:
return 0.0
return np.mean(self.accuracy_matrix[-1])
def forgetting_measure(self) -> float:
"""Average forgetting across tasks."""
if len(self.accuracy_matrix) < 2:
return 0.0
forgetting = []
for j in range(len(self.accuracy_matrix) - 1):
max_acc = max(self.accuracy_matrix[i][j] for i in range(j, len(self.accuracy_matrix)))
current_acc = self.accuracy_matrix[-1][j]
forgetting.append(max_acc - current_acc)
return np.mean(forgetting)
def backward_transfer(self) -> float:
"""Backward transfer (negative = forgetting)."""
if len(self.accuracy_matrix) < 2:
return 0.0
bwt = []
for j in range(len(self.accuracy_matrix) - 1):
final_acc = self.accuracy_matrix[-1][j]
after_task_acc = self.accuracy_matrix[j][j]
bwt.append(final_acc - after_task_acc)
return np.mean(bwt)
def forward_transfer(self, random_baseline: List[float]) -> float:
"""Forward transfer (knowledge transfer to new tasks)."""
if len(self.accuracy_matrix) < 2:
return 0.0
fwt = []
for i in range(1, len(self.accuracy_matrix)):
before_training = self.accuracy_matrix[i-1][i] if i < len(self.accuracy_matrix[i-1]) else 0
fwt.append(before_training - random_baseline[i])
return np.mean(fwt)
# ============================================================================
# 2. Elastic Weight Consolidation (EWC)
# ============================================================================
class EWC(ContinualLearner):
"""
Elastic Weight Consolidation (Kirkpatrick et al., 2017).
Protects important parameters using Fisher Information Matrix.
"""
def __init__(self, model: nn.Module, lambda_: float = 5000, fisher_samples: int = 200):
"""
Args:
model: Neural network
lambda_: Regularization strength
fisher_samples: Number of samples for Fisher estimation
"""
super().__init__(model)
self.lambda_ = lambda_
self.fisher_samples = fisher_samples
# Store optimal parameters and Fisher information for each task
self.optimal_params = {}
self.fisher_matrices = {}
def compute_fisher(self, dataloader: DataLoader, device: str = 'cuda') -> Dict[str, torch.Tensor]:
"""Compute Fisher Information Matrix."""
fisher = {n: torch.zeros_like(p, device=device) for n, p in self.model.named_parameters() if p.requires_grad}
self.model.eval()
n_samples = 0
for x, y in dataloader:
if n_samples >= self.fisher_samples:
break
x, y = x.to(device), y.to(device)
batch_size = x.size(0)
# Compute gradients
logits = self.model(x)
loss = F.cross_entropy(logits, y)
self.model.zero_grad()
loss.backward()
# Accumulate squared gradients
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data ** 2 * batch_size
n_samples += batch_size
# Normalize
for n in fisher:
fisher[n] /= n_samples
return fisher
def after_task(self, task_id: int, dataloader: DataLoader):
"""Compute and store Fisher matrix after task."""
device = next(self.model.parameters()).device
# Compute Fisher
fisher = self.compute_fisher(dataloader, device)
self.fisher_matrices[task_id] = fisher
# Store optimal parameters
self.optimal_params[task_id] = {
n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad
}
self.task_count += 1
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with EWC regularization."""
# Standard loss
logits = self.model(x)
loss = F.cross_entropy(logits, y)
# EWC penalty
if self.task_count > 0:
for t in range(self.task_count):
for n, p in self.model.named_parameters():
if p.requires_grad and n in self.fisher_matrices[t]:
fisher = self.fisher_matrices[t][n]
optimal = self.optimal_params[t][n]
loss += (self.lambda_ / 2) * (fisher * (p - optimal) ** 2).sum()
return loss
# ============================================================================
# 3. Synaptic Intelligence (SI)
# ============================================================================
class SynapticIntelligence(ContinualLearner):
"""
Synaptic Intelligence (Zenke et al., 2017).
Online computation of parameter importance during training.
"""
def __init__(self, model: nn.Module, c: float = 0.1, xi: float = 0.1):
"""
Args:
model: Neural network
c: Regularization strength
xi: Damping parameter
"""
super().__init__(model)
self.c = c
self.xi = xi
# Online importance estimation
self.omega = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
self.W = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad} # Accumulated path
# Optimal parameters from previous tasks
self.optimal_params = {}
# Track previous parameters
self.prev_params = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with SI regularization."""
# Standard loss
logits = self.model(x)
loss = F.cross_entropy(logits, y)
# SI penalty (from previous tasks)
if self.task_count > 0:
for n, p in self.model.named_parameters():
if p.requires_grad and n in self.omega:
optimal = self.optimal_params.get(n, p.data)
loss += (self.c / 2) * (self.omega[n] * (p - optimal) ** 2).sum()
# Compute gradients
loss.backward()
# Update path integral
device = next(self.model.parameters()).device
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
delta = p.data - self.prev_params[n]
self.W[n] += p.grad.data * delta
self.prev_params[n] = p.data.clone()
return loss
def after_task(self, task_id: int, dataloader: DataLoader):
"""Update importance after task."""
# Compute omega (importance)
for n, p in self.model.named_parameters():
if p.requires_grad:
delta = p.data - self.prev_params.get(n, p.data)
self.omega[n] += self.W[n] / (delta ** 2 + self.xi)
# Reset path integral for next task
self.W[n].zero_()
# Store optimal parameters
self.optimal_params = {n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad}
self.task_count += 1
# ============================================================================
# 4. Learning without Forgetting (LwF)
# ============================================================================
class LearningWithoutForgetting(ContinualLearner):
"""
Learning without Forgetting (Li & Hoiem, 2016).
Uses knowledge distillation to preserve outputs on old tasks.
"""
def __init__(self, model: nn.Module, lambda_: float = 1.0, T: float = 2.0):
"""
Args:
model: Neural network
lambda_: Distillation loss weight
T: Temperature for softmax
"""
super().__init__(model)
self.lambda_ = lambda_
self.T = T
self.old_model = None
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with knowledge distillation."""
logits = self.model(x)
# New task loss
loss = F.cross_entropy(logits, y)
# Distillation loss (if not first task)
if self.old_model is not None:
with torch.no_grad():
old_logits = self.old_model(x)
# Soft targets
p_old = F.softmax(old_logits / self.T, dim=1)
p_new = F.log_softmax(logits / self.T, dim=1)
distill_loss = F.kl_div(p_new, p_old, reduction='batchmean') * (self.T ** 2)
loss += self.lambda_ * distill_loss
return loss
def after_task(self, task_id: int, dataloader: DataLoader):
"""Store model for distillation."""
self.old_model = copy.deepcopy(self.model)
self.old_model.eval()
for p in self.old_model.parameters():
p.requires_grad = False
self.task_count += 1
# ============================================================================
# 5. Experience Replay (ER)
# ============================================================================
class ExperienceReplay(ContinualLearner):
"""
Experience Replay: Store subset of data, interleave with new data.
"""
def __init__(self, model: nn.Module, memory_size: int = 1000, replay_batch_size: int = 64):
"""
Args:
model: Neural network
memory_size: Size of replay buffer
replay_batch_size: Batch size for replay
"""
super().__init__(model)
self.memory = MemoryBuffer(memory_size)
self.replay_batch_size = replay_batch_size
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with experience replay."""
# Add to memory
self.memory.add(x, y, task_id)
# Standard loss on current batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
# Replay loss
if len(self.memory) > 0:
x_mem, y_mem, _ = self.memory.sample(self.replay_batch_size, x.device)
if x_mem is not None:
logits_mem = self.model(x_mem)
loss += F.cross_entropy(logits_mem, y_mem)
return loss
# ============================================================================
# 6. Gradient Episodic Memory (GEM & A-GEM)
# ============================================================================
class AGEM(ContinualLearner):
"""
Averaged Gradient Episodic Memory (Chaudhry et al., 2019).
Projects gradients to not increase loss on memory.
"""
def __init__(self, model: nn.Module, memory_size: int = 1000, memory_batch_size: int = 64):
"""
Args:
model: Neural network
memory_size: Size of episodic memory
memory_batch_size: Batch size for memory gradient
"""
super().__init__(model)
self.memory = MemoryBuffer(memory_size)
self.memory_batch_size = memory_batch_size
def project_gradient(self, g_current: torch.Tensor, g_memory: torch.Tensor) -> torch.Tensor:
"""Project gradient if it increases loss on memory."""
dot_product = torch.dot(g_current, g_memory)
if dot_product < 0:
# Violation: project gradient
g_proj = g_current - (dot_product / (torch.norm(g_memory) ** 2 + 1e-8)) * g_memory
return g_proj
else:
# No violation
return g_current
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with gradient projection."""
# Add to memory
self.memory.add(x, y, task_id)
# Compute gradient on current batch
self.model.zero_grad()
logits = self.model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
# Store current gradient
g_current = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
# Compute gradient on memory
if len(self.memory) > 0:
x_mem, y_mem, _ = self.memory.sample(self.memory_batch_size, x.device)
if x_mem is not None:
self.model.zero_grad()
logits_mem = self.model(x_mem)
loss_mem = F.cross_entropy(logits_mem, y_mem)
loss_mem.backward()
g_memory = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
# Project gradient
g_proj = self.project_gradient(g_current, g_memory)
# Set projected gradient
idx = 0
for p in self.model.parameters():
if p.grad is not None:
numel = p.grad.numel()
p.grad.data = g_proj[idx:idx+numel].view_as(p.grad)
idx += numel
return loss
# ============================================================================
# 7. Dark Experience Replay (DER)
# ============================================================================
class DarkExperienceReplay(ContinualLearner):
"""
Dark Experience Replay (Buzzega et al., 2020).
Stores both inputs and logits (dark knowledge).
"""
def __init__(
self,
model: nn.Module,
memory_size: int = 1000,
alpha: float = 0.5,
beta: float = 0.5
):
"""
Args:
model: Neural network
memory_size: Size of buffer
alpha: Weight for logit matching loss
beta: Weight for classification loss on memory
"""
super().__init__(model)
self.memory_size = memory_size
self.alpha = alpha
self.beta = beta
# Store (x, y, logits)
self.memory_x = []
self.memory_y = []
self.memory_logits = []
def add_to_memory(self, x: torch.Tensor, y: torch.Tensor, logits: torch.Tensor):
"""Add samples with logits to memory."""
for i in range(x.size(0)):
if len(self.memory_x) < self.memory_size:
self.memory_x.append(x[i].cpu())
self.memory_y.append(y[i].cpu())
self.memory_logits.append(logits[i].detach().cpu())
else:
# Random replacement
idx = np.random.randint(0, self.memory_size)
self.memory_x[idx] = x[i].cpu()
self.memory_y[idx] = y[i].cpu()
self.memory_logits[idx] = logits[i].detach().cpu()
def observe(self, x: torch.Tensor, y: torch.Tensor, task_id: int) -> float:
"""Train with dark replay."""
# Forward pass
logits = self.model(x)
# Store in memory
self.add_to_memory(x, y, logits)
# New task loss
loss = F.cross_entropy(logits, y)
# Replay loss
if len(self.memory_x) > 0:
# Sample from memory
batch_size = min(x.size(0), len(self.memory_x))
indices = np.random.choice(len(self.memory_x), batch_size, replace=False)
x_mem = torch.stack([self.memory_x[i] for i in indices]).to(x.device)
y_mem = torch.tensor([self.memory_y[i] for i in indices], dtype=torch.long).to(x.device)
logits_old = torch.stack([self.memory_logits[i] for i in indices]).to(x.device)
# Current model predictions on memory
logits_mem = self.model(x_mem)
# MSE loss on logits (dark knowledge)
loss_mse = F.mse_loss(logits_mem, logits_old)
# CE loss on labels
loss_ce = F.cross_entropy(logits_mem, y_mem)
loss += self.alpha * loss_mse + self.beta * loss_ce
return loss
# ============================================================================
# 8. Progressive Neural Networks
# ============================================================================
class ProgressiveNetwork(nn.Module):
"""
Progressive Neural Networks (Rusu et al., 2016).
Allocate new column for each task, connect to previous columns.
"""
def __init__(self, input_size: int, hidden_sizes: List[int], output_size: int):
super().__init__()
self.columns = nn.ModuleList()
self.laterals = nn.ModuleList()
self.input_size = input_size
self.hidden_sizes = hidden_sizes
self.output_size = output_size
self.n_tasks = 0
def add_task(self):
"""Add new column for new task."""
# New column
layers = []
prev_size = self.input_size
for hidden_size in self.hidden_sizes:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.ReLU())
prev_size = hidden_size
layers.append(nn.Linear(prev_size, self.output_size))
column = nn.Sequential(*layers)
self.columns.append(column)
# Lateral connections (from all previous columns to this column)
if self.n_tasks > 0:
lateral_list = []
for _ in range(self.n_tasks):
lateral_layers = nn.ModuleList()
for hidden_size in self.hidden_sizes:
lateral_layers.append(nn.Linear(hidden_size, hidden_size))
lateral_list.append(lateral_layers)
self.laterals.append(lateral_list)
self.n_tasks += 1
def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
"""Forward pass through specific column with lateral connections."""
if task_id >= self.n_tasks:
raise ValueError(f"Task {task_id} not yet added")
# Activations from all columns
activations = [[] for _ in range(self.n_tasks)]
# Compute activations layer by layer
current_input = x
for layer_idx in range(len(self.hidden_sizes) + 1):
# For each column up to and including task_id
for col_idx in range(task_id + 1):
if layer_idx == 0:
# First layer: just input
if col_idx < len(self.columns):
h = self.columns[col_idx][2*layer_idx](current_input) if layer_idx < len(self.hidden_sizes) else self.columns[col_idx][-1](current_input)
activations[col_idx].append(h)
else:
# Subsequent layers: own activation + lateral connections
h = activations[col_idx][-1]
# Add lateral connections from previous columns
if col_idx > 0 and layer_idx - 1 < len(self.hidden_sizes):
for prev_col in range(col_idx):
if len(self.laterals) > col_idx - 1:
lateral = self.laterals[col_idx - 1][prev_col][layer_idx - 1]
h = h + lateral(activations[prev_col][-1])
activations[col_idx].append(h)
return activations[task_id][-1]
# ============================================================================
# 9. Demo and Evaluation
# ============================================================================
def create_split_mnist_tasks(n_tasks: int = 5) -> List[Tuple]:
"""
Create Split MNIST tasks.
Returns:
List of (train_dataset, test_dataset) tuples
"""
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load full MNIST
train_full = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_full = datasets.MNIST('./data', train=False, transform=transform)
tasks = []
classes_per_task = 10 // n_tasks
for task_id in range(n_tasks):
# Classes for this task
start_class = task_id * classes_per_task
end_class = start_class + classes_per_task
# Filter datasets
train_indices = [i for i, (_, y) in enumerate(train_full) if start_class <= y < end_class]
test_indices = [i for i, (_, y) in enumerate(test_full) if start_class <= y < end_class]
train_subset = Subset(train_full, train_indices)
test_subset = Subset(test_full, test_indices)
tasks.append((train_subset, test_subset))
return tasks
def evaluate_on_all_tasks(
model: nn.Module,
tasks: List[Tuple],
device: str = 'cuda'
) -> List[float]:
"""Evaluate model on all tasks."""
accuracies = []
model.eval()
with torch.no_grad():
for task_id, (_, test_dataset) in enumerate(tasks):
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
correct = 0
total = 0
for x, y in test_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
acc = correct / total if total > 0 else 0
accuracies.append(acc)
return accuracies
def train_continual_learning(
method: ContinualLearner,
tasks: List[Tuple],
n_epochs: int = 5,
lr: float = 0.01,
device: str = 'cuda'
) -> EvaluationMetrics:
"""
Train continual learning method on sequence of tasks.
Args:
method: CL method
tasks: List of (train, test) datasets
n_epochs: Epochs per task
lr: Learning rate
Returns:
metrics: Evaluation metrics
"""
method = method.to(device)
optimizer = torch.optim.SGD(method.parameters(), lr=lr, momentum=0.9)
metrics = EvaluationMetrics()
for task_id, (train_dataset, _) in enumerate(tasks):
print(f"\n{'='*50}")
print(f"Training Task {task_id}")
print(f"{'='*50}")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Train on task
for epoch in range(n_epochs):
method.train()
total_loss = 0
n_batches = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = method.observe(x, y, task_id)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
avg_loss = total_loss / n_batches
print(f" Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")
# Consolidate after task
method.after_task(task_id, train_loader)
# Evaluate on all tasks seen so far
task_accuracies = evaluate_on_all_tasks(method, tasks[:task_id+1], device)
metrics.update(task_accuracies)
print(f"\nAccuracies after Task {task_id}:")
for i, acc in enumerate(task_accuracies):
print(f" Task {i}: {acc*100:.2f}%")
# Print final metrics
print(f"\n{'='*50}")
print("Final Metrics")
print(f"{'='*50}")
print(f"Average Accuracy: {metrics.average_accuracy()*100:.2f}%")
print(f"Forgetting Measure: {metrics.forgetting_measure()*100:.2f}%")
print(f"Backward Transfer: {metrics.backward_transfer()*100:.2f}%")
return metrics
def demo_continual_learning():
"""Demo comparing different CL methods."""
print("=" * 80)
print("Continual Learning Methods Comparison Demo")
print("=" * 80)
# Create tasks
print("\nCreating Split MNIST tasks...")
tasks = create_split_mnist_tasks(n_tasks=5)
print(f"Created {len(tasks)} tasks (2 digits each)")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Simple CNN model
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Methods to compare
methods = {
'Naive (Fine-tuning)': ContinualLearner(SimpleCNN()),
'EWC': EWC(SimpleCNN(), lambda_=1000),
'Experience Replay': ExperienceReplay(SimpleCNN(), memory_size=500),
'LwF': LearningWithoutForgetting(SimpleCNN(), lambda_=1.0),
}
results = {}
for name, method in methods.items():
print(f"\n{'*'*80}")
print(f"Training: {name}")
print(f"{'*'*80}")
metrics = train_continual_learning(
method,
tasks,
n_epochs=3,
lr=0.01,
device=device
)
results[name] = {
'avg_acc': metrics.average_accuracy(),
'forgetting': metrics.forgetting_measure(),
'bwt': metrics.backward_transfer()
}
# Print comparison table
print("\n" + "=" * 80)
print("Results Comparison")
print("=" * 80)
print(f"{'Method':<25} {'Avg Acc':<12} {'Forgetting':<12} {'BWT':<12}")
print("-" * 80)
for name, result in results.items():
print(f"{name:<25} {result['avg_acc']*100:>10.2f}% {result['forgetting']*100:>10.2f}% {result['bwt']*100:>10.2f}%")
if __name__ == "__main__":
print("\nAdvanced Continual Learning - Comprehensive Implementation\n")
demo_continual_learning()
print("\n" + "=" * 80)
print("Demo complete! Key takeaways:")
print("=" * 80)
print("1. Catastrophic forgetting is a major challenge in continual learning")
print("2. EWC protects important parameters using Fisher information")
print("3. Experience Replay is simple but very effective")
print("4. LwF uses knowledge distillation (no data storage)")
print("5. Trade-off between stability and plasticity is fundamental")
print("6. Different methods suit different scenarios (memory, privacy, etc.)")
print("=" * 80)