import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Contrastive Learning FrameworkΒΆ

GoalΒΆ

Learn representations where:

  • Similar samples are close

  • Dissimilar samples are far

NT-Xent LossΒΆ

\[\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp(\text{sim}(z_i, z_k)/\tau)}\]

where \(\text{sim}(u, v) = u^T v / (\|u\| \|v\|)\) (cosine similarity).

πŸ“š Reference Materials:

Data AugmentationΒΆ

Data augmentation is the cornerstone of contrastive self-supervised learning. Each training image is transformed by two different random augmentations (e.g., random crop, color jitter, Gaussian blur, horizontal flip) to create a positive pair. The contrastive objective then pulls these two views together in embedding space while pushing apart views from different images. The choice and strength of augmentations critically determines what invariances the model learns: augmentations should remove information that is irrelevant for downstream tasks while preserving semantically meaningful content. Too weak augmentations lead to trivial solutions; too strong augmentations destroy useful information.

class ContrastiveTransform:
    """Generate two augmented views."""
    
    def __init__(self, base_transform):
        self.base_transform = base_transform
    
    def __call__(self, x):
        return [self.base_transform(x), self.base_transform(x)]

# Strong augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load data
mnist = datasets.MNIST('./data', train=True, download=True, 
                       transform=ContrastiveTransform(train_transform))
dataloader = torch.utils.data.DataLoader(mnist, batch_size=256, shuffle=True, num_workers=2)

print(f"Dataset size: {len(mnist)}")

SimCLR ImplementationΒΆ

SimCLR (Simple Framework for Contrastive Learning of Visual Representations) consists of an encoder network (e.g., ResNet) followed by a small projection head MLP. For a batch of \(N\) images, each is augmented twice to produce \(2N\) views. The NT-Xent loss (Normalized Temperature-scaled Cross-Entropy) for a positive pair \((i, j)\) is: \(\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}\), where \(\text{sim}\) is cosine similarity and \(\tau\) is a temperature parameter. The projection head is discarded after pre-training – representations from the encoder backbone are used for downstream tasks, a design choice that significantly improves transfer performance.

class Encoder(nn.Module):
    """Simple CNN encoder."""
    
    def __init__(self, out_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU()
        )
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)
        )
    
    def forward(self, x):
        h = self.encoder(x)
        z = self.projection(h)
        return F.normalize(z, dim=1)

def nt_xent_loss(z_i, z_j, temperature=0.5):
    """NT-Xent (normalized temperature-scaled cross entropy) loss."""
    batch_size = z_i.size(0)
    
    # Concatenate
    z = torch.cat([z_i, z_j], dim=0)
    
    # Similarity matrix
    sim = torch.mm(z, z.T) / temperature
    
    # Mask for positive pairs
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    sim = sim.masked_fill(mask, -9e15)
    
    # Positive pairs indices
    pos_indices = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)
    
    # Loss
    loss = F.cross_entropy(sim, pos_indices)
    return loss

# Test
model = Encoder().to(device)
x1 = torch.randn(8, 1, 28, 28).to(device)
x2 = torch.randn(8, 1, 28, 28).to(device)
z1 = model(x1)
z2 = model(x2)
loss = nt_xent_loss(z1, z2)
print(f"Loss: {loss.item():.4f}")

Train SimCLRΒΆ

SimCLR training requires large batch sizes (the original paper uses 4096) because each sample needs many negative pairs within the batch for the contrastive loss to be informative. The loss is symmetric: both views in a positive pair serve as anchor and positive, doubling the effective training signal. The temperature \(\tau\) controls the sharpness of the similarity distribution – lower values make the model focus harder on distinguishing the closest negatives. Training for hundreds of epochs on unlabeled data produces representations that, when evaluated with a simple linear classifier, approach supervised learning performance.

def train_simclr(model, dataloader, n_epochs=10, temperature=0.5):
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    losses = []
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        for batch_idx, (views, _) in enumerate(dataloader):
            x_i, x_j = views[0].to(device), views[1].to(device)
            
            # Forward
            z_i = model(x_i)
            z_j = model(x_j)
            
            # Loss
            loss = nt_xent_loss(z_i, z_j, temperature)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            losses.append(loss.item())
        
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
    
    return losses

model = Encoder().to(device)
losses = train_simclr(model, dataloader, n_epochs=10)

plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.5)
plt.plot(np.convolve(losses, np.ones(20)/20, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('NT-Xent Loss', fontsize=11)
plt.title('SimCLR Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

MoCo (Momentum Contrast)ΒΆ

MoCo addresses SimCLR’s reliance on large batch sizes by maintaining a momentum-updated queue of negative embeddings. The query encoder is updated by gradient descent as usual, but the key encoder is updated as a slow-moving average: \(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\) with momentum \(m \approx 0.999\). The queue stores key embeddings from recent mini-batches, providing a large and consistent pool of negatives without requiring enormous batch sizes. This design decouples the number of negatives from the batch size, making MoCo practical on standard hardware while matching or exceeding SimCLR’s representation quality.

class MoCo(nn.Module):
    """Momentum Contrast."""
    
    def __init__(self, encoder_q, encoder_k, dim=128, K=4096, m=0.999, T=0.07):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T
        
        # Encoders
        self.encoder_q = encoder_q
        self.encoder_k = encoder_k
        
        # Initialize key encoder with query encoder
        for param_q, param_k in zip(self.encoder_q.parameters(), 
                                    self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
        # Queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
    
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """Momentum update."""
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                    self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Update queue."""
        batch_size = keys.shape[0]
        
        ptr = int(self.queue_ptr)
        
        # Replace oldest
        if ptr + batch_size <= self.K:
            self.queue[:, ptr:ptr + batch_size] = keys.T
            ptr = (ptr + batch_size) % self.K
        else:
            # Wrap around
            self.queue[:, ptr:] = keys[:self.K - ptr].T
            self.queue[:, :(ptr + batch_size) % self.K] = keys[self.K - ptr:].T
            ptr = (ptr + batch_size) % self.K
        
        self.queue_ptr[0] = ptr
    
    def forward(self, im_q, im_k):
        # Query
        q = self.encoder_q(im_q)
        
        # Key
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
        
        # Positive logits
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # Negative logits
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        # Logits
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        
        # Labels: positives are 0
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
        
        # Update queue
        self._dequeue_and_enqueue(k)
        
        return logits, labels

# Test
encoder_q = Encoder().to(device)
encoder_k = Encoder().to(device)
moco = MoCo(encoder_q, encoder_k, K=1024).to(device)

x_q = torch.randn(8, 1, 28, 28).to(device)
x_k = torch.randn(8, 1, 28, 28).to(device)
logits, labels = moco(x_q, x_k)
print(f"Logits: {logits.shape}, Labels: {labels.shape}")

Train MoCoΒΆ

MoCo training follows a similar loop to SimCLR but replaces the in-batch negatives with the queue. At each step, the query encoder processes one augmented view, the momentum encoder processes the other, and the contrastive loss is computed against the queue of recent key embeddings. After each step, the oldest entries in the queue are dequeued and the new key embeddings are enqueued. The slow momentum update of the key encoder ensures that queue entries remain approximately consistent, avoiding the representation drift that would occur if old and new keys came from very different encoders.

def train_moco(moco, dataloader, n_epochs=10):
    optimizer = torch.optim.SGD(moco.encoder_q.parameters(), lr=0.03, 
                                momentum=0.9, weight_decay=1e-4)
    losses = []
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        for batch_idx, (views, _) in enumerate(dataloader):
            im_q, im_k = views[0].to(device), views[1].to(device)
            
            logits, labels = moco(im_q, im_k)
            loss = F.cross_entropy(logits, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            losses.append(loss.item())
        
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
    
    return losses

encoder_q = Encoder().to(device)
encoder_k = Encoder().to(device)
moco = MoCo(encoder_q, encoder_k, K=2048).to(device)

losses_moco = train_moco(moco, dataloader, n_epochs=10)

plt.figure(figsize=(10, 5))
plt.plot(losses_moco, alpha=0.5)
plt.plot(np.convolve(losses_moco, np.ones(20)/20, mode='valid'), linewidth=2, label='Moving Avg')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('MoCo Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Linear EvaluationΒΆ

The standard evaluation protocol for self-supervised representations is linear evaluation: freeze the pre-trained encoder and train only a linear classifier (single fully connected layer) on labeled data. High linear evaluation accuracy indicates that the learned representations are linearly separable by class, meaning the contrastive pre-training has organized the feature space in a semantically meaningful way. Comparing linear evaluation accuracy between SimCLR, MoCo, and a randomly initialized baseline quantifies how much useful structure each method has extracted from unlabeled data.

# Freeze encoder, train linear classifier
model.eval()
for param in model.parameters():
    param.requires_grad = False

# Linear classifier
classifier = nn.Linear(128, 10).to(device)

# Normal MNIST for evaluation
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist_eval = datasets.MNIST('./data', train=True, download=True, transform=eval_transform)
eval_loader = torch.utils.data.DataLoader(mnist_eval, batch_size=256, shuffle=True)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(5):
    correct = 0
    total = 0
    
    for x, y in eval_loader:
        x, y = x.to(device), y.to(device)
        
        with torch.no_grad():
            features = model(x)
        
        logits = classifier(features)
        loss = F.cross_entropy(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    print(f"Epoch {epoch+1}, Accuracy: {100*correct/total:.2f}%")

SummaryΒΆ

SimCLR:ΒΆ

  • Large batch sizes

  • Strong augmentations

  • NT-Xent loss

  • Projection head

MoCo:ΒΆ

  • Momentum encoder \(\theta_k \leftarrow m\theta_k + (1-m)\theta_q\)

  • Queue for negatives

  • More memory efficient

Key Insights:ΒΆ

  1. Data augmentation crucial

  2. Large negative sets improve quality

  3. Projection head helps

  4. Temperature controls hardness

Applications:ΒΆ

  • Pre-training for downstream tasks

  • Transfer learning

  • Few-shot learning

  • Anomaly detection

Extensions:ΒΆ

  • SimCLRv2 (bigger models, distillation)

  • MoCov2/v3 (improvements)

  • BYOL (no negatives)

  • SwAV (clustering)

Next Steps:ΒΆ

  • 05_sentence_transformer_intro.ipynb - Text embeddings

  • Study BYOL and Barlow Twins

  • Explore multi-modal contrastive learning (CLIP)

Advanced Contrastive Learning TheoryΒΆ

1. Introduction to Contrastive LearningΒΆ

Contrastive learning is a self-supervised learning paradigm that learns representations by contrasting positive pairs against negative pairs.

1.1 Core IdeaΒΆ

Goal: Learn embeddings where:

  • Similar samples (positives) are close in embedding space

  • Dissimilar samples (negatives) are far apart

Mathematical formulation:

Maximize similarity(f(x), f(x⁺))
Minimize similarity(f(x), f(x⁻))

Where:

  • f: Encoder network

  • x⁺: Positive sample (augmentation of x, or semantically similar)

  • x⁻: Negative samples (different from x)

1.2 Why Contrastive Learning?ΒΆ

Advantages:

  1. Self-supervised: No labels needed (use augmentations)

  2. Generalization: Learns task-agnostic representations

  3. Data efficiency: Pre-training on unlabeled data, fine-tune with few labels

  4. SOTA performance: Rivals supervised learning on many tasks

Key insight: Augmentations of the same image should have similar representations, different images should not.

2. Contrastive Loss FunctionsΒΆ

2.1 InfoNCE Loss (NT-Xent)ΒΆ

Noise Contrastive Estimation [Oord et al., 2018]:

For anchor x with positive x⁺ and N-1 negatives {x⁻ᡒ}:

L = -log[exp(sim(z, z⁺)/Ο„) / Ξ£β±Ό exp(sim(z, zβ±Ό)/Ο„)]

Where:

  • z = f(x): Embedding of anchor

  • z⁺ = f(x⁺): Embedding of positive

  • zβ±Ό: Embeddings of all samples in batch (positives + negatives)

  • sim(u, v) = uα΅€v / (||u|| ||v||): Cosine similarity

  • Ο„: Temperature parameter

Temperature Ο„:

  • Low Ο„ β†’ Sharp distribution (high confidence)

  • High Ο„ β†’ Smooth distribution (less confident)

  • Typical: Ο„ = 0.07-0.5

Batch size importance: Large batches provide more negatives β†’ better training.

2.2 Triplet LossΒΆ

Formulation:

L = max(0, ||f(a) - f(p)||Β² - ||f(a) - f(n)||Β² + margin)

Where:

  • a: Anchor

  • p: Positive

  • n: Negative

  • margin: Minimum separation between positive and negative

Hard negative mining: Select negatives closer to anchor for harder learning.

2.3 N-Pair LossΒΆ

Extension of triplet to multiple negatives:

L = log(1 + Ξ£α΅’ exp(f(a)α΅€f(nα΅’) - f(a)α΅€f(p)))

Benefits: Better gradient signal from multiple negatives.

2.4 Supervised Contrastive LossΒΆ

SupCon [Khosla et al., 2020]: Use label information.

L = -Ξ£α΅’β‚Œβ‚^|P(i)| log[exp(zα΅’α΅€z / Ο„) / Ξ£β‚β‚Œβ‚^{2N} 1[aβ‰ i] exp(zα΅’α΅€zₐ / Ο„)]

Where P(i) are all positives with same label as i.

Advantage: Multiple positives per anchor (all samples with same label).

3. SimCLR FrameworkΒΆ

SimCLR [Chen et al., 2020] is a simple framework for contrastive learning.

3.1 ArchitectureΒΆ

Pipeline:

  1. Data augmentation: Generate two views x̃ᡒ, x̃ⱼ from x

  2. Encoder f(·): CNN backbone (e.g., ResNet-50) → hᡒ = f(x̃ᡒ)

  3. Projection head g(Β·): MLP β†’ zα΅’ = g(hα΅’)

  4. Contrastive loss: Apply NT-Xent on {zα΅’}

Key components:

  • Augmentation composition: Random crop, color jitter, Gaussian blur

  • Large batch size: 256-8192 (more negatives)

  • Projection head: 2-layer MLP (improves representation quality)

  • Temperature: Ο„ = 0.07-0.5

3.2 AugmentationsΒΆ

Strong augmentations crucial:

  • Random cropping + resizing (0.08-1.0 scale)

  • Color distortion (brightness, contrast, saturation, hue)

  • Gaussian blur (kernel size 10% of image)

  • Random horizontal flip

Composition: Apply multiple augmentations sequentially.

Finding: Stronger augmentations β†’ better representations (up to a point).

3.3 Training AlgorithmΒΆ

For each minibatch of N samples:
  1. Sample two augmentations for each: 2N samples total
  2. Encode: z₁, ..., zβ‚‚β‚™ = g(f(augment(x₁)), ..., g(f(augment(xβ‚‚β‚™)))
  3. For each pair (i, j) from same original image:
     - Positive: (zα΅’, zβ±Ό)
     - Negatives: All other 2(N-1) samples
     - Compute InfoNCE loss
  4. Update encoder f and projection g via gradient descent

Batch size scaling: Linear scaling rule for learning rate (lr = base_lr Γ— batch_size / 256).

3.4 Mathematical AnalysisΒΆ

Alignment: Pull positive pairs together:

l_align = E_p(x,x⁺) [||f(x) - f(x⁺)||²]

Uniformity: Spread features uniformly on hypersphere:

l_uniform = log E_x,y~p [exp(-||f(x) - f(y)||Β²)]

SimCLR optimizes: Balance between alignment and uniformity.

4. MoCo (Momentum Contrast)ΒΆ

MoCo [He et al., 2020] uses a memory bank and momentum encoder for efficient contrastive learning.

4.1 ArchitectureΒΆ

Two encoders:

  1. Query encoder f_q: Updated via backprop

  2. Key encoder f_k: Updated via momentum

Queue: Stores encoded keys (negatives) from previous batches.

4.2 Momentum UpdateΒΆ

ΞΈ_k ← m ΞΈ_k + (1-m) ΞΈ_q

Where:

  • ΞΈ_q: Parameters of query encoder

  • ΞΈ_k: Parameters of key encoder

  • m: Momentum coefficient (e.g., 0.999)

Benefit: Key encoder evolves slowly β†’ consistent keys in queue.

4.3 Queue MechanismΒΆ

Idea: Decouple batch size from number of negatives.

Operation:

  1. Encode queries: q = f_q(x_query)

  2. Encode keys: k = f_k(x_key)

  3. Enqueue current keys

  4. Dequeue oldest keys

  5. Compute contrastive loss: q vs. current k (positive) and queue (negatives)

Queue size: Typically 65,536 (much larger than batch size).

4.4 MoCo v2 and v3ΒΆ

MoCo v2 improvements:

  • MLP projection head (from SimCLR)

  • Stronger augmentations

  • Cosine learning rate schedule

MoCo v3 [Chen et al., 2021]:

  • Symmetric loss (both directions: qβ†’k and kβ†’q)

  • Vision Transformer (ViT) backbone

  • Prediction head (from BYOL)

Formula (MoCo v3):

L = L(q₁, kβ‚‚) + L(qβ‚‚, k₁)

Where q₁, qβ‚‚ are queries from two views, k₁, kβ‚‚ are keys.

5. BYOL (Bootstrap Your Own Latent)ΒΆ

BYOL [Grill et al., 2020]: No negative pairs needed!

5.1 ArchitectureΒΆ

Two networks:

  1. Online network: Encoder f_ΞΈ + Projector g_ΞΈ + Predictor h_ΞΈ

  2. Target network: Encoder f_ΞΎ + Projector g_ΞΎ (no predictor)

Asymmetry: Only online network has predictor.

5.2 AlgorithmΒΆ

For two augmented views x₁, xβ‚‚:

Online network:

y₁ = f_ΞΈ(x₁)
z₁ = g_ΞΈ(y₁)
p₁ = h_ΞΈ(z₁)

Target network:

yβ‚‚ = f_ΞΎ(xβ‚‚)
zβ‚‚ = g_ΞΎ(yβ‚‚)

Loss (mean squared error):

L = ||p₁ - sg(zβ‚‚)||Β² + ||pβ‚‚ - sg(z₁)||Β²

Where sg(Β·) is stop-gradient operator.

Target network update (exponential moving average):

ΞΎ ← τξ + (1-Ο„)ΞΈ

Typical: Ο„ = 0.996-0.999.

5.3 Why No Negatives?ΒΆ

Predictor asymmetry prevents collapse:

  • Online network learns to predict target

  • Target network evolves slowly

  • Without predictor, would collapse to constant

Stop-gradient crucial: Prevents trivial solution.

5.4 Theoretical InsightΒΆ

BYOL optimizes:

min_ΞΈ E_x,T₁,Tβ‚‚ [||h_ΞΈ(g_ΞΈ(f_ΞΈ(T₁(x)))) - g_ΞΎ(f_ΞΎ(Tβ‚‚(x)))||Β²]

Implicit regularization: Predictor prevents representation collapse.

6. SwAV (Swapped Assignments between Views)ΒΆ

SwAV [Caron et al., 2020]: Clustering-based contrastive learning.

6.1 Core IdeaΒΆ

Swap prediction: Predict cluster assignment of one view from another.

No negatives: Uses prototype vectors (cluster centers).

6.2 AlgorithmΒΆ

  1. Encode two views: z₁ = f(x₁), zβ‚‚ = f(xβ‚‚)

  2. Cluster assignment via Sinkhorn-Knopp:

    • Compute similarity to prototypes: C = {c₁, …, c_K}

    • Soft assignments: q₁ = softmax(z₁ᡀC/Ο„)

  3. Swap prediction loss:

    L = -q₁ᡀ log pβ‚‚ - qβ‚‚α΅€ log p₁
    

    Where pβ‚‚ = softmax(zβ‚‚α΅€C/Ο„)

6.3 Sinkhorn-Knopp AlgorithmΒΆ

Optimal transport: Enforce equal cluster sizes.

Iterations:

Repeat T times:
  Q ← Q / sum(Q, dim=0)  # Normalize columns
  Q ← Q / sum(Q, dim=1)  # Normalize rows

Ensures each prototype receives equal weight.

6.4 Multi-Crop StrategyΒΆ

Efficiency trick: Use crops of different sizes.

  • 2 global views (224Γ—224): Standard crops

  • V local views (96Γ—96): Smaller crops

Total: 2+V views, but local views only passed through encoder (cheaper).

Loss: Predict global from local and vice versa.

7. DINO (Self-Distillation with No Labels)ΒΆ

DINO [Caron et al., 2021]: Self-supervised learning via self-distillation.

7.1 ArchitectureΒΆ

Student-teacher framework:

  • Student network: Updated via gradients

  • Teacher network: EMA of student

Both networks: Vision Transformer (ViT) or CNN.

7.2 AlgorithmΒΆ

Multi-crop strategy: Global crops (2Γ—) + local crops (β‰₯8Γ—).

Cross-entropy loss:

L = -Ξ£α΅’ P_t(xα΅’) log P_s(xα΅’)

Where:

  • P_t: Teacher output (softmax over prototypes, high temperature Ο„_t)

  • P_s: Student output (softmax, low temperature Ο„_s)

Centering: Prevent collapse by subtracting mean:

P_t = softmax((g_t(x) - c) / Ο„_t)

Where c is EMA of batch means.

7.3 Key InnovationsΒΆ

  1. Temperature sharpening: Ο„_s < Ο„_t (student sharper than teacher)

  2. Centering: Prevent one prototype dominating

  3. Multi-crop: Global views for teacher, global+local for student

  4. No contrastive loss: Uses self-distillation instead

Finding: Works exceptionally well with Vision Transformers.

8. Barlow TwinsΒΆ

Barlow Twins [Zbontar et al., 2021]: Redundancy reduction objective.

8.1 ObjectiveΒΆ

Cross-correlation matrix between embeddings:

C_ij = Σ_b z_A^b_i z_B^b_j / √(Σ_b (z_A^b_i)²) √(Σ_b (z_B^b_j)²)

Where z_A, z_B are embeddings from two views.

Loss:

L = Ξ£α΅’ (1 - C_ii)Β² + Ξ» Ξ£α΅’ Ξ£β±Όβ‰ α΅’ C_ijΒ²
       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       Invariance     Redundancy reduction

Goal: Diagonal of C should be 1 (invariance), off-diagonal 0 (decorrelation).

8.2 InterpretationΒΆ

Barlow’s redundancy reduction principle: Remove redundant information.

Comparison to other methods:

  • SimCLR: Contrastive (push away negatives)

  • Barlow Twins: Decorrelation (remove redundancy)

No negatives needed: Only uses positive pairs.

9. VICReg (Variance-Invariance-Covariance Regularization)ΒΆ

VICReg [Bardes et al., 2022]: Explicit regularization to prevent collapse.

9.1 Loss ComponentsΒΆ

Total loss:

L = Ξ» L_invariance + ΞΌ L_variance + Ξ½ L_covariance

1. Invariance: Pull positive pairs together:

L_invariance = E[||z_A - z_B||Β²]

2. Variance: Prevent collapse to constant:

L_variance = E[max(0, γ - √Var(z_d))]

Where z_d is dimension d, Ξ³ is target variance (e.g., 1).

3. Covariance: Decorrelate dimensions:

L_covariance = Σ_{d≠d'} [Cov(z_d, z_d')]²

9.2 AdvantagesΒΆ

No batch normalization needed: Variance term prevents collapse.

No asymmetric networks: Both views treated equally.

Stable training: Explicit constraints easier to optimize.

10. Theoretical AnalysisΒΆ

10.1 Contrastive Loss as Metric LearningΒΆ

InfoNCE approximates mutual information:

I(X; Y) β‰₯ log(K) - L_InfoNCE

Where K is number of negatives.

Insight: Minimizing contrastive loss β†’ maximizing mutual information between views.

10.2 Uniformity-Alignment FrameworkΒΆ

Wang & Isola, 2020: Decompose contrastive learning.

Alignment (positive pairs):

L_align = E[(f(x) - f(x⁺))²]

Uniformity (distribution on hypersphere):

L_uniform = log E[e^{-||f(x)-f(y)||Β²}]

Good representations: Low alignment loss + low uniformity loss.

10.3 Collapse and PreventionΒΆ

Dimensional collapse: Representations live in low-dimensional subspace.

Complete collapse: All representations identical.

Prevention mechanisms:

  1. Contrastive: Negative pairs push apart

  2. Batch normalization: Normalizes per dimension

  3. Prediction head (BYOL): Asymmetry

  4. Variance regularization (VICReg): Explicit constraint

  5. Decorrelation (Barlow Twins): Cross-correlation penalty

10.4 Sample ComplexityΒΆ

Theorem [Arora et al., 2019]: For SimCLR, sample complexity is:

O(k/Ρ² · log(d))

Where:

  • k: Number of augmentations

  • Ξ΅: Desired accuracy

  • d: Embedding dimension

Implication: More augmentations β†’ fewer samples needed.

11. Practical ConsiderationsΒΆ

11.1 HyperparametersΒΆ

Critical hyperparameters:

Parameter

Typical Range

Notes

Batch size

256-8192

Larger better (more negatives)

Temperature Ο„

0.07-0.5

Lower β†’ harder negatives

Learning rate

0.3-1.0

Linear scaling with batch size

Projection dim

128-2048

128 often sufficient

Hidden dim

2048-4096

MLP projection head

Epochs

100-1000

Longer training helps

11.2 AugmentationsΒΆ

Effectiveness ranking (SimCLR):

  1. Random crop + resize (most important)

  2. Color jitter

  3. Gaussian blur

  4. Random flip

  5. Grayscale

Composition matters: Multiple augmentations crucial.

11.3 Backbone ArchitectureΒΆ

CNNs: ResNet-50, ResNet-101

  • Standard for ImageNet

  • Well-understood

Vision Transformers (ViT):

  • Better with large datasets

  • DINO, MoCo v3 show strong results

  • Requires careful tuning

11.4 Training StabilityΒΆ

Common issues:

  1. Collapse: All representations identical

    • Solution: Check variance, adjust batch norm, use regularization

  2. Slow convergence: Insufficient batch size or learning rate

    • Solution: Increase batch size, tune LR

  3. Gradient explosion: Large temperature or learning rate

    • Solution: Gradient clipping, lower Ο„ or LR

Monitoring: Track alignment and uniformity metrics during training.

12. Downstream EvaluationΒΆ

12.1 Linear Evaluation ProtocolΒΆ

Standard benchmark:

  1. Freeze pre-trained encoder f

  2. Train linear classifier on top: W Β· f(x)

  3. Evaluate on test set

Metrics: Top-1 and Top-5 accuracy on ImageNet.

12.2 Transfer LearningΒΆ

Fine-tuning:

  1. Initialize with pre-trained weights

  2. Fine-tune entire network on target task

  3. Lower learning rate for pre-trained layers

Typical speedup: 2-10Γ— fewer labels needed vs. training from scratch.

12.3 Few-Shot LearningΒΆ

k-NN evaluation: Classify test samples via nearest neighbors in embedding space.

Procedure:

  1. Encode train set: {f(xα΅’)}

  2. Encode test sample: f(x_test)

  3. Find k nearest neighbors

  4. Majority vote for prediction

Performance: Good k-NN accuracy indicates quality representations.

13. Comparison of MethodsΒΆ

13.1 Performance Summary (ImageNet Top-1 Linear Eval)ΒΆ

Method

Backbone

Epochs

Batch Size

Accuracy

SimCLR

ResNet-50

1000

4096

69.3%

MoCo v2

ResNet-50

800

256

71.1%

BYOL

ResNet-50

1000

4096

74.3%

SwAV

ResNet-50

800

4096

75.3%

Barlow Twins

ResNet-50

1000

2048

73.2%

VICReg

ResNet-50

1000

2048

73.2%

DINO

ViT-B/16

300

1024

78.2%

Trends:

  • Clustering methods (SwAV, DINO) perform well

  • ViT backbones outperform CNNs

  • Negative-free methods (BYOL, SwAV) competitive

13.2 Computational CostΒΆ

Training time (100 epochs, ImageNet, 8 GPUs):

  • SimCLR: ~24 hours

  • MoCo v2: ~18 hours (queue more efficient)

  • BYOL: ~20 hours

  • SwAV: ~15 hours (multi-crop efficient)

Memory: Large batch sizes require distributed training (gradient accumulation or multi-GPU).

14. Advanced TopicsΒΆ

14.1 Region-Level Contrastive LearningΒΆ

Dense contrastive: Learn pixel/region representations.

DenseCL [Wang et al., 2021]:

  • Extract dense features from feature maps

  • Contrastive loss on corresponding regions

Applications: Detection, segmentation.

14.2 Graph Contrastive LearningΒΆ

Graph-level: Contrast graph augmentations.

Node-level: Contrast node embeddings.

Augmentations: Node/edge dropping, subgraph sampling, feature masking.

14.3 Temporal Contrastive LearningΒΆ

Video: Contrast different clips from same video.

Time series: Contrast different windows.

Augmentations: Temporal cropping, speed variation, frame dropping.

14.4 Multimodal Contrastive LearningΒΆ

CLIP [Radford et al., 2021]: Image-text contrastive learning.

Contrastive loss:

L = -log[exp(sim(I, T⁺)/Ο„) / Ξ£β±Ό exp(sim(I, Tβ±Ό)/Ο„)]

Where I is image, T⁺ is matching text, Tⱼ are all texts in batch.

Applications: Zero-shot classification, image retrieval, generation.

15. Recent Advances (2022-2024)ΒΆ

15.1 Masked Image Modeling (MIM)ΒΆ

MAE [He et al., 2022]: Mask patches, reconstruct pixels.

SimMIM [Xie et al., 2022]: Simpler masking strategy.

Relation to contrastive: MIM is generative, contrastive is discriminative.

15.2 Joint Embedding Predictive Architecture (JEPA)ΒΆ

I-JEPA [Assran et al., 2023]: Predict representations of masked regions.

Difference from MAE: Predict in latent space, not pixel space.

15.3 Contrastive Language-Image Pre-training at ScaleΒΆ

CLIP variants: OpenCLIP, ALIGN, BASIC

  • Billions of image-text pairs

  • Zero-shot transfer to many tasks

15.4 Self-Supervised Vision TransformersΒΆ

DINOv2 [Oquab et al., 2023]:

  • Scaled to 142M images

  • Strong zero-shot and fine-tuning performance

  • Dense prediction capabilities

16. Limitations and Future DirectionsΒΆ

16.1 Current LimitationsΒΆ

  1. Computational cost: Requires large batches, long training

  2. Augmentation sensitivity: Performance depends on augmentation choice

  3. Domain-specific: What works for images may not for other modalities

  4. Theoretical gaps: Why do negative-free methods work?

16.2 Open QuestionsΒΆ

  1. Optimal augmentations: Can we learn augmentation policies?

  2. Small batch training: How to train with limited compute?

  3. Generalization bounds: Formal guarantees?

  4. Multimodal learning: Unified framework for vision, language, audio?

16.3 Future DirectionsΒΆ

  1. Efficient training: Reduce batch size/epoch requirements

  2. Automated augmentation: Neural augmentation search

  3. Unified frameworks: Combine contrastive + generative

  4. Theoretical understanding: Why these methods work

17. Key TakeawaysΒΆ

  1. Contrastive learning learns representations by pulling positives together, pushing negatives apart

  2. InfoNCE loss is the most common objective, requires large batch sizes

  3. SimCLR: Strong augmentations + large batches + projection head

  4. MoCo: Memory queue decouples batch size from number of negatives

  5. BYOL: No negatives needed, uses predictor + momentum encoder

  6. SwAV: Clustering-based, no negatives, multi-crop efficient

  7. DINO: Self-distillation, works well with ViT

  8. Barlow Twins/VICReg: Redundancy reduction, explicit regularization

  9. Augmentations crucial: Random crop most important

  10. Linear evaluation: Standard protocol for comparing methods

Core insight: Learning invariance to data augmentations produces powerful, general-purpose representations.

ReferencesΒΆ

  1. Chen et al. (2020) β€œA Simple Framework for Contrastive Learning of Visual Representations (SimCLR)”

  2. He et al. (2020) β€œMomentum Contrast for Unsupervised Visual Representation Learning (MoCo)”

  3. Grill et al. (2020) β€œBootstrap Your Own Latent: A New Approach to Self-Supervised Learning (BYOL)”

  4. Caron et al. (2020) β€œUnsupervised Learning of Visual Features by Contrasting Cluster Assignments (SwAV)”

  5. Caron et al. (2021) β€œEmerging Properties in Self-Supervised Vision Transformers (DINO)”

  6. Zbontar et al. (2021) β€œBarlow Twins: Self-Supervised Learning via Redundancy Reduction”

  7. Bardes et al. (2022) β€œVICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning”

  8. Wang & Isola (2020) β€œUnderstanding Contrastive Representation Learning through Alignment and Uniformity”

  9. Radford et al. (2021) β€œLearning Transferable Visual Models From Natural Language Supervision (CLIP)”

  10. He et al. (2022) β€œMasked Autoencoders Are Scalable Vision Learners (MAE)”

"""
Complete Contrastive Learning Implementations
=============================================
Includes: SimCLR, MoCo, BYOL, Barlow Twins, VICReg, NT-Xent loss,
data augmentations, evaluation protocols.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import copy

# ============================================================================
# 1. Data Augmentations
# ============================================================================

class ContrastiveAugmentation:
    """
    SimCLR-style augmentation pipeline.
    
    Applies composition of:
    - Random crop + resize
    - Color jitter
    - Gaussian blur
    - Random horizontal flip
    - Grayscale (optional)
    """
    def __init__(self, image_size=32, s=0.5):
        """
        Args:
            image_size: Target image size
            s: Strength of color distortion
        """
        # Color jitter
        color_jitter = transforms.ColorJitter(
            brightness=0.8*s, contrast=0.8*s, 
            saturation=0.8*s, hue=0.2*s
        )
        
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, x):
        """Apply augmentation twice to get two views."""
        return self.train_transform(x), self.train_transform(x)


# ============================================================================
# 2. Projection Head
# ============================================================================

class ProjectionHead(nn.Module):
    """
    MLP projection head for contrastive learning.
    
    Standard: 2-layer MLP with ReLU.
    """
    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
        super(ProjectionHead, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)


# ============================================================================
# 3. NT-Xent (InfoNCE) Loss
# ============================================================================

class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss.
    
    L = -log[exp(sim(z_i, z_j)/Ο„) / Ξ£_k exp(sim(z_i, z_k)/Ο„)]
    
    Args:
        temperature: Temperature parameter Ο„
        use_cosine_similarity: If True, use cosine similarity
    """
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, z_i, z_j):
        """
        Args:
            z_i, z_j: Embeddings from two views [batch, dim]
        
        Returns:
            loss: NT-Xent loss
        """
        batch_size = z_i.shape[0]
        
        # Normalize embeddings
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        
        # Concatenate embeddings
        z = torch.cat([z_i, z_j], dim=0)  # [2*batch, dim]
        
        # Compute similarity matrix
        sim = torch.mm(z, z.T) / self.temperature  # [2*batch, 2*batch]
        
        # Create positive mask
        batch_size = z_i.shape[0]
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        
        # Positive pairs: (i, i+batch) and (i+batch, i)
        pos_mask = mask.roll(shifts=batch_size, dims=0)
        
        # Remove self-similarity
        sim = sim.masked_fill(mask, -9e15)
        
        # Positive similarities
        pos_sim = sim[pos_mask].view(2 * batch_size, 1)
        
        # Compute log probabilities
        log_prob = pos_sim - torch.logsumexp(sim, dim=1, keepdim=True)
        
        # Mean over all samples
        loss = -log_prob.mean()
        
        return loss


# ============================================================================
# 4. SimCLR
# ============================================================================

class SimCLR(nn.Module):
    """
    SimCLR: Simple Framework for Contrastive Learning.
    
    Components:
    - Encoder f (e.g., ResNet)
    - Projection head g
    - NT-Xent loss
    """
    def __init__(self, encoder, projection_dim=128, temperature=0.5):
        super(SimCLR, self).__init__()
        
        self.encoder = encoder
        
        # Get encoder output dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 32, 32)
            encoder_dim = encoder(dummy).shape[1]
        
        # Projection head
        self.projection = ProjectionHead(
            input_dim=encoder_dim,
            hidden_dim=encoder_dim,
            output_dim=projection_dim
        )
        
        # Loss
        self.criterion = NTXentLoss(temperature=temperature)
    
    def forward(self, x_i, x_j):
        """
        Args:
            x_i, x_j: Two augmented views [batch, C, H, W]
        
        Returns:
            loss: Contrastive loss
        """
        # Encode
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)
        
        # Project
        z_i = self.projection(h_i)
        z_j = self.projection(h_j)
        
        # Compute loss
        loss = self.criterion(z_i, z_j)
        
        return loss
    
    def get_representation(self, x):
        """Get representation (for downstream tasks)."""
        with torch.no_grad():
            h = self.encoder(x)
        return h


# ============================================================================
# 5. MoCo (Momentum Contrast)
# ============================================================================

class MoCo(nn.Module):
    """
    Momentum Contrast.
    
    Uses:
    - Query encoder (updated via backprop)
    - Key encoder (updated via momentum)
    - Queue of keys (negatives)
    """
    def __init__(self, encoder, projection_dim=128, queue_size=65536,
                 momentum=0.999, temperature=0.07):
        super(MoCo, self).__init__()
        
        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature
        
        # Query encoder
        self.encoder_q = encoder
        
        # Get encoder output dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 32, 32)
            encoder_dim = encoder(dummy).shape[1]
        
        self.projection_q = ProjectionHead(
            input_dim=encoder_dim,
            output_dim=projection_dim
        )
        
        # Key encoder (no gradient)
        self.encoder_k = copy.deepcopy(encoder)
        self.projection_k = copy.deepcopy(self.projection_q)
        
        for param in self.encoder_k.parameters():
            param.requires_grad = False
        for param in self.projection_k.parameters():
            param.requires_grad = False
        
        # Queue
        self.register_buffer("queue", torch.randn(projection_dim, queue_size))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
    
    @torch.no_grad()
    def _momentum_update(self):
        """Update key encoder via momentum."""
        for param_q, param_k in zip(self.encoder_q.parameters(), 
                                     self.encoder_k.parameters()):
            param_k.data = param_k.data * self.momentum + \
                           param_q.data * (1.0 - self.momentum)
        
        for param_q, param_k in zip(self.projection_q.parameters(), 
                                     self.projection_k.parameters()):
            param_k.data = param_k.data * self.momentum + \
                           param_q.data * (1.0 - self.momentum)
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Update queue with new keys."""
        batch_size = keys.shape[0]
        
        ptr = int(self.queue_ptr)
        
        # Replace oldest keys
        if ptr + batch_size <= self.queue_size:
            self.queue[:, ptr:ptr + batch_size] = keys.T
        else:
            # Wrap around
            remaining = self.queue_size - ptr
            self.queue[:, ptr:] = keys[:remaining].T
            self.queue[:, :batch_size - remaining] = keys[remaining:].T
        
        # Update pointer
        ptr = (ptr + batch_size) % self.queue_size
        self.queue_ptr[0] = ptr
    
    def forward(self, x_q, x_k):
        """
        Args:
            x_q: Query images [batch, C, H, W]
            x_k: Key images [batch, C, H, W]
        
        Returns:
            loss: Contrastive loss
        """
        # Query embeddings
        q = self.projection_q(self.encoder_q(x_q))
        q = F.normalize(q, dim=1)
        
        # Key embeddings (no gradient)
        with torch.no_grad():
            self._momentum_update()
            
            k = self.projection_k(self.encoder_k(x_k))
            k = F.normalize(k, dim=1)
        
        # Positive logits: [batch, 1]
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # Negative logits: [batch, queue_size]
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        # Logits: [batch, 1+queue_size]
        logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
        
        # Labels: positives are at index 0
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
        
        # Cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        
        # Update queue
        self._dequeue_and_enqueue(k)
        
        return loss


# ============================================================================
# 6. BYOL (Bootstrap Your Own Latent)
# ============================================================================

class BYOL(nn.Module):
    """
    Bootstrap Your Own Latent.
    
    No negative pairs needed!
    Uses predictor asymmetry.
    """
    def __init__(self, encoder, projection_dim=256, hidden_dim=4096,
                 momentum=0.996):
        super(BYOL, self).__init__()
        
        self.momentum = momentum
        
        # Online network
        self.online_encoder = encoder
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 32, 32)
            encoder_dim = encoder(dummy).shape[1]
        
        self.online_projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim)
        )
        
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim)
        )
        
        # Target network (no gradient)
        self.target_encoder = copy.deepcopy(encoder)
        self.target_projector = copy.deepcopy(self.online_projector)
        
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    @torch.no_grad()
    def _update_target(self):
        """Update target network via EMA."""
        for param_o, param_t in zip(self.online_encoder.parameters(),
                                     self.target_encoder.parameters()):
            param_t.data = param_t.data * self.momentum + \
                           param_o.data * (1.0 - self.momentum)
        
        for param_o, param_t in zip(self.online_projector.parameters(),
                                     self.target_projector.parameters()):
            param_t.data = param_t.data * self.momentum + \
                           param_o.data * (1.0 - self.momentum)
    
    def forward(self, x1, x2):
        """
        Args:
            x1, x2: Two augmented views
        
        Returns:
            loss: BYOL loss
        """
        # Online network
        y1 = self.online_encoder(x1)
        y2 = self.online_encoder(x2)
        
        z1 = self.online_projector(y1)
        z2 = self.online_projector(y2)
        
        p1 = self.online_predictor(z1)
        p2 = self.online_predictor(z2)
        
        # Target network (no gradient)
        with torch.no_grad():
            self._update_target()
            
            t1 = self.target_projector(self.target_encoder(x1))
            t2 = self.target_projector(self.target_encoder(x2))
        
        # Normalize
        p1 = F.normalize(p1, dim=1)
        p2 = F.normalize(p2, dim=1)
        t1 = F.normalize(t1, dim=1)
        t2 = F.normalize(t2, dim=1)
        
        # Loss: mean squared error
        loss = (2 - 2 * (p1 * t2).sum(dim=1)).mean() + \
               (2 - 2 * (p2 * t1).sum(dim=1)).mean()
        
        return loss / 2


