Vision Transformers (ViT)ΒΆ

Learning Objectives:

  • Understand patch embeddings and positional encoding

  • Implement self-attention for images

  • Train ViT on image classification

  • Compare with CNNs

Prerequisites: Transformers, self-attention, deep learning

Time: 90 minutes

πŸ“š Reference Materials:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. From Images to SequencesΒΆ

ChallengeΒΆ

Transformers expect sequences as input, but images are 2D grids.

ViT Solution: Patch EmbeddingΒΆ

  1. Split image into fixed-size patches (e.g., 16Γ—16)

  2. Flatten each patch into a vector

  3. Linearly project to embedding dimension

  4. Add positional encoding

  5. Prepend [CLS] token for classification

Mathematical FormulationΒΆ

Image \(x \in \mathbb{R}^{H \times W \times C}\) with patch size \(P\):

  • Number of patches: \(N = \frac{HW}{P^2}\)

  • Patch sequence: \(x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}\)

  • Embedded patches: \(z_0 = [x_{cls}; x_p^1 E; x_p^2 E; \ldots; x_p^N E] + E_{pos}\)

where \(E \in \mathbb{R}^{(P^2 \cdot C) \times D}\) is learnable projection.

1.5. Patch Embedding: Mathematical DetailsΒΆ

Why Patches?ΒΆ

Computational Complexity:

  • Per-pixel attention: \(O((HW)^2 \cdot D)\) - prohibitive!

  • Patch-based: \(O(N^2 \cdot D)\) where \(N = HW/P^2\) - tractable!

For 224Γ—224 image with \(P=16\):

  • Pixels: \(224^2 = 50{,}176\) β†’ Attention: \(50{,}176^2 \approx 2.5\) billion

  • Patches: \((224/16)^2 = 196\) β†’ Attention: \(196^2 \approx 38{,}000\)

~65,000Γ— reduction in complexity! βœ…

Patch Size AnalysisΒΆ

Trade-offs:

Patch Size

# Patches

Pros

Cons

4Γ—4

\(56^2 = 3{,}136\)

Fine-grained details

High computation, overfitting

8Γ—8

\(28^2 = 784\)

Good detail

Moderate computation

16Γ—16

\(14^2 = 196\)

Standard, efficient

May miss fine details

32Γ—32

\(7^2 = 49\)

Very fast

Loses too much detail

Standard choice: \(P=16\) balances performance and efficiency

Positional Encoding OptionsΒΆ

1. Learnable 1D (ViT default) $\(E_{pos} \in \mathbb{R}^{N \times D} \text{ (learned parameters)}\)$

2. Sinusoidal 2D $\(PE(i, j, 2k) = \sin(i / 10000^{2k/D})\)\( \)\(PE(i, j, 2k+1) = \cos(j / 10000^{2k/D})\)$

where \((i, j)\) is 2D patch position

3. Relative Positional Encoding Encode relative distance between patches in attention: $\(\text{Attn}(i, j) = \frac{\exp(q_i^T k_j / \sqrt{d} + r_{i-j})}{\sum_k \exp(q_i^T k_k / \sqrt{d} + r_{i-k})}\)$

Class Token ([CLS])ΒΆ

Purpose: Global image representation for classification

Alternative: Could use average pooling over all patches $\(\text{avg} = \frac{1}{N}\sum_{i=1}^N z_i\)$

Why [CLS] token?

  • Learned aggregation (vs. fixed averaging)

  • Can attend to relevant patches

  • Proven effective in BERT, adopted for ViT

Mathematical Form: $\(z_0 = [\mathbf{x}_{cls}; E_1; E_2; \ldots; E_N] + [E_{pos}^{cls}; E_{pos}^1; \ldots; E_{pos}^N]\)$

where \(\mathbf{x}_{cls}\) is learned embedding

class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Convolutional projection (equivalent to linear on flattened patches)
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

# Test
patch_embed = PatchEmbedding(img_size=224, patch_size=16, embed_dim=768)
img = torch.randn(2, 3, 224, 224)
patches = patch_embed(img)
print(f"Image shape: {img.shape}")
print(f"Patch embeddings: {patches.shape}")  # (2, 196, 768)
print(f"Number of patches: {patch_embed.n_patches}")

