import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Contrastive Learning FrameworkΒΆ
GoalΒΆ
Learn representations where:
Similar samples are close
Dissimilar samples are far
NT-Xent LossΒΆ
where \(\text{sim}(u, v) = u^T v / (\|u\| \|v\|)\) (cosine similarity).
π Reference Materials:
deep_learning_chatgpt.pdf - Deep Learning Chatgpt
Data AugmentationΒΆ
Data augmentation is the cornerstone of contrastive self-supervised learning. Each training image is transformed by two different random augmentations (e.g., random crop, color jitter, Gaussian blur, horizontal flip) to create a positive pair. The contrastive objective then pulls these two views together in embedding space while pushing apart views from different images. The choice and strength of augmentations critically determines what invariances the model learns: augmentations should remove information that is irrelevant for downstream tasks while preserving semantically meaningful content. Too weak augmentations lead to trivial solutions; too strong augmentations destroy useful information.
class ContrastiveTransform:
"""Generate two augmented views."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
return [self.base_transform(x), self.base_transform(x)]
# Strong augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# Load data
mnist = datasets.MNIST('./data', train=True, download=True,
transform=ContrastiveTransform(train_transform))
dataloader = torch.utils.data.DataLoader(mnist, batch_size=256, shuffle=True, num_workers=2)
print(f"Dataset size: {len(mnist)}")
SimCLR ImplementationΒΆ
SimCLR (Simple Framework for Contrastive Learning of Visual Representations) consists of an encoder network (e.g., ResNet) followed by a small projection head MLP. For a batch of \(N\) images, each is augmented twice to produce \(2N\) views. The NT-Xent loss (Normalized Temperature-scaled Cross-Entropy) for a positive pair \((i, j)\) is: \(\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}\), where \(\text{sim}\) is cosine similarity and \(\tau\) is a temperature parameter. The projection head is discarded after pre-training β representations from the encoder backbone are used for downstream tasks, a design choice that significantly improves transfer performance.
class Encoder(nn.Module):
"""Simple CNN encoder."""
def __init__(self, out_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 256),
nn.ReLU()
)
# Projection head
self.projection = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, out_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projection(h)
return F.normalize(z, dim=1)
def nt_xent_loss(z_i, z_j, temperature=0.5):
"""NT-Xent (normalized temperature-scaled cross entropy) loss."""
batch_size = z_i.size(0)
# Concatenate
z = torch.cat([z_i, z_j], dim=0)
# Similarity matrix
sim = torch.mm(z, z.T) / temperature
# Mask for positive pairs
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
sim = sim.masked_fill(mask, -9e15)
# Positive pairs indices
pos_indices = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
# Loss
loss = F.cross_entropy(sim, pos_indices)
return loss
# Test
model = Encoder().to(device)
x1 = torch.randn(8, 1, 28, 28).to(device)
x2 = torch.randn(8, 1, 28, 28).to(device)
z1 = model(x1)
z2 = model(x2)
loss = nt_xent_loss(z1, z2)
print(f"Loss: {loss.item():.4f}")
Train SimCLRΒΆ
SimCLR training requires large batch sizes (the original paper uses 4096) because each sample needs many negative pairs within the batch for the contrastive loss to be informative. The loss is symmetric: both views in a positive pair serve as anchor and positive, doubling the effective training signal. The temperature \(\tau\) controls the sharpness of the similarity distribution β lower values make the model focus harder on distinguishing the closest negatives. Training for hundreds of epochs on unlabeled data produces representations that, when evaluated with a simple linear classifier, approach supervised learning performance.
def train_simclr(model, dataloader, n_epochs=10, temperature=0.5):
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
losses = []
for epoch in range(n_epochs):
epoch_loss = 0
for batch_idx, (views, _) in enumerate(dataloader):
x_i, x_j = views[0].to(device), views[1].to(device)
# Forward
z_i = model(x_i)
z_j = model(x_j)
# Loss
loss = nt_xent_loss(z_i, z_j, temperature)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(loss.item())
print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
return losses
model = Encoder().to(device)
losses = train_simclr(model, dataloader, n_epochs=10)
plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(20)/20, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('NT-Xent Loss', fontsize=11)
plt.title('SimCLR Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
MoCo (Momentum Contrast)ΒΆ
MoCo addresses SimCLRβs reliance on large batch sizes by maintaining a momentum-updated queue of negative embeddings. The query encoder is updated by gradient descent as usual, but the key encoder is updated as a slow-moving average: \(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\) with momentum \(m \approx 0.999\). The queue stores key embeddings from recent mini-batches, providing a large and consistent pool of negatives without requiring enormous batch sizes. This design decouples the number of negatives from the batch size, making MoCo practical on standard hardware while matching or exceeding SimCLRβs representation quality.
class MoCo(nn.Module):
"""Momentum Contrast."""
def __init__(self, encoder_q, encoder_k, dim=128, K=4096, m=0.999, T=0.07):
super().__init__()
self.K = K
self.m = m
self.T = T
# Encoders
self.encoder_q = encoder_q
self.encoder_k = encoder_k
# Initialize key encoder with query encoder
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# Queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Update queue."""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# Replace oldest
if ptr + batch_size <= self.K:
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
else:
# Wrap around
self.queue[:, ptr:] = keys[:self.K - ptr].T
self.queue[:, :(ptr + batch_size) % self.K] = keys[self.K - ptr:].T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
# Query
q = self.encoder_q(im_q)
# Key
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
# Positive logits
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
# Labels: positives are 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# Update queue
self._dequeue_and_enqueue(k)
return logits, labels
# Test
encoder_q = Encoder().to(device)
encoder_k = Encoder().to(device)
moco = MoCo(encoder_q, encoder_k, K=1024).to(device)
x_q = torch.randn(8, 1, 28, 28).to(device)
x_k = torch.randn(8, 1, 28, 28).to(device)
logits, labels = moco(x_q, x_k)
print(f"Logits: {logits.shape}, Labels: {labels.shape}")
Train MoCoΒΆ
MoCo training follows a similar loop to SimCLR but replaces the in-batch negatives with the queue. At each step, the query encoder processes one augmented view, the momentum encoder processes the other, and the contrastive loss is computed against the queue of recent key embeddings. After each step, the oldest entries in the queue are dequeued and the new key embeddings are enqueued. The slow momentum update of the key encoder ensures that queue entries remain approximately consistent, avoiding the representation drift that would occur if old and new keys came from very different encoders.
def train_moco(moco, dataloader, n_epochs=10):
optimizer = torch.optim.SGD(moco.encoder_q.parameters(), lr=0.03,
momentum=0.9, weight_decay=1e-4)
losses = []
for epoch in range(n_epochs):
epoch_loss = 0
for batch_idx, (views, _) in enumerate(dataloader):
im_q, im_k = views[0].to(device), views[1].to(device)
logits, labels = moco(im_q, im_k)
loss = F.cross_entropy(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(loss.item())
print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
return losses
encoder_q = Encoder().to(device)
encoder_k = Encoder().to(device)
moco = MoCo(encoder_q, encoder_k, K=2048).to(device)
losses_moco = train_moco(moco, dataloader, n_epochs=10)
plt.figure(figsize=(10, 5))
plt.plot(losses_moco, alpha=0.5)
plt.plot(np.convolve(losses_moco, np.ones(20)/20, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('MoCo Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Linear EvaluationΒΆ
The standard evaluation protocol for self-supervised representations is linear evaluation: freeze the pre-trained encoder and train only a linear classifier (single fully connected layer) on labeled data. High linear evaluation accuracy indicates that the learned representations are linearly separable by class, meaning the contrastive pre-training has organized the feature space in a semantically meaningful way. Comparing linear evaluation accuracy between SimCLR, MoCo, and a randomly initialized baseline quantifies how much useful structure each method has extracted from unlabeled data.
# Freeze encoder, train linear classifier
model.eval()
for param in model.parameters():
param.requires_grad = False
# Linear classifier
classifier = nn.Linear(128, 10).to(device)
# Normal MNIST for evaluation
eval_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist_eval = datasets.MNIST('./data', train=True, download=True, transform=eval_transform)
eval_loader = torch.utils.data.DataLoader(mnist_eval, batch_size=256, shuffle=True)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
for epoch in range(5):
correct = 0
total = 0
for x, y in eval_loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
features = model(x)
logits = classifier(features)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
print(f"Epoch {epoch+1}, Accuracy: {100*correct/total:.2f}%")
SummaryΒΆ
SimCLR:ΒΆ
Large batch sizes
Strong augmentations
NT-Xent loss
Projection head
MoCo:ΒΆ
Momentum encoder \(\theta_k \leftarrow m\theta_k + (1-m)\theta_q\)
Queue for negatives
More memory efficient
Key Insights:ΒΆ
Data augmentation crucial
Large negative sets improve quality
Projection head helps
Temperature controls hardness
Applications:ΒΆ
Pre-training for downstream tasks
Transfer learning
Few-shot learning
Anomaly detection
Extensions:ΒΆ
SimCLRv2 (bigger models, distillation)
MoCov2/v3 (improvements)
BYOL (no negatives)
SwAV (clustering)
Next Steps:ΒΆ
05_sentence_transformer_intro.ipynb - Text embeddings
Study BYOL and Barlow Twins
Explore multi-modal contrastive learning (CLIP)
Advanced Contrastive Learning TheoryΒΆ
1. Introduction to Contrastive LearningΒΆ
Contrastive learning is a self-supervised learning paradigm that learns representations by contrasting positive pairs against negative pairs.
1.1 Core IdeaΒΆ
Goal: Learn embeddings where:
Similar samples (positives) are close in embedding space
Dissimilar samples (negatives) are far apart
Mathematical formulation:
Maximize similarity(f(x), f(xβΊ))
Minimize similarity(f(x), f(xβ»))
Where:
f: Encoder network
xβΊ: Positive sample (augmentation of x, or semantically similar)
xβ»: Negative samples (different from x)
1.2 Why Contrastive Learning?ΒΆ
Advantages:
Self-supervised: No labels needed (use augmentations)
Generalization: Learns task-agnostic representations
Data efficiency: Pre-training on unlabeled data, fine-tune with few labels
SOTA performance: Rivals supervised learning on many tasks
Key insight: Augmentations of the same image should have similar representations, different images should not.
2. Contrastive Loss FunctionsΒΆ
2.1 InfoNCE Loss (NT-Xent)ΒΆ
Noise Contrastive Estimation [Oord et al., 2018]:
For anchor x with positive xβΊ and N-1 negatives {xβ»α΅’}:
L = -log[exp(sim(z, zβΊ)/Ο) / Ξ£β±Ό exp(sim(z, zβ±Ό)/Ο)]
Where:
z = f(x): Embedding of anchor
zβΊ = f(xβΊ): Embedding of positive
zβ±Ό: Embeddings of all samples in batch (positives + negatives)
sim(u, v) = uα΅v / (||u|| ||v||): Cosine similarity
Ο: Temperature parameter
Temperature Ο:
Low Ο β Sharp distribution (high confidence)
High Ο β Smooth distribution (less confident)
Typical: Ο = 0.07-0.5
Batch size importance: Large batches provide more negatives β better training.
2.2 Triplet LossΒΆ
Formulation:
L = max(0, ||f(a) - f(p)||Β² - ||f(a) - f(n)||Β² + margin)
Where:
a: Anchor
p: Positive
n: Negative
margin: Minimum separation between positive and negative
Hard negative mining: Select negatives closer to anchor for harder learning.
2.3 N-Pair LossΒΆ
Extension of triplet to multiple negatives:
L = log(1 + Ξ£α΅’ exp(f(a)α΅f(nα΅’) - f(a)α΅f(p)))
Benefits: Better gradient signal from multiple negatives.
2.4 Supervised Contrastive LossΒΆ
SupCon [Khosla et al., 2020]: Use label information.
L = -Ξ£α΅’ββ^|P(i)| log[exp(zα΅’α΅z / Ο) / Ξ£βββ^{2N} 1[aβ i] exp(zα΅’α΅zβ / Ο)]
Where P(i) are all positives with same label as i.
Advantage: Multiple positives per anchor (all samples with same label).
3. SimCLR FrameworkΒΆ
SimCLR [Chen et al., 2020] is a simple framework for contrastive learning.
3.1 ArchitectureΒΆ
Pipeline:
Data augmentation: Generate two views xΜα΅’, xΜβ±Ό from x
Encoder f(Β·): CNN backbone (e.g., ResNet-50) β hα΅’ = f(xΜα΅’)
Projection head g(Β·): MLP β zα΅’ = g(hα΅’)
Contrastive loss: Apply NT-Xent on {zα΅’}
Key components:
Augmentation composition: Random crop, color jitter, Gaussian blur
Large batch size: 256-8192 (more negatives)
Projection head: 2-layer MLP (improves representation quality)
Temperature: Ο = 0.07-0.5
3.2 AugmentationsΒΆ
Strong augmentations crucial:
Random cropping + resizing (0.08-1.0 scale)
Color distortion (brightness, contrast, saturation, hue)
Gaussian blur (kernel size 10% of image)
Random horizontal flip
Composition: Apply multiple augmentations sequentially.
Finding: Stronger augmentations β better representations (up to a point).
3.3 Training AlgorithmΒΆ
For each minibatch of N samples:
1. Sample two augmentations for each: 2N samples total
2. Encode: zβ, ..., zββ = g(f(augment(xβ)), ..., g(f(augment(xββ)))
3. For each pair (i, j) from same original image:
- Positive: (zα΅’, zβ±Ό)
- Negatives: All other 2(N-1) samples
- Compute InfoNCE loss
4. Update encoder f and projection g via gradient descent
Batch size scaling: Linear scaling rule for learning rate (lr = base_lr Γ batch_size / 256).
3.4 Mathematical AnalysisΒΆ
Alignment: Pull positive pairs together:
l_align = E_p(x,xβΊ) [||f(x) - f(xβΊ)||Β²]
Uniformity: Spread features uniformly on hypersphere:
l_uniform = log E_x,y~p [exp(-||f(x) - f(y)||Β²)]
SimCLR optimizes: Balance between alignment and uniformity.
4. MoCo (Momentum Contrast)ΒΆ
MoCo [He et al., 2020] uses a memory bank and momentum encoder for efficient contrastive learning.
4.1 ArchitectureΒΆ
Two encoders:
Query encoder f_q: Updated via backprop
Key encoder f_k: Updated via momentum
Queue: Stores encoded keys (negatives) from previous batches.
4.2 Momentum UpdateΒΆ
ΞΈ_k β m ΞΈ_k + (1-m) ΞΈ_q
Where:
ΞΈ_q: Parameters of query encoder
ΞΈ_k: Parameters of key encoder
m: Momentum coefficient (e.g., 0.999)
Benefit: Key encoder evolves slowly β consistent keys in queue.
4.3 Queue MechanismΒΆ
Idea: Decouple batch size from number of negatives.
Operation:
Encode queries: q = f_q(x_query)
Encode keys: k = f_k(x_key)
Enqueue current keys
Dequeue oldest keys
Compute contrastive loss: q vs. current k (positive) and queue (negatives)
Queue size: Typically 65,536 (much larger than batch size).
4.4 MoCo v2 and v3ΒΆ
MoCo v2 improvements:
MLP projection head (from SimCLR)
Stronger augmentations
Cosine learning rate schedule
MoCo v3 [Chen et al., 2021]:
Symmetric loss (both directions: qβk and kβq)
Vision Transformer (ViT) backbone
Prediction head (from BYOL)
Formula (MoCo v3):
L = L(qβ, kβ) + L(qβ, kβ)
Where qβ, qβ are queries from two views, kβ, kβ are keys.
5. BYOL (Bootstrap Your Own Latent)ΒΆ
BYOL [Grill et al., 2020]: No negative pairs needed!
5.1 ArchitectureΒΆ
Two networks:
Online network: Encoder f_ΞΈ + Projector g_ΞΈ + Predictor h_ΞΈ
Target network: Encoder f_ΞΎ + Projector g_ΞΎ (no predictor)
Asymmetry: Only online network has predictor.
5.2 AlgorithmΒΆ
For two augmented views xβ, xβ:
Online network:
yβ = f_ΞΈ(xβ)
zβ = g_ΞΈ(yβ)
pβ = h_ΞΈ(zβ)
Target network:
yβ = f_ΞΎ(xβ)
zβ = g_ΞΎ(yβ)
Loss (mean squared error):
L = ||pβ - sg(zβ)||Β² + ||pβ - sg(zβ)||Β²
Where sg(Β·) is stop-gradient operator.
Target network update (exponential moving average):
ΞΎ β ΟΞΎ + (1-Ο)ΞΈ
Typical: Ο = 0.996-0.999.
5.3 Why No Negatives?ΒΆ
Predictor asymmetry prevents collapse:
Online network learns to predict target
Target network evolves slowly
Without predictor, would collapse to constant
Stop-gradient crucial: Prevents trivial solution.
5.4 Theoretical InsightΒΆ
BYOL optimizes:
min_ΞΈ E_x,Tβ,Tβ [||h_ΞΈ(g_ΞΈ(f_ΞΈ(Tβ(x)))) - g_ΞΎ(f_ΞΎ(Tβ(x)))||Β²]
Implicit regularization: Predictor prevents representation collapse.
6. SwAV (Swapped Assignments between Views)ΒΆ
SwAV [Caron et al., 2020]: Clustering-based contrastive learning.
6.1 Core IdeaΒΆ
Swap prediction: Predict cluster assignment of one view from another.
No negatives: Uses prototype vectors (cluster centers).
6.2 AlgorithmΒΆ
Encode two views: zβ = f(xβ), zβ = f(xβ)
Cluster assignment via Sinkhorn-Knopp:
Compute similarity to prototypes: C = {cβ, β¦, c_K}
Soft assignments: qβ = softmax(zβα΅C/Ο)
Swap prediction loss:
L = -qβα΅ log pβ - qβα΅ log pβ
Where pβ = softmax(zβα΅C/Ο)
6.3 Sinkhorn-Knopp AlgorithmΒΆ
Optimal transport: Enforce equal cluster sizes.
Iterations:
Repeat T times:
Q β Q / sum(Q, dim=0) # Normalize columns
Q β Q / sum(Q, dim=1) # Normalize rows
Ensures each prototype receives equal weight.
6.4 Multi-Crop StrategyΒΆ
Efficiency trick: Use crops of different sizes.
2 global views (224Γ224): Standard crops
V local views (96Γ96): Smaller crops
Total: 2+V views, but local views only passed through encoder (cheaper).
Loss: Predict global from local and vice versa.
7. DINO (Self-Distillation with No Labels)ΒΆ
DINO [Caron et al., 2021]: Self-supervised learning via self-distillation.
7.1 ArchitectureΒΆ
Student-teacher framework:
Student network: Updated via gradients
Teacher network: EMA of student
Both networks: Vision Transformer (ViT) or CNN.
7.2 AlgorithmΒΆ
Multi-crop strategy: Global crops (2Γ) + local crops (β₯8Γ).
Cross-entropy loss:
L = -Ξ£α΅’ P_t(xα΅’) log P_s(xα΅’)
Where:
P_t: Teacher output (softmax over prototypes, high temperature Ο_t)
P_s: Student output (softmax, low temperature Ο_s)
Centering: Prevent collapse by subtracting mean:
P_t = softmax((g_t(x) - c) / Ο_t)
Where c is EMA of batch means.
7.3 Key InnovationsΒΆ
Temperature sharpening: Ο_s < Ο_t (student sharper than teacher)
Centering: Prevent one prototype dominating
Multi-crop: Global views for teacher, global+local for student
No contrastive loss: Uses self-distillation instead
Finding: Works exceptionally well with Vision Transformers.
8. Barlow TwinsΒΆ
Barlow Twins [Zbontar et al., 2021]: Redundancy reduction objective.
8.1 ObjectiveΒΆ
Cross-correlation matrix between embeddings:
C_ij = Ξ£_b z_A^b_i z_B^b_j / β(Ξ£_b (z_A^b_i)Β²) β(Ξ£_b (z_B^b_j)Β²)
Where z_A, z_B are embeddings from two views.
Loss:
L = Ξ£α΅’ (1 - C_ii)Β² + Ξ» Ξ£α΅’ Ξ£β±Όβ α΅’ C_ijΒ²
βββββββββββ ββββββββββββββ
Invariance Redundancy reduction
Goal: Diagonal of C should be 1 (invariance), off-diagonal 0 (decorrelation).
8.2 InterpretationΒΆ
Barlowβs redundancy reduction principle: Remove redundant information.
Comparison to other methods:
SimCLR: Contrastive (push away negatives)
Barlow Twins: Decorrelation (remove redundancy)
No negatives needed: Only uses positive pairs.
9. VICReg (Variance-Invariance-Covariance Regularization)ΒΆ
VICReg [Bardes et al., 2022]: Explicit regularization to prevent collapse.
9.1 Loss ComponentsΒΆ
Total loss:
L = Ξ» L_invariance + ΞΌ L_variance + Ξ½ L_covariance
1. Invariance: Pull positive pairs together:
L_invariance = E[||z_A - z_B||Β²]
2. Variance: Prevent collapse to constant:
L_variance = E[max(0, Ξ³ - βVar(z_d))]
Where z_d is dimension d, Ξ³ is target variance (e.g., 1).
3. Covariance: Decorrelate dimensions:
L_covariance = Ξ£_{dβ d'} [Cov(z_d, z_d')]Β²
9.2 AdvantagesΒΆ
No batch normalization needed: Variance term prevents collapse.
No asymmetric networks: Both views treated equally.
Stable training: Explicit constraints easier to optimize.
10. Theoretical AnalysisΒΆ
10.1 Contrastive Loss as Metric LearningΒΆ
InfoNCE approximates mutual information:
I(X; Y) β₯ log(K) - L_InfoNCE
Where K is number of negatives.
Insight: Minimizing contrastive loss β maximizing mutual information between views.
10.2 Uniformity-Alignment FrameworkΒΆ
Wang & Isola, 2020: Decompose contrastive learning.
Alignment (positive pairs):
L_align = E[(f(x) - f(xβΊ))Β²]
Uniformity (distribution on hypersphere):
L_uniform = log E[e^{-||f(x)-f(y)||Β²}]
Good representations: Low alignment loss + low uniformity loss.
10.3 Collapse and PreventionΒΆ
Dimensional collapse: Representations live in low-dimensional subspace.
Complete collapse: All representations identical.
Prevention mechanisms:
Contrastive: Negative pairs push apart
Batch normalization: Normalizes per dimension
Prediction head (BYOL): Asymmetry
Variance regularization (VICReg): Explicit constraint
Decorrelation (Barlow Twins): Cross-correlation penalty
10.4 Sample ComplexityΒΆ
Theorem [Arora et al., 2019]: For SimCLR, sample complexity is:
O(k/Ρ² · log(d))
Where:
k: Number of augmentations
Ξ΅: Desired accuracy
d: Embedding dimension
Implication: More augmentations β fewer samples needed.
11. Practical ConsiderationsΒΆ
11.1 HyperparametersΒΆ
Critical hyperparameters:
Parameter |
Typical Range |
Notes |
|---|---|---|
Batch size |
256-8192 |
Larger better (more negatives) |
Temperature Ο |
0.07-0.5 |
Lower β harder negatives |
Learning rate |
0.3-1.0 |
Linear scaling with batch size |
Projection dim |
128-2048 |
128 often sufficient |
Hidden dim |
2048-4096 |
MLP projection head |
Epochs |
100-1000 |
Longer training helps |
11.2 AugmentationsΒΆ
Effectiveness ranking (SimCLR):
Random crop + resize (most important)
Color jitter
Gaussian blur
Random flip
Grayscale
Composition matters: Multiple augmentations crucial.
11.3 Backbone ArchitectureΒΆ
CNNs: ResNet-50, ResNet-101
Standard for ImageNet
Well-understood
Vision Transformers (ViT):
Better with large datasets
DINO, MoCo v3 show strong results
Requires careful tuning
11.4 Training StabilityΒΆ
Common issues:
Collapse: All representations identical
Solution: Check variance, adjust batch norm, use regularization
Slow convergence: Insufficient batch size or learning rate
Solution: Increase batch size, tune LR
Gradient explosion: Large temperature or learning rate
Solution: Gradient clipping, lower Ο or LR
Monitoring: Track alignment and uniformity metrics during training.
12. Downstream EvaluationΒΆ
12.1 Linear Evaluation ProtocolΒΆ
Standard benchmark:
Freeze pre-trained encoder f
Train linear classifier on top: W Β· f(x)
Evaluate on test set
Metrics: Top-1 and Top-5 accuracy on ImageNet.
12.2 Transfer LearningΒΆ
Fine-tuning:
Initialize with pre-trained weights
Fine-tune entire network on target task
Lower learning rate for pre-trained layers
Typical speedup: 2-10Γ fewer labels needed vs. training from scratch.
12.3 Few-Shot LearningΒΆ
k-NN evaluation: Classify test samples via nearest neighbors in embedding space.
Procedure:
Encode train set: {f(xα΅’)}
Encode test sample: f(x_test)
Find k nearest neighbors
Majority vote for prediction
Performance: Good k-NN accuracy indicates quality representations.
13. Comparison of MethodsΒΆ
13.1 Performance Summary (ImageNet Top-1 Linear Eval)ΒΆ
Method |
Backbone |
Epochs |
Batch Size |
Accuracy |
|---|---|---|---|---|
SimCLR |
ResNet-50 |
1000 |
4096 |
69.3% |
MoCo v2 |
ResNet-50 |
800 |
256 |
71.1% |
BYOL |
ResNet-50 |
1000 |
4096 |
74.3% |
SwAV |
ResNet-50 |
800 |
4096 |
75.3% |
Barlow Twins |
ResNet-50 |
1000 |
2048 |
73.2% |
VICReg |
ResNet-50 |
1000 |
2048 |
73.2% |
DINO |
ViT-B/16 |
300 |
1024 |
78.2% |
Trends:
Clustering methods (SwAV, DINO) perform well
ViT backbones outperform CNNs
Negative-free methods (BYOL, SwAV) competitive
13.2 Computational CostΒΆ
Training time (100 epochs, ImageNet, 8 GPUs):
SimCLR: ~24 hours
MoCo v2: ~18 hours (queue more efficient)
BYOL: ~20 hours
SwAV: ~15 hours (multi-crop efficient)
Memory: Large batch sizes require distributed training (gradient accumulation or multi-GPU).
14. Advanced TopicsΒΆ
14.1 Region-Level Contrastive LearningΒΆ
Dense contrastive: Learn pixel/region representations.
DenseCL [Wang et al., 2021]:
Extract dense features from feature maps
Contrastive loss on corresponding regions
Applications: Detection, segmentation.
14.2 Graph Contrastive LearningΒΆ
Graph-level: Contrast graph augmentations.
Node-level: Contrast node embeddings.
Augmentations: Node/edge dropping, subgraph sampling, feature masking.
14.3 Temporal Contrastive LearningΒΆ
Video: Contrast different clips from same video.
Time series: Contrast different windows.
Augmentations: Temporal cropping, speed variation, frame dropping.
14.4 Multimodal Contrastive LearningΒΆ
CLIP [Radford et al., 2021]: Image-text contrastive learning.
Contrastive loss:
L = -log[exp(sim(I, TβΊ)/Ο) / Ξ£β±Ό exp(sim(I, Tβ±Ό)/Ο)]
Where I is image, TβΊ is matching text, Tβ±Ό are all texts in batch.
Applications: Zero-shot classification, image retrieval, generation.
15. Recent Advances (2022-2024)ΒΆ
15.1 Masked Image Modeling (MIM)ΒΆ
MAE [He et al., 2022]: Mask patches, reconstruct pixels.
SimMIM [Xie et al., 2022]: Simpler masking strategy.
Relation to contrastive: MIM is generative, contrastive is discriminative.
15.2 Joint Embedding Predictive Architecture (JEPA)ΒΆ
I-JEPA [Assran et al., 2023]: Predict representations of masked regions.
Difference from MAE: Predict in latent space, not pixel space.
15.3 Contrastive Language-Image Pre-training at ScaleΒΆ
CLIP variants: OpenCLIP, ALIGN, BASIC
Billions of image-text pairs
Zero-shot transfer to many tasks
15.4 Self-Supervised Vision TransformersΒΆ
DINOv2 [Oquab et al., 2023]:
Scaled to 142M images
Strong zero-shot and fine-tuning performance
Dense prediction capabilities
16. Limitations and Future DirectionsΒΆ
16.1 Current LimitationsΒΆ
Computational cost: Requires large batches, long training
Augmentation sensitivity: Performance depends on augmentation choice
Domain-specific: What works for images may not for other modalities
Theoretical gaps: Why do negative-free methods work?
16.2 Open QuestionsΒΆ
Optimal augmentations: Can we learn augmentation policies?
Small batch training: How to train with limited compute?
Generalization bounds: Formal guarantees?
Multimodal learning: Unified framework for vision, language, audio?
16.3 Future DirectionsΒΆ
Efficient training: Reduce batch size/epoch requirements
Automated augmentation: Neural augmentation search
Unified frameworks: Combine contrastive + generative
Theoretical understanding: Why these methods work
17. Key TakeawaysΒΆ
Contrastive learning learns representations by pulling positives together, pushing negatives apart
InfoNCE loss is the most common objective, requires large batch sizes
SimCLR: Strong augmentations + large batches + projection head
MoCo: Memory queue decouples batch size from number of negatives
BYOL: No negatives needed, uses predictor + momentum encoder
SwAV: Clustering-based, no negatives, multi-crop efficient
DINO: Self-distillation, works well with ViT
Barlow Twins/VICReg: Redundancy reduction, explicit regularization
Augmentations crucial: Random crop most important
Linear evaluation: Standard protocol for comparing methods
Core insight: Learning invariance to data augmentations produces powerful, general-purpose representations.
ReferencesΒΆ
Chen et al. (2020) βA Simple Framework for Contrastive Learning of Visual Representations (SimCLR)β
He et al. (2020) βMomentum Contrast for Unsupervised Visual Representation Learning (MoCo)β
Grill et al. (2020) βBootstrap Your Own Latent: A New Approach to Self-Supervised Learning (BYOL)β
Caron et al. (2020) βUnsupervised Learning of Visual Features by Contrasting Cluster Assignments (SwAV)β
Caron et al. (2021) βEmerging Properties in Self-Supervised Vision Transformers (DINO)β
Zbontar et al. (2021) βBarlow Twins: Self-Supervised Learning via Redundancy Reductionβ
Bardes et al. (2022) βVICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learningβ
Wang & Isola (2020) βUnderstanding Contrastive Representation Learning through Alignment and Uniformityβ
Radford et al. (2021) βLearning Transferable Visual Models From Natural Language Supervision (CLIP)β
He et al. (2022) βMasked Autoencoders Are Scalable Vision Learners (MAE)β
"""
Complete Contrastive Learning Implementations
=============================================
Includes: SimCLR, MoCo, BYOL, Barlow Twins, VICReg, NT-Xent loss,
data augmentations, evaluation protocols.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import copy
# ============================================================================
# 1. Data Augmentations
# ============================================================================
class ContrastiveAugmentation:
"""
SimCLR-style augmentation pipeline.
Applies composition of:
- Random crop + resize
- Color jitter
- Gaussian blur
- Random horizontal flip
- Grayscale (optional)
"""
def __init__(self, image_size=32, s=0.5):
"""
Args:
image_size: Target image size
s: Strength of color distortion
"""
# Color jitter
color_jitter = transforms.ColorJitter(
brightness=0.8*s, contrast=0.8*s,
saturation=0.8*s, hue=0.2*s
)
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
"""Apply augmentation twice to get two views."""
return self.train_transform(x), self.train_transform(x)
# ============================================================================
# 2. Projection Head
# ============================================================================
class ProjectionHead(nn.Module):
"""
MLP projection head for contrastive learning.
Standard: 2-layer MLP with ReLU.
"""
def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
super(ProjectionHead, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# ============================================================================
# 3. NT-Xent (InfoNCE) Loss
# ============================================================================
class NTXentLoss(nn.Module):
"""
Normalized Temperature-scaled Cross Entropy Loss.
L = -log[exp(sim(z_i, z_j)/Ο) / Ξ£_k exp(sim(z_i, z_k)/Ο)]
Args:
temperature: Temperature parameter Ο
use_cosine_similarity: If True, use cosine similarity
"""
def __init__(self, temperature=0.5):
super(NTXentLoss, self).__init__()
self.temperature = temperature
def forward(self, z_i, z_j):
"""
Args:
z_i, z_j: Embeddings from two views [batch, dim]
Returns:
loss: NT-Xent loss
"""
batch_size = z_i.shape[0]
# Normalize embeddings
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)
# Concatenate embeddings
z = torch.cat([z_i, z_j], dim=0) # [2*batch, dim]
# Compute similarity matrix
sim = torch.mm(z, z.T) / self.temperature # [2*batch, 2*batch]
# Create positive mask
batch_size = z_i.shape[0]
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
# Positive pairs: (i, i+batch) and (i+batch, i)
pos_mask = mask.roll(shifts=batch_size, dims=0)
# Remove self-similarity
sim = sim.masked_fill(mask, -9e15)
# Positive similarities
pos_sim = sim[pos_mask].view(2 * batch_size, 1)
# Compute log probabilities
log_prob = pos_sim - torch.logsumexp(sim, dim=1, keepdim=True)
# Mean over all samples
loss = -log_prob.mean()
return loss
# ============================================================================
# 4. SimCLR
# ============================================================================
class SimCLR(nn.Module):
"""
SimCLR: Simple Framework for Contrastive Learning.
Components:
- Encoder f (e.g., ResNet)
- Projection head g
- NT-Xent loss
"""
def __init__(self, encoder, projection_dim=128, temperature=0.5):
super(SimCLR, self).__init__()
self.encoder = encoder
# Get encoder output dimension
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32)
encoder_dim = encoder(dummy).shape[1]
# Projection head
self.projection = ProjectionHead(
input_dim=encoder_dim,
hidden_dim=encoder_dim,
output_dim=projection_dim
)
# Loss
self.criterion = NTXentLoss(temperature=temperature)
def forward(self, x_i, x_j):
"""
Args:
x_i, x_j: Two augmented views [batch, C, H, W]
Returns:
loss: Contrastive loss
"""
# Encode
h_i = self.encoder(x_i)
h_j = self.encoder(x_j)
# Project
z_i = self.projection(h_i)
z_j = self.projection(h_j)
# Compute loss
loss = self.criterion(z_i, z_j)
return loss
def get_representation(self, x):
"""Get representation (for downstream tasks)."""
with torch.no_grad():
h = self.encoder(x)
return h
# ============================================================================
# 5. MoCo (Momentum Contrast)
# ============================================================================
class MoCo(nn.Module):
"""
Momentum Contrast.
Uses:
- Query encoder (updated via backprop)
- Key encoder (updated via momentum)
- Queue of keys (negatives)
"""
def __init__(self, encoder, projection_dim=128, queue_size=65536,
momentum=0.999, temperature=0.07):
super(MoCo, self).__init__()
self.queue_size = queue_size
self.momentum = momentum
self.temperature = temperature
# Query encoder
self.encoder_q = encoder
# Get encoder output dimension
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32)
encoder_dim = encoder(dummy).shape[1]
self.projection_q = ProjectionHead(
input_dim=encoder_dim,
output_dim=projection_dim
)
# Key encoder (no gradient)
self.encoder_k = copy.deepcopy(encoder)
self.projection_k = copy.deepcopy(self.projection_q)
for param in self.encoder_k.parameters():
param.requires_grad = False
for param in self.projection_k.parameters():
param.requires_grad = False
# Queue
self.register_buffer("queue", torch.randn(projection_dim, queue_size))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update(self):
"""Update key encoder via momentum."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + \
param_q.data * (1.0 - self.momentum)
for param_q, param_k in zip(self.projection_q.parameters(),
self.projection_k.parameters()):
param_k.data = param_k.data * self.momentum + \
param_q.data * (1.0 - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Update queue with new keys."""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# Replace oldest keys
if ptr + batch_size <= self.queue_size:
self.queue[:, ptr:ptr + batch_size] = keys.T
else:
# Wrap around
remaining = self.queue_size - ptr
self.queue[:, ptr:] = keys[:remaining].T
self.queue[:, :batch_size - remaining] = keys[remaining:].T
# Update pointer
ptr = (ptr + batch_size) % self.queue_size
self.queue_ptr[0] = ptr
def forward(self, x_q, x_k):
"""
Args:
x_q: Query images [batch, C, H, W]
x_k: Key images [batch, C, H, W]
Returns:
loss: Contrastive loss
"""
# Query embeddings
q = self.projection_q(self.encoder_q(x_q))
q = F.normalize(q, dim=1)
# Key embeddings (no gradient)
with torch.no_grad():
self._momentum_update()
k = self.projection_k(self.encoder_k(x_k))
k = F.normalize(k, dim=1)
# Positive logits: [batch, 1]
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: [batch, queue_size]
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits: [batch, 1+queue_size]
logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# Cross-entropy loss
loss = F.cross_entropy(logits, labels)
# Update queue
self._dequeue_and_enqueue(k)
return loss
# ============================================================================
# 6. BYOL (Bootstrap Your Own Latent)
# ============================================================================
class BYOL(nn.Module):
"""
Bootstrap Your Own Latent.
No negative pairs needed!
Uses predictor asymmetry.
"""
def __init__(self, encoder, projection_dim=256, hidden_dim=4096,
momentum=0.996):
super(BYOL, self).__init__()
self.momentum = momentum
# Online network
self.online_encoder = encoder
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32)
encoder_dim = encoder(dummy).shape[1]
self.online_projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, projection_dim)
)
# Target network (no gradient)
self.target_encoder = copy.deepcopy(encoder)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def _update_target(self):
"""Update target network via EMA."""
for param_o, param_t in zip(self.online_encoder.parameters(),
self.target_encoder.parameters()):
param_t.data = param_t.data * self.momentum + \
param_o.data * (1.0 - self.momentum)
for param_o, param_t in zip(self.online_projector.parameters(),
self.target_projector.parameters()):
param_t.data = param_t.data * self.momentum + \
param_o.data * (1.0 - self.momentum)
def forward(self, x1, x2):
"""
Args:
x1, x2: Two augmented views
Returns:
loss: BYOL loss
"""
# Online network
y1 = self.online_encoder(x1)
y2 = self.online_encoder(x2)
z1 = self.online_projector(y1)
z2 = self.online_projector(y2)
p1 = self.online_predictor(z1)
p2 = self.online_predictor(z2)
# Target network (no gradient)
with torch.no_grad():
self._update_target()
t1 = self.target_projector(self.target_encoder(x1))
t2 = self.target_projector(self.target_encoder(x2))
# Normalize
p1 = F.normalize(p1, dim=1)
p2 = F.normalize(p2, dim=1)
t1 = F.normalize(t1, dim=1)
t2 = F.normalize(t2, dim=1)
# Loss: mean squared error
loss = (2 - 2 * (p1 * t2).sum(dim=1)).mean() + \
(2 - 2 * (p2 * t1).sum(dim=1)).mean()
return loss / 2
# ============================================================================
# 7. Barlow Twins
# ============================================================================
class BarlowTwins(nn.Module):
"""
Barlow Twins: Redundancy reduction via cross-correlation.
Loss:
L = Ξ£α΅’ (1 - C_ii)Β² + Ξ» Ξ£α΅’ Ξ£β±Όβ α΅’ C_ijΒ²
"""
def __init__(self, encoder, projection_dim=2048, lambd=5e-3):
super(BarlowTwins, self).__init__()
self.lambd = lambd
self.encoder = encoder
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32)
encoder_dim = encoder(dummy).shape[1]
# Projection head (3 layers for Barlow Twins)
self.projector = nn.Sequential(
nn.Linear(encoder_dim, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, projection_dim)
)
# Batch normalization for projections
self.bn = nn.BatchNorm1d(projection_dim, affine=False)
def forward(self, x1, x2):
"""
Args:
x1, x2: Two augmented views
Returns:
loss: Barlow Twins loss
"""
# Encode and project
z1 = self.projector(self.encoder(x1))
z2 = self.projector(self.encoder(x2))
# Normalize
z1 = self.bn(z1)
z2 = self.bn(z2)
# Cross-correlation matrix
batch_size = z1.shape[0]
c = (z1.T @ z2) / batch_size # [dim, dim]
# Loss
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = c.pow_(2).sum() - torch.diagonal(c).pow_(2).sum()
loss = on_diag + self.lambd * off_diag
return loss
# ============================================================================
# 8. VICReg (Variance-Invariance-Covariance)
# ============================================================================
class VICReg(nn.Module):
"""
VICReg: Explicit regularization to prevent collapse.
Loss = Ξ» * invariance + ΞΌ * variance + Ξ½ * covariance
"""
def __init__(self, encoder, projection_dim=2048,
sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0):
super(VICReg, self).__init__()
self.sim_coeff = sim_coeff
self.std_coeff = std_coeff
self.cov_coeff = cov_coeff
self.encoder = encoder
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32)
encoder_dim = encoder(dummy).shape[1]
# Projection head
self.projector = nn.Sequential(
nn.Linear(encoder_dim, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, projection_dim)
)
def forward(self, x1, x2):
"""
Args:
x1, x2: Two augmented views
Returns:
loss: VICReg loss
"""
# Encode and project
z1 = self.projector(self.encoder(x1))
z2 = self.projector(self.encoder(x2))
# Invariance loss (MSE)
sim_loss = F.mse_loss(z1, z2)
# Variance loss
std_z1 = torch.sqrt(z1.var(dim=0) + 1e-4)
std_z2 = torch.sqrt(z2.var(dim=0) + 1e-4)
std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
# Covariance loss
z1 = z1 - z1.mean(dim=0)
z2 = z2 - z2.mean(dim=0)
cov_z1 = (z1.T @ z1) / (z1.shape[0] - 1)
cov_z2 = (z2.T @ z2) / (z2.shape[0] - 1)
cov_loss = (cov_z1.pow_(2).sum() - torch.diagonal(cov_z1).pow_(2).sum()) / z1.shape[1] + \
(cov_z2.pow_(2).sum() - torch.diagonal(cov_z2).pow_(2).sum()) / z2.shape[1]
# Total loss
loss = self.sim_coeff * sim_loss + \
self.std_coeff * std_loss + \
self.cov_coeff * cov_loss
return loss
# ============================================================================
# 9. Evaluation: Linear Classifier
# ============================================================================
class LinearClassifier(nn.Module):
"""Linear classifier for evaluation."""
def __init__(self, input_dim, num_classes):
super(LinearClassifier, self).__init__()
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x):
return self.fc(x)
def linear_evaluation(encoder, train_loader, test_loader, num_classes,
device='cuda', epochs=100):
"""
Linear evaluation protocol.
Freeze encoder, train linear classifier.
"""
# Freeze encoder
encoder.eval()
for param in encoder.parameters():
param.requires_grad = False
# Get encoder output dimension
with torch.no_grad():
dummy = torch.randn(1, 3, 32, 32).to(device)
encoder_dim = encoder(dummy).shape[1]
# Linear classifier
classifier = LinearClassifier(encoder_dim, num_classes).to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Train
for epoch in range(epochs):
classifier.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Extract features
with torch.no_grad():
features = encoder(images)
# Classify
outputs = classifier(features)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 20 == 0:
print(f" Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")
# Evaluate
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
features = encoder(images)
outputs = classifier(features)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100.0 * correct / total
return accuracy
# ============================================================================
# 10. Method Comparison
# ============================================================================
def print_method_comparison():
"""Print comparison of contrastive learning methods."""
print("="*70)
print("Contrastive Learning Methods Comparison")
print("="*70)
print()
comparison = """
ββββββββββββββββ¬βββββββββββββ¬βββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Negatives? β Asymmetry? β Batch Size β Key Feature β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β SimCLR β Yes (many) β No β Large (4096) β Augmentationsβ
β β β β β + projection β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β MoCo β Yes (queue)β Momentum β Medium (256) β Memory queue β
β β β β β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β BYOL β No β Predictor β Medium β No negatives β
β β β β β + EMA β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β SwAV β No β Prototypes β Medium β Clustering β
β β β β β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β Barlow Twins β No β No β Medium β Decorrelationβ
β β β β β β
ββββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β VICReg β No β No β Medium β Explicit β
β β β β β regulariz. β
ββββββββββββββββ΄βββββββββββββ΄βββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
**Loss Functions:**
- **InfoNCE/NT-Xent**: -log[exp(sim(z,zβΊ)/Ο) / Ξ£β±Ό exp(sim(z,zβ±Ό)/Ο)]
- **BYOL**: MSE(predictor(zβ), sg(zβ))
- **Barlow Twins**: Ξ£α΅’(1-C_ii)Β² + λΣᡒΣⱼβ α΅’ C_ijΒ²
- **VICReg**: λ·MSE + μ·variance_loss + ν·covariance_loss
**Typical Hyperparameters:**
Method | Batch Size | Temperature | Epochs | LR | Aug Strength
-------------|------------|-------------|--------|---------|-------------
SimCLR | 4096 | 0.07-0.5 | 1000 | 0.3-1.0 | High
MoCo v2 | 256 | 0.2 | 800 | 0.03 | Medium
BYOL | 4096 | N/A | 1000 | 0.2 | Medium
Barlow Twins | 2048 | N/A | 1000 | 0.2 | High
VICReg | 2048 | N/A | 1000 | 0.2 | High
**Computational Cost (relative to supervised):**
- SimCLR: ~3-5Γ (large batches)
- MoCo: ~2-3Γ (queue efficient)
- BYOL: ~3-4Γ (double forward pass)
- Barlow Twins: ~2-3Γ (simple loss)
**When to Use:**
- **SimCLR**: Have large GPU memory, want simplicity
- **MoCo**: Limited memory, want efficiency
- **BYOL**: Prefer simple training (no negative sampling)
- **Barlow Twins**: Interested in decorrelation perspective
- **VICReg**: Want explicit control over collapse prevention
"""
print(comparison)
print()
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
print("="*70)
print("Contrastive Learning Implementations")
print("="*70)
print()
# Simple encoder for demonstration
class SimpleEncoder(nn.Module):
def __init__(self):
super(SimpleEncoder, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 512)
)
def forward(self, x):
return self.net(x)
encoder = SimpleEncoder()
# Test each method
print("Testing SimCLR...")
simclr = SimCLR(encoder, projection_dim=128)
x1 = torch.randn(8, 3, 32, 32)
x2 = torch.randn(8, 3, 32, 32)
loss = simclr(x1, x2)
print(f" SimCLR loss: {loss.item():.4f}")
print()
print("Testing MoCo...")
moco = MoCo(SimpleEncoder(), projection_dim=128, queue_size=256)
loss = moco(x1, x2)
print(f" MoCo loss: {loss.item():.4f}")
print()
print("Testing BYOL...")
byol = BYOL(SimpleEncoder(), projection_dim=256)
loss = byol(x1, x2)
print(f" BYOL loss: {loss.item():.4f}")
print()
print("Testing Barlow Twins...")
barlow = BarlowTwins(SimpleEncoder(), projection_dim=256)
loss = barlow(x1, x2)
print(f" Barlow Twins loss: {loss.item():.4f}")
print()
print("Testing VICReg...")
vicreg = VICReg(SimpleEncoder(), projection_dim=256)
loss = vicreg(x1, x2)
print(f" VICReg loss: {loss.item():.4f}")
print()
print_method_comparison()
print("="*70)
print("Contrastive Learning Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ SimCLR: Large batches + augmentations + NT-Xent")
print(" β’ MoCo: Memory queue + momentum encoder")
print(" β’ BYOL: No negatives + predictor asymmetry")
print(" β’ Barlow Twins: Redundancy reduction via decorrelation")
print(" β’ VICReg: Explicit variance + invariance + covariance")
print()
print("Key insight: Learn representations via augmentation invariance")
print("Trade-off: Batch size vs. computational cost vs. performance")
print("Applications: Pre-training for transfer learning, few-shot")
print()