# ============================================================================
# 7. Barlow Twins
# ============================================================================

class BarlowTwins(nn.Module):
    """
    Barlow Twins: Redundancy reduction via cross-correlation.
    
    Loss:
    L = Ξ£α΅’ (1 - C_ii)Β² + Ξ» Ξ£α΅’ Ξ£β±Όβ‰ α΅’ C_ijΒ²
    """
    def __init__(self, encoder, projection_dim=2048, lambd=5e-3):
        super(BarlowTwins, self).__init__()
        
        self.lambd = lambd
        
        self.encoder = encoder
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 32, 32)
            encoder_dim = encoder(dummy).shape[1]
        
        # Projection head (3 layers for Barlow Twins)
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, projection_dim)
        )
        
        # Batch normalization for projections
        self.bn = nn.BatchNorm1d(projection_dim, affine=False)
    
    def forward(self, x1, x2):
        """
        Args:
            x1, x2: Two augmented views
        
        Returns:
            loss: Barlow Twins loss
        """
        # Encode and project
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Normalize
        z1 = self.bn(z1)
        z2 = self.bn(z2)
        
        # Cross-correlation matrix
        batch_size = z1.shape[0]
        c = (z1.T @ z2) / batch_size  # [dim, dim]
        
        # Loss
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = c.pow_(2).sum() - torch.diagonal(c).pow_(2).sum()
        
        loss = on_diag + self.lambd * off_diag
        
        return loss