2. Multi-Head Self-Attention: Complete TheoryΒΆ

Attention MechanismΒΆ

Scaled Dot-Product Attention: $\(\text{Attention}(Q, K, V) = \softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)$

where:

  • \(Q \in \mathbb{R}^{N \times d_k}\): Queries

  • \(K \in \mathbb{R}^{N \times d_k}\): Keys

  • \(V \in \mathbb{R}^{N \times d_v}\): Values

  • \(d_k\): Key/Query dimension

Why scale by \(\sqrt{d_k}\)?

As \(d_k\) increases, dot products grow in magnitude: $\(\text{Var}(q \cdot k) = d_k \cdot \text{Var}(q) \cdot \text{Var}(k)\)$

Scaling prevents saturation of softmax (gradients β†’ 0)

Multi-Head MechanismΒΆ

Single head limitation: Limited to one representation subspace

Solution: Learn \(h\) different attention patterns in parallel

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\]

where: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$

Parameter matrices:

  • \(W_i^Q, W_i^K \in \mathbb{R}^{D \times d_k}\) where \(d_k = D/h\)

  • \(W_i^V \in \mathbb{R}^{D \times d_v}\) where \(d_v = D/h\)

  • \(W^O \in \mathbb{R}^{D \times D}\)

Key insight: Each head can focus on different aspects:

  • Head 1: Local patterns (nearby patches)

  • Head 2: Global structure

  • Head 3: Specific objects

  • etc.

Complexity AnalysisΒΆ

Computational cost:

Component

Complexity

Explanation

\(QK^T\)

\(O(N^2 \cdot d_k)\)

Pairwise similarities

Softmax

\(O(N^2)\)

Normalize each row

Attention \(\times\) V

\(O(N^2 \cdot d_v)\)

Weighted sum

Total

\(\mathbf{O(N^2 \cdot D)}\)

\(N\) = seq length, \(D\) = embed dim

Memory: \(O(h \cdot N^2)\) for attention maps

Quadratic in sequence length! Major bottleneck for long sequences.

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Attention: Q @ K^T / sqrt(d_k)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Weighted sum of values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x, attn

# Test
mha = MultiHeadAttention(embed_dim=768, num_heads=12)
x = torch.randn(2, 197, 768)  # 196 patches + 1 CLS token
out, attn = mha(x)
print(f"Output shape: {out.shape}")
print(f"Attention shape: {attn.shape}")  # (B, num_heads, N, N)
# Visualize attention patterns

def visualize_attention_mechanism():
    """
    Demonstrate how self-attention creates patch relationships
    Shows query-key similarity and attention weights
    """
    
    # Simulate simple 3x3 grid of patches (9 patches)
    N = 9
    D = 64
    
    # Create simple patterns in patches
    np.random.seed(42)
    patches = torch.randn(1, N, D)
    
    # Simple attention (single head for visualization)
    Q = patches @ torch.randn(D, D)
    K = patches @ torch.randn(D, D)
    V = patches @ torch.randn(D, D)
    
    # Compute attention scores
    scores = (Q @ K.transpose(-2, -1)) / np.sqrt(D)
    attn_weights = F.softmax(scores, dim=-1)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    
    # Plot 1: Patch grid
    ax = axes[0]
    grid = np.arange(9).reshape(3, 3)
    im = ax.imshow(grid, cmap='tab10', vmin=0, vmax=9)
    ax.set_title('Patch Layout (3Γ—3 grid)', fontweight='bold', fontsize=13)
    ax.set_xticks(range(3))
    ax.set_yticks(range(3))
    ax.grid(False)
    
    # Add patch numbers
    for i in range(3):
        for j in range(3):
            ax.text(j, i, str(grid[i, j]), ha='center', va='center',
                   fontsize=16, fontweight='bold', color='white')
    
    plt.colorbar(im, ax=ax, label='Patch ID')
    
    # Plot 2: Attention map (focus on center patch)
    ax = axes[1]
    center_patch = 4  # Center of 3x3 grid
    attn_from_center = attn_weights[0, center_patch].detach().numpy().reshape(3, 3)
    
    im = ax.imshow(attn_from_center, cmap='YlOrRd', vmin=0, vmax=attn_from_center.max())
    ax.set_title(f'Attention Weights from Patch {center_patch} (center)', 
                 fontweight='bold', fontsize=13)
    ax.set_xticks(range(3))
    ax.set_yticks(range(3))
    
    # Add values
    for i in range(3):
        for j in range(3):
            ax.text(j, i, f'{attn_from_center[i, j]:.3f}', 
                   ha='center', va='center', fontsize=11)
    
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    # Plot 3: Full attention matrix
    ax = axes[2]
    attn_matrix = attn_weights[0].detach().numpy()
    im = ax.imshow(attn_matrix, cmap='viridis', aspect='auto')
    ax.set_xlabel('Key Patches', fontsize=12)
    ax.set_ylabel('Query Patches', fontsize=12)
    ax.set_title('Full Attention Matrix (9Γ—9)', fontweight='bold', fontsize=13)
    plt.colorbar(im, ax=ax, label='Attention')
    
    # Add grid
    ax.set_xticks(np.arange(N))
    ax.set_yticks(np.arange(N))
    ax.grid(True, which='both', color='white', linewidth=0.5, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*70)
    print("SELF-ATTENTION INTERPRETATION")
    print("="*70)
    print(f"Attention from Patch {center_patch} (center):")
    print(f"  β€’ Highest attention to itself: {attn_from_center.flat[center_patch]:.3f}")
    print(f"  β€’ Neighbors get significant weight")
    print(f"  β€’ Distant patches get lower weight")
    print("\nKey Properties:")
    print(f"  β€’ Each row sums to 1.0 (softmax normalization)")
    print(f"  β€’ All patches attend to all others (global receptive field)")
    print(f"  β€’ Unlike CNNs, no spatial locality bias (learned, not hardcoded)")
    print("="*70)

visualize_attention_mechanism()

Transformer Encoder BlockΒΆ

Each encoder block applies multi-head self-attention followed by a position-wise feed-forward network, with layer normalization and residual connections wrapping both sub-layers. Self-attention computes \(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\), allowing every patch to attend to every other patch regardless of spatial distance. The feed-forward network (typically two linear layers with GELU activation) then processes each position independently, giving the model capacity to learn complex per-token transformations. Stacking multiple encoder blocks enables the network to build increasingly abstract representations of the image.

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # MLP
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # Multi-head attention with residual
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out
        
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        
        return x, attn_weights

# Test
block = TransformerBlock()
x = torch.randn(2, 197, 768)
out, attn = block(x)
print(f"Block output shape: {out.shape}")

Complete Vision TransformerΒΆ

The full ViT model chains together the patch embedding, positional encoding, a stack of Transformer encoder blocks, and a classification head. An image of size \(H \times W\) is divided into \(N = HW / P^2\) non-overlapping patches of size \(P \times P\), each linearly projected to a \(d\)-dimensional embedding. A learnable [CLS] token is prepended to the sequence; after passing through all encoder layers, its representation is fed to the MLP head for classification. This architecture achieves state-of-the-art results on image benchmarks when pre-trained on large datasets, demonstrating that the inductive biases of convolutions are not strictly necessary for visual understanding.

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # CLS token and positional embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add CLS token
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.dropout(x)
        
        # Transformer blocks
        attentions = []
        for block in self.blocks:
            x, attn = block(x)
            attentions.append(attn)
        
        x = self.norm(x)
        
        # Classification head (use CLS token)
        logits = self.head(x[:, 0])
        
        return logits, attentions