# ============================================================================
# 8. VICReg (Variance-Invariance-Covariance)
# ============================================================================

class VICReg(nn.Module):
    """
    VICReg: Explicit regularization to prevent collapse.
    
    Loss = Ξ» * invariance + ΞΌ * variance + Ξ½ * covariance
    """
    def __init__(self, encoder, projection_dim=2048, 
                 sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0):
        super(VICReg, self).__init__()
        
        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff
        
        self.encoder = encoder
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 32, 32)
            encoder_dim = encoder(dummy).shape[1]
        
        # Projection head
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, projection_dim)
        )
    
    def forward(self, x1, x2):
        """
        Args:
            x1, x2: Two augmented views
        
        Returns:
            loss: VICReg loss
        """
        # Encode and project
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Invariance loss (MSE)
        sim_loss = F.mse_loss(z1, z2)
        
        # Variance loss
        std_z1 = torch.sqrt(z1.var(dim=0) + 1e-4)
        std_z2 = torch.sqrt(z2.var(dim=0) + 1e-4)
        std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
        
        # Covariance loss
        z1 = z1 - z1.mean(dim=0)
        z2 = z2 - z2.mean(dim=0)
        
        cov_z1 = (z1.T @ z1) / (z1.shape[0] - 1)
        cov_z2 = (z2.T @ z2) / (z2.shape[0] - 1)
        
        cov_loss = (cov_z1.pow_(2).sum() - torch.diagonal(cov_z1).pow_(2).sum()) / z1.shape[1] + \
                   (cov_z2.pow_(2).sum() - torch.diagonal(cov_z2).pow_(2).sum()) / z2.shape[1]
        
        # Total loss
        loss = self.sim_coeff * sim_loss + \
               self.std_coeff * std_loss + \
               self.cov_coeff * cov_loss
        
        return loss