# Create smaller ViT for MNIST (ViT-Tiny)
model = VisionTransformer(
    img_size=28,
    patch_size=4,
    in_channels=1,
    num_classes=10,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=2.0
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
test_img = torch.randn(4, 1, 28, 28).to(device)
logits, attns = model(test_img)
print(f"Logits shape: {logits.shape}")
print(f"Number of attention maps: {len(attns)}")

Training on MNISTΒΆ

We train the Vision Transformer on MNIST as a lightweight proof-of-concept – the \(28 \times 28\) grayscale images are small enough for fast iteration while still demonstrating the full ViT pipeline. Each image is split into patches (e.g., \(7 \times 7\) patches of \(4 \times 4\) pixels), embedded, and processed through the Transformer stack. Even with a small model and limited data, ViT can achieve strong accuracy on MNIST, illustrating that attention-based architectures generalize well even to modest tasks. The training loop uses cross-entropy loss and Adam optimizer with a warm-up schedule, following the conventions established in the original ViT paper.

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits, _ = model(images)
        loss = F.cross_entropy(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = logits.argmax(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        logits, _ = model(images)
        pred = logits.argmax(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    
    return correct / total

# Training loop
n_epochs = 10
history = {'train_loss': [], 'train_acc': [], 'test_acc': []}

for epoch in range(n_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer)
    test_acc = evaluate(model, test_loader)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_acc'].append(test_acc)
    
    print(f"Epoch {epoch+1}/{n_epochs}: Loss={train_loss:.4f}, "
          f"Train Acc={train_acc:.4f}, Test Acc={test_acc:.4f}")

print(f"\nFinal test accuracy: {history['test_acc'][-1]:.4f}")
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(history['train_loss'], linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=13)
ax1.grid(True, alpha=0.3)

ax2.plot(history['train_acc'], label='Train', linewidth=2)
ax2.plot(history['test_acc'], label='Test', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Classification Accuracy', fontsize=13)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Visualizing Attention MapsΒΆ

One of the most appealing properties of Vision Transformers is that attention weights provide a built-in interpretability mechanism. By extracting the attention matrix from one or more heads, we can visualize which patches each token attends to, revealing what the model considers important for its prediction. Early layers tend to show local attention patterns similar to convolutional receptive fields, while deeper layers exhibit long-range attention that spans the entire image. These attention maps are a practical debugging tool and help build trust in the model’s predictions for downstream applications.

model.eval()

# Get one test image
img, label = test_dataset[0]
img_input = img.unsqueeze(0).to(device)

with torch.no_grad():
    logits, attentions = model(img_input)
    pred = logits.argmax(1).item()

# Visualize attention from last layer, first head, CLS token
last_attn = attentions[-1][0, 0, 0, 1:].reshape(7, 7).cpu().numpy()  # 28/4 = 7 patches

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(img.squeeze(), cmap='gray')
axes[0].set_title(f'Input Image (Label: {label})', fontsize=13)
axes[0].axis('off')

# Attention map
im = axes[1].imshow(last_attn, cmap='hot', interpolation='nearest')
axes[1].set_title('CLS Token Attention', fontsize=13)
axes[1].axis('off')
plt.colorbar(im, ax=axes[1], fraction=0.046)

# Overlay
from scipy.ndimage import zoom
attn_resized = zoom(last_attn, (28/7, 28/7), order=1)
axes[2].imshow(img.squeeze(), cmap='gray')
axes[2].imshow(attn_resized, cmap='hot', alpha=0.5)
axes[2].set_title(f'Attention Overlay (Pred: {pred})', fontsize=13)
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("Attention highlights regions CLS token focuses on for classification")

SummaryΒΆ

Key Innovations of ViT:ΒΆ

  1. Patch embedding - treat image as sequence of patches

  2. No inductive bias - pure attention, no convolutions

  3. CLS token - learnable classification token

  4. Positional encoding - spatial relationships

ViT vs CNNs:ΒΆ

Aspect

CNN

ViT

Inductive bias

Strong (locality, translation equivariance)

Weak

Data efficiency

Good (small datasets)

Poor (needs large data)

Global context

Limited (receptive field)

Excellent (self-attention)

Scalability

Moderate

Excellent

Interpretability

Feature maps

Attention maps

When to Use ViT:ΒΆ

  • Large-scale datasets (ImageNet-21k, JFT-300M)

  • Transfer learning (pre-trained models)

  • Long-range dependencies matter

  • Interpretability via attention

Variants:ΒΆ

  • DeiT - Data-efficient training

  • Swin Transformer - Hierarchical, shifted windows

  • BEiT - Masked image modeling (BERT for vision)

  • MAE - Masked autoencoders

Next Steps:ΒΆ

  • 18_bert_deep_dive.ipynb - Language transformers

  • 20_efficient_transformers.ipynb - Attention optimization

  • CNN comparison in Phase 6 neural networks

6. CNNs vs Vision Transformers: Deep ComparisonΒΆ

Architectural DifferencesΒΆ

Aspect

CNNs

Vision Transformers

Receptive Field

Local β†’ Global (gradual)

Global from layer 1

Inductive Bias

Strong (locality, translation eq.)

Weak (learned from data)

Data Requirement

Low (works with ImageNet)

High (needs JFT-300M or strong aug)

Computation

\(O(N)\) per layer

\(O(N^2)\) per layer

Parameters

More efficient

More parameters needed

Mathematical ComparisonΒΆ

CNN Convolution: $\(y_{i,j} = \sum_{m,n} w_{m,n} \cdot x_{i+m, j+n}\)$

  • Local connectivity: Only neighbors in kernel

  • Weight sharing: Same \(w\) across spatial locations

  • Translation equivariance: \(f(T(x)) = T(f(x))\)

ViT Self-Attention: $\(y_i = \sum_{j=1}^N \text{softmax}(\frac{q_i k_j^T}{\sqrt{d}}) v_j\)$

  • Global connectivity: All patches communicate

  • No weight sharing: Different attention per position

  • Permutation equivariance: Relies on positional encoding

Inductive Bias AnalysisΒΆ

CNNs have strong built-in assumptions:

  1. Locality: Features composed from nearby pixels

    • Good: Works well with limited data

    • Bad: May miss long-range dependencies

  2. Translation Equivariance: Shift input β†’ shift output

    • Good: Parameter efficient

    • Bad: Rigid structure

ViTs have minimal assumptions:

  1. Flexibility: Learn structure from data

    • Good: Can discover optimal patterns

    • Bad: Requires massive data

  2. Global attention: All patches interact

    • Good: Capture long-range dependencies easily

    • Bad: Quadratic complexity

Performance CharacteristicsΒΆ

Data Regime:

Small Data (<100K images):
  CNN >>> ViT (by large margin)
  
Medium Data (ImageNet, 1.3M):
  CNN β‰ˆ ViT (ViT needs strong augmentation)
  
Large Data (JFT-300M, 300M+):
  ViT > CNN (ViT scales better)

Computational Efficiency:

Model

ImageNet Top-1

Parameters

FLOPs

Throughput

ResNet-50

76.5%

25M

4.1G

Fast ⚑⚑⚑

ViT-B/16

77.9%

86M

17.6G

Medium ⚑⚑

ViT-L/16

76.5%

307M

61.6G

Slow ⚑

When to Use EachΒΆ

Use CNNs when: βœ… Limited training data βœ… Need fast inference βœ… Working with small images βœ… Resource-constrained deployment βœ… Strong spatial structure (e.g., medical imaging)

Use ViTs when: βœ… Access to large datasets βœ… Pre-trained models available βœ… Need best possible accuracy βœ… Computational resources available βœ… Transfer learning to diverse tasks

Hybrid ApproachesΒΆ

Best of both worlds:

  1. Early Convolutions + Late Transformers

    • Use CNNs for early feature extraction

    • ViT for high-level reasoning

    • Example: Convolutional Vision Transformer (CvT)

  2. Local Attention Windows

    • Restrict attention to spatial neighborhoods

    • Reduces complexity from \(O(N^2)\) to \(O(N)\)

    • Example: Swin Transformer

  3. Pyramidal Architecture

    • Multi-scale feature hierarchy (like CNNs)

    • Transformer at each scale

    • Example: Pyramid Vision Transformer (PVT)

Recent AdvancesΒΆ

Efficient Transformers:

  • Swin Transformer: Shifted windows, \(O(N)\) complexity

  • Linformer: Linear attention approximation

  • Performer: Kernel-based attention, \(O(N)\)

Better ViT Training:

  • DeiT: Data-efficient training with distillation

  • CaiT: Class-attention in later layers

  • BEiT: BERT-style pre-training for images

Theoretical UnderstandingΒΆ

Why do ViTs work?

  1. Expressiveness: Can represent any function CNNs can (universal approximators)

  2. Optimization: Self-attention provides multiple gradient paths

  3. Scaling: Performance improves log-linearly with compute

  4. Transfer: Learn general visual representations

Open Questions:

  • Why do ViTs need so much data initially?

  • What visual features do attention heads learn?

  • Can we design better inductive biases?