# ============================================================================
# 9. Evaluation: Linear Classifier
# ============================================================================

class LinearClassifier(nn.Module):
    """Linear classifier for evaluation."""
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)


def linear_evaluation(encoder, train_loader, test_loader, num_classes, 
                       device='cuda', epochs=100):
    """
    Linear evaluation protocol.
    
    Freeze encoder, train linear classifier.
    """
    # Freeze encoder
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False
    
    # Get encoder output dimension
    with torch.no_grad():
        dummy = torch.randn(1, 3, 32, 32).to(device)
        encoder_dim = encoder(dummy).shape[1]
    
    # Linear classifier
    classifier = LinearClassifier(encoder_dim, num_classes).to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Train
    for epoch in range(epochs):
        classifier.train()
        total_loss = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Extract features
            with torch.no_grad():
                features = encoder(images)
            
            # Classify
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if (epoch + 1) % 20 == 0:
            print(f"  Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")
    
    # Evaluate
    classifier.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            features = encoder(images)
            outputs = classifier(features)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100.0 * correct / total
    return accuracy


# ============================================================================
# 10. Method Comparison
# ============================================================================

def print_method_comparison():
    """Print comparison of contrastive learning methods."""
    print("="*70)
    print("Contrastive Learning Methods Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method       β”‚ Negatives? β”‚ Asymmetry? β”‚ Batch Size   β”‚ Key Feature  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ SimCLR       β”‚ Yes (many) β”‚ No         β”‚ Large (4096) β”‚ Augmentationsβ”‚
β”‚              β”‚            β”‚            β”‚              β”‚ + projection β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ MoCo         β”‚ Yes (queue)β”‚ Momentum   β”‚ Medium (256) β”‚ Memory queue β”‚
β”‚              β”‚            β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ BYOL         β”‚ No         β”‚ Predictor  β”‚ Medium       β”‚ No negatives β”‚
β”‚              β”‚            β”‚            β”‚              β”‚ + EMA        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ SwAV         β”‚ No         β”‚ Prototypes β”‚ Medium       β”‚ Clustering   β”‚
β”‚              β”‚            β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Barlow Twins β”‚ No         β”‚ No         β”‚ Medium       β”‚ Decorrelationβ”‚
β”‚              β”‚            β”‚            β”‚              β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ VICReg       β”‚ No         β”‚ No         β”‚ Medium       β”‚ Explicit     β”‚
β”‚              β”‚            β”‚            β”‚              β”‚ regulariz.   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Loss Functions:**

- **InfoNCE/NT-Xent**: -log[exp(sim(z,z⁺)/Ο„) / Ξ£β±Ό exp(sim(z,zβ±Ό)/Ο„)]
- **BYOL**: MSE(predictor(z₁), sg(zβ‚‚))
- **Barlow Twins**: Ξ£α΅’(1-C_ii)Β² + λΣᡒΣⱼ≠ᡒ C_ijΒ²
- **VICReg**: λ·MSE + μ·variance_loss + ν·covariance_loss

**Typical Hyperparameters:**

Method       | Batch Size | Temperature | Epochs | LR      | Aug Strength
-------------|------------|-------------|--------|---------|-------------
SimCLR       | 4096       | 0.07-0.5    | 1000   | 0.3-1.0 | High
MoCo v2      | 256        | 0.2         | 800    | 0.03    | Medium
BYOL         | 4096       | N/A         | 1000   | 0.2     | Medium
Barlow Twins | 2048       | N/A         | 1000   | 0.2     | High
VICReg       | 2048       | N/A         | 1000   | 0.2     | High

**Computational Cost (relative to supervised):**

- SimCLR: ~3-5Γ— (large batches)
- MoCo: ~2-3Γ— (queue efficient)
- BYOL: ~3-4Γ— (double forward pass)
- Barlow Twins: ~2-3Γ— (simple loss)

**When to Use:**

- **SimCLR**: Have large GPU memory, want simplicity
- **MoCo**: Limited memory, want efficiency
- **BYOL**: Prefer simple training (no negative sampling)
- **Barlow Twins**: Interested in decorrelation perspective
- **VICReg**: Want explicit control over collapse prevention
"""
    
    print(comparison)
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    print("="*70)
    print("Contrastive Learning Implementations")
    print("="*70)
    print()
    
    # Simple encoder for demonstration
    class SimpleEncoder(nn.Module):
        def __init__(self):
            super(SimpleEncoder, self).__init__()
            self.net = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(128 * 8 * 8, 512)
            )
        
        def forward(self, x):
            return self.net(x)
    
    encoder = SimpleEncoder()
    
    # Test each method
    print("Testing SimCLR...")
    simclr = SimCLR(encoder, projection_dim=128)
    x1 = torch.randn(8, 3, 32, 32)
    x2 = torch.randn(8, 3, 32, 32)
    loss = simclr(x1, x2)
    print(f"  SimCLR loss: {loss.item():.4f}")
    print()
    
    print("Testing MoCo...")
    moco = MoCo(SimpleEncoder(), projection_dim=128, queue_size=256)
    loss = moco(x1, x2)
    print(f"  MoCo loss: {loss.item():.4f}")
    print()
    
    print("Testing BYOL...")
    byol = BYOL(SimpleEncoder(), projection_dim=256)
    loss = byol(x1, x2)
    print(f"  BYOL loss: {loss.item():.4f}")
    print()
    
    print("Testing Barlow Twins...")
    barlow = BarlowTwins(SimpleEncoder(), projection_dim=256)
    loss = barlow(x1, x2)
    print(f"  Barlow Twins loss: {loss.item():.4f}")
    print()
    
    print("Testing VICReg...")
    vicreg = VICReg(SimpleEncoder(), projection_dim=256)
    loss = vicreg(x1, x2)
    print(f"  VICReg loss: {loss.item():.4f}")
    print()
    
    print_method_comparison()
    
    print("="*70)
    print("Contrastive Learning Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ SimCLR: Large batches + augmentations + NT-Xent")
    print("  β€’ MoCo: Memory queue + momentum encoder")
    print("  β€’ BYOL: No negatives + predictor asymmetry")
    print("  β€’ Barlow Twins: Redundancy reduction via decorrelation")
    print("  β€’ VICReg: Explicit variance + invariance + covariance")
    print()
    print("Key insight: Learn representations via augmentation invariance")
    print("Trade-off: Batch size vs. computational cost vs. performance")
    print("Applications: Pre-training for transfer learning, few-shot")
    print()