Vision Transformers (ViT)ΒΆ
Learning Objectives:
Understand patch embeddings and positional encoding
Implement self-attention for images
Train ViT on image classification
Compare with CNNs
Prerequisites: Transformers, self-attention, deep learning
Time: 90 minutes
π Reference Materials:
transformer.pdf - Transformer architecture and attention mechanisms
cnn_beyond.pdf - Beyond CNNs including Vision Transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. From Images to SequencesΒΆ
ChallengeΒΆ
Transformers expect sequences as input, but images are 2D grids.
ViT Solution: Patch EmbeddingΒΆ
Split image into fixed-size patches (e.g., 16Γ16)
Flatten each patch into a vector
Linearly project to embedding dimension
Add positional encoding
Prepend [CLS] token for classification
Mathematical FormulationΒΆ
Image \(x \in \mathbb{R}^{H \times W \times C}\) with patch size \(P\):
Number of patches: \(N = \frac{HW}{P^2}\)
Patch sequence: \(x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}\)
Embedded patches: \(z_0 = [x_{cls}; x_p^1 E; x_p^2 E; \ldots; x_p^N E] + E_{pos}\)
where \(E \in \mathbb{R}^{(P^2 \cdot C) \times D}\) is learnable projection.
1.5. Patch Embedding: Mathematical DetailsΒΆ
Why Patches?ΒΆ
Computational Complexity:
Per-pixel attention: \(O((HW)^2 \cdot D)\) - prohibitive!
Patch-based: \(O(N^2 \cdot D)\) where \(N = HW/P^2\) - tractable!
For 224Γ224 image with \(P=16\):
Pixels: \(224^2 = 50{,}176\) β Attention: \(50{,}176^2 \approx 2.5\) billion
Patches: \((224/16)^2 = 196\) β Attention: \(196^2 \approx 38{,}000\)
~65,000Γ reduction in complexity! β
Patch Size AnalysisΒΆ
Trade-offs:
Patch Size |
# Patches |
Pros |
Cons |
|---|---|---|---|
4Γ4 |
\(56^2 = 3{,}136\) |
Fine-grained details |
High computation, overfitting |
8Γ8 |
\(28^2 = 784\) |
Good detail |
Moderate computation |
16Γ16 |
\(14^2 = 196\) |
Standard, efficient |
May miss fine details |
32Γ32 |
\(7^2 = 49\) |
Very fast |
Loses too much detail |
Standard choice: \(P=16\) balances performance and efficiency
Positional Encoding OptionsΒΆ
1. Learnable 1D (ViT default) $\(E_{pos} \in \mathbb{R}^{N \times D} \text{ (learned parameters)}\)$
2. Sinusoidal 2D $\(PE(i, j, 2k) = \sin(i / 10000^{2k/D})\)\( \)\(PE(i, j, 2k+1) = \cos(j / 10000^{2k/D})\)$
where \((i, j)\) is 2D patch position
3. Relative Positional Encoding Encode relative distance between patches in attention: $\(\text{Attn}(i, j) = \frac{\exp(q_i^T k_j / \sqrt{d} + r_{i-j})}{\sum_k \exp(q_i^T k_k / \sqrt{d} + r_{i-k})}\)$
Class Token ([CLS])ΒΆ
Purpose: Global image representation for classification
Alternative: Could use average pooling over all patches $\(\text{avg} = \frac{1}{N}\sum_{i=1}^N z_i\)$
Why [CLS] token?
Learned aggregation (vs. fixed averaging)
Can attend to relevant patches
Proven effective in BERT, adopted for ViT
Mathematical Form: $\(z_0 = [\mathbf{x}_{cls}; E_1; E_2; \ldots; E_N] + [E_{pos}^{cls}; E_{pos}^1; \ldots; E_{pos}^N]\)$
where \(\mathbf{x}_{cls}\) is learned embedding
class PatchEmbedding(nn.Module):
"""Split image into patches and embed them."""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Convolutional projection (equivalent to linear on flattened patches)
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: (B, C, H, W)
x = self.proj(x) # (B, embed_dim, n_patches_h, n_patches_w)
x = x.flatten(2) # (B, embed_dim, n_patches)
x = x.transpose(1, 2) # (B, n_patches, embed_dim)
return x
# Test
patch_embed = PatchEmbedding(img_size=224, patch_size=16, embed_dim=768)
img = torch.randn(2, 3, 224, 224)
patches = patch_embed(img)
print(f"Image shape: {img.shape}")
print(f"Patch embeddings: {patches.shape}") # (2, 196, 768)
print(f"Number of patches: {patch_embed.n_patches}")
2. Multi-Head Self-Attention: Complete TheoryΒΆ
Attention MechanismΒΆ
Scaled Dot-Product Attention: $\(\text{Attention}(Q, K, V) = \softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)$
where:
\(Q \in \mathbb{R}^{N \times d_k}\): Queries
\(K \in \mathbb{R}^{N \times d_k}\): Keys
\(V \in \mathbb{R}^{N \times d_v}\): Values
\(d_k\): Key/Query dimension
Why scale by \(\sqrt{d_k}\)?
As \(d_k\) increases, dot products grow in magnitude: $\(\text{Var}(q \cdot k) = d_k \cdot \text{Var}(q) \cdot \text{Var}(k)\)$
Scaling prevents saturation of softmax (gradients β 0)
Multi-Head MechanismΒΆ
Single head limitation: Limited to one representation subspace
Solution: Learn \(h\) different attention patterns in parallel
where: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$
Parameter matrices:
\(W_i^Q, W_i^K \in \mathbb{R}^{D \times d_k}\) where \(d_k = D/h\)
\(W_i^V \in \mathbb{R}^{D \times d_v}\) where \(d_v = D/h\)
\(W^O \in \mathbb{R}^{D \times D}\)
Key insight: Each head can focus on different aspects:
Head 1: Local patterns (nearby patches)
Head 2: Global structure
Head 3: Specific objects
etc.
Complexity AnalysisΒΆ
Computational cost:
Component |
Complexity |
Explanation |
|---|---|---|
\(QK^T\) |
\(O(N^2 \cdot d_k)\) |
Pairwise similarities |
Softmax |
\(O(N^2)\) |
Normalize each row |
Attention \(\times\) V |
\(O(N^2 \cdot d_v)\) |
Weighted sum |
Total |
\(\mathbf{O(N^2 \cdot D)}\) |
\(N\) = seq length, \(D\) = embed dim |
Memory: \(O(h \cdot N^2)\) for attention maps
Quadratic in sequence length! Major bottleneck for long sequences.
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# Generate Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention: Q @ K^T / sqrt(d_k)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
# Weighted sum of values
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dropout(x)
return x, attn
# Test
mha = MultiHeadAttention(embed_dim=768, num_heads=12)
x = torch.randn(2, 197, 768) # 196 patches + 1 CLS token
out, attn = mha(x)
print(f"Output shape: {out.shape}")
print(f"Attention shape: {attn.shape}") # (B, num_heads, N, N)
# Visualize attention patterns
def visualize_attention_mechanism():
"""
Demonstrate how self-attention creates patch relationships
Shows query-key similarity and attention weights
"""
# Simulate simple 3x3 grid of patches (9 patches)
N = 9
D = 64
# Create simple patterns in patches
np.random.seed(42)
patches = torch.randn(1, N, D)
# Simple attention (single head for visualization)
Q = patches @ torch.randn(D, D)
K = patches @ torch.randn(D, D)
V = patches @ torch.randn(D, D)
# Compute attention scores
scores = (Q @ K.transpose(-2, -1)) / np.sqrt(D)
attn_weights = F.softmax(scores, dim=-1)
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Plot 1: Patch grid
ax = axes[0]
grid = np.arange(9).reshape(3, 3)
im = ax.imshow(grid, cmap='tab10', vmin=0, vmax=9)
ax.set_title('Patch Layout (3Γ3 grid)', fontweight='bold', fontsize=13)
ax.set_xticks(range(3))
ax.set_yticks(range(3))
ax.grid(False)
# Add patch numbers
for i in range(3):
for j in range(3):
ax.text(j, i, str(grid[i, j]), ha='center', va='center',
fontsize=16, fontweight='bold', color='white')
plt.colorbar(im, ax=ax, label='Patch ID')
# Plot 2: Attention map (focus on center patch)
ax = axes[1]
center_patch = 4 # Center of 3x3 grid
attn_from_center = attn_weights[0, center_patch].detach().numpy().reshape(3, 3)
im = ax.imshow(attn_from_center, cmap='YlOrRd', vmin=0, vmax=attn_from_center.max())
ax.set_title(f'Attention Weights from Patch {center_patch} (center)',
fontweight='bold', fontsize=13)
ax.set_xticks(range(3))
ax.set_yticks(range(3))
# Add values
for i in range(3):
for j in range(3):
ax.text(j, i, f'{attn_from_center[i, j]:.3f}',
ha='center', va='center', fontsize=11)
plt.colorbar(im, ax=ax, label='Attention Weight')
# Plot 3: Full attention matrix
ax = axes[2]
attn_matrix = attn_weights[0].detach().numpy()
im = ax.imshow(attn_matrix, cmap='viridis', aspect='auto')
ax.set_xlabel('Key Patches', fontsize=12)
ax.set_ylabel('Query Patches', fontsize=12)
ax.set_title('Full Attention Matrix (9Γ9)', fontweight='bold', fontsize=13)
plt.colorbar(im, ax=ax, label='Attention')
# Add grid
ax.set_xticks(np.arange(N))
ax.set_yticks(np.arange(N))
ax.grid(True, which='both', color='white', linewidth=0.5, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("SELF-ATTENTION INTERPRETATION")
print("="*70)
print(f"Attention from Patch {center_patch} (center):")
print(f" β’ Highest attention to itself: {attn_from_center.flat[center_patch]:.3f}")
print(f" β’ Neighbors get significant weight")
print(f" β’ Distant patches get lower weight")
print("\nKey Properties:")
print(f" β’ Each row sums to 1.0 (softmax normalization)")
print(f" β’ All patches attend to all others (global receptive field)")
print(f" β’ Unlike CNNs, no spatial locality bias (learned, not hardcoded)")
print("="*70)
visualize_attention_mechanism()
Transformer Encoder BlockΒΆ
Each encoder block applies multi-head self-attention followed by a position-wise feed-forward network, with layer normalization and residual connections wrapping both sub-layers. Self-attention computes \(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\), allowing every patch to attend to every other patch regardless of spatial distance. The feed-forward network (typically two linear layers with GELU activation) then processes each position independently, giving the model capacity to learn complex per-token transformations. Stacking multiple encoder blocks enables the network to build increasingly abstract representations of the image.
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
# MLP
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Multi-head attention with residual
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
# MLP with residual
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# Test
block = TransformerBlock()
x = torch.randn(2, 197, 768)
out, attn = block(x)
print(f"Block output shape: {out.shape}")
Complete Vision TransformerΒΆ
The full ViT model chains together the patch embedding, positional encoding, a stack of Transformer encoder blocks, and a classification head. An image of size \(H \times W\) is divided into \(N = HW / P^2\) non-overlapping patches of size \(P \times P\), each linearly projected to a \(d\)-dimensional embedding. A learnable [CLS] token is prepended to the sequence; after passing through all encoder layers, its representation is fed to the MLP head for classification. This architecture achieves state-of-the-art results on image benchmarks when pre-trained on large datasets, demonstrating that the inductive biases of convolutions are not strictly necessary for visual understanding.
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
# Patch embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# CLS token and positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
# Transformer encoder
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# Initialize
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x)
# Add CLS token
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_token, x], dim=1)
# Add positional embedding
x = x + self.pos_embed
x = self.dropout(x)
# Transformer blocks
attentions = []
for block in self.blocks:
x, attn = block(x)
attentions.append(attn)
x = self.norm(x)
# Classification head (use CLS token)
logits = self.head(x[:, 0])
return logits, attentions
# Create smaller ViT for MNIST (ViT-Tiny)
model = VisionTransformer(
img_size=28,
patch_size=4,
in_channels=1,
num_classes=10,
embed_dim=192,
depth=6,
num_heads=3,
mlp_ratio=2.0
).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test forward pass
test_img = torch.randn(4, 1, 28, 28).to(device)
logits, attns = model(test_img)
print(f"Logits shape: {logits.shape}")
print(f"Number of attention maps: {len(attns)}")
Training on MNISTΒΆ
We train the Vision Transformer on MNIST as a lightweight proof-of-concept β the \(28 \times 28\) grayscale images are small enough for fast iteration while still demonstrating the full ViT pipeline. Each image is split into patches (e.g., \(7 \times 7\) patches of \(4 \times 4\) pixels), embedded, and processed through the Transformer stack. Even with a small model and limited data, ViT can achieve strong accuracy on MNIST, illustrating that attention-based architectures generalize well even to modest tasks. The training loop uses cross-entropy loss and Adam optimizer with a warm-up schedule, following the conventions established in the original ViT paper.
# Data loading
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)
# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
def train_epoch(model, loader, optimizer):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
logits, _ = model(images)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = logits.argmax(1)
correct += (pred == labels).sum().item()
total += labels.size(0)
return total_loss / len(loader), correct / total
@torch.no_grad()
def evaluate(model, loader):
model.eval()
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
logits, _ = model(images)
pred = logits.argmax(1)
correct += (pred == labels).sum().item()
total += labels.size(0)
return correct / total
# Training loop
n_epochs = 10
history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
for epoch in range(n_epochs):
train_loss, train_acc = train_epoch(model, train_loader, optimizer)
test_acc = evaluate(model, test_loader)
scheduler.step()
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['test_acc'].append(test_acc)
print(f"Epoch {epoch+1}/{n_epochs}: Loss={train_loss:.4f}, "
f"Train Acc={train_acc:.4f}, Test Acc={test_acc:.4f}")
print(f"\nFinal test accuracy: {history['test_acc'][-1]:.4f}")
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(history['train_loss'], linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=13)
ax1.grid(True, alpha=0.3)
ax2.plot(history['train_acc'], label='Train', linewidth=2)
ax2.plot(history['test_acc'], label='Test', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Classification Accuracy', fontsize=13)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Visualizing Attention MapsΒΆ
One of the most appealing properties of Vision Transformers is that attention weights provide a built-in interpretability mechanism. By extracting the attention matrix from one or more heads, we can visualize which patches each token attends to, revealing what the model considers important for its prediction. Early layers tend to show local attention patterns similar to convolutional receptive fields, while deeper layers exhibit long-range attention that spans the entire image. These attention maps are a practical debugging tool and help build trust in the modelβs predictions for downstream applications.
model.eval()
# Get one test image
img, label = test_dataset[0]
img_input = img.unsqueeze(0).to(device)
with torch.no_grad():
logits, attentions = model(img_input)
pred = logits.argmax(1).item()
# Visualize attention from last layer, first head, CLS token
last_attn = attentions[-1][0, 0, 0, 1:].reshape(7, 7).cpu().numpy() # 28/4 = 7 patches
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original image
axes[0].imshow(img.squeeze(), cmap='gray')
axes[0].set_title(f'Input Image (Label: {label})', fontsize=13)
axes[0].axis('off')
# Attention map
im = axes[1].imshow(last_attn, cmap='hot', interpolation='nearest')
axes[1].set_title('CLS Token Attention', fontsize=13)
axes[1].axis('off')
plt.colorbar(im, ax=axes[1], fraction=0.046)
# Overlay
from scipy.ndimage import zoom
attn_resized = zoom(last_attn, (28/7, 28/7), order=1)
axes[2].imshow(img.squeeze(), cmap='gray')
axes[2].imshow(attn_resized, cmap='hot', alpha=0.5)
axes[2].set_title(f'Attention Overlay (Pred: {pred})', fontsize=13)
axes[2].axis('off')
plt.tight_layout()
plt.show()
print("Attention highlights regions CLS token focuses on for classification")
SummaryΒΆ
Key Innovations of ViT:ΒΆ
Patch embedding - treat image as sequence of patches
No inductive bias - pure attention, no convolutions
CLS token - learnable classification token
Positional encoding - spatial relationships
ViT vs CNNs:ΒΆ
Aspect |
CNN |
ViT |
|---|---|---|
Inductive bias |
Strong (locality, translation equivariance) |
Weak |
Data efficiency |
Good (small datasets) |
Poor (needs large data) |
Global context |
Limited (receptive field) |
Excellent (self-attention) |
Scalability |
Moderate |
Excellent |
Interpretability |
Feature maps |
Attention maps |
When to Use ViT:ΒΆ
Large-scale datasets (ImageNet-21k, JFT-300M)
Transfer learning (pre-trained models)
Long-range dependencies matter
Interpretability via attention
Variants:ΒΆ
DeiT - Data-efficient training
Swin Transformer - Hierarchical, shifted windows
BEiT - Masked image modeling (BERT for vision)
MAE - Masked autoencoders
Next Steps:ΒΆ
18_bert_deep_dive.ipynb - Language transformers
20_efficient_transformers.ipynb - Attention optimization
CNN comparison in Phase 6 neural networks
6. CNNs vs Vision Transformers: Deep ComparisonΒΆ
Architectural DifferencesΒΆ
Aspect |
CNNs |
Vision Transformers |
|---|---|---|
Receptive Field |
Local β Global (gradual) |
Global from layer 1 |
Inductive Bias |
Strong (locality, translation eq.) |
Weak (learned from data) |
Data Requirement |
Low (works with ImageNet) |
High (needs JFT-300M or strong aug) |
Computation |
\(O(N)\) per layer |
\(O(N^2)\) per layer |
Parameters |
More efficient |
More parameters needed |
Mathematical ComparisonΒΆ
CNN Convolution: $\(y_{i,j} = \sum_{m,n} w_{m,n} \cdot x_{i+m, j+n}\)$
Local connectivity: Only neighbors in kernel
Weight sharing: Same \(w\) across spatial locations
Translation equivariance: \(f(T(x)) = T(f(x))\)
ViT Self-Attention: $\(y_i = \sum_{j=1}^N \text{softmax}(\frac{q_i k_j^T}{\sqrt{d}}) v_j\)$
Global connectivity: All patches communicate
No weight sharing: Different attention per position
Permutation equivariance: Relies on positional encoding
Inductive Bias AnalysisΒΆ
CNNs have strong built-in assumptions:
Locality: Features composed from nearby pixels
Good: Works well with limited data
Bad: May miss long-range dependencies
Translation Equivariance: Shift input β shift output
Good: Parameter efficient
Bad: Rigid structure
ViTs have minimal assumptions:
Flexibility: Learn structure from data
Good: Can discover optimal patterns
Bad: Requires massive data
Global attention: All patches interact
Good: Capture long-range dependencies easily
Bad: Quadratic complexity
Performance CharacteristicsΒΆ
Data Regime:
Small Data (<100K images):
CNN >>> ViT (by large margin)
Medium Data (ImageNet, 1.3M):
CNN β ViT (ViT needs strong augmentation)
Large Data (JFT-300M, 300M+):
ViT > CNN (ViT scales better)
Computational Efficiency:
Model |
ImageNet Top-1 |
Parameters |
FLOPs |
Throughput |
|---|---|---|---|---|
ResNet-50 |
76.5% |
25M |
4.1G |
Fast β‘β‘β‘ |
ViT-B/16 |
77.9% |
86M |
17.6G |
Medium β‘β‘ |
ViT-L/16 |
76.5% |
307M |
61.6G |
Slow β‘ |
When to Use EachΒΆ
Use CNNs when: β Limited training data β Need fast inference β Working with small images β Resource-constrained deployment β Strong spatial structure (e.g., medical imaging)
Use ViTs when: β Access to large datasets β Pre-trained models available β Need best possible accuracy β Computational resources available β Transfer learning to diverse tasks
Hybrid ApproachesΒΆ
Best of both worlds:
Early Convolutions + Late Transformers
Use CNNs for early feature extraction
ViT for high-level reasoning
Example: Convolutional Vision Transformer (CvT)
Local Attention Windows
Restrict attention to spatial neighborhoods
Reduces complexity from \(O(N^2)\) to \(O(N)\)
Example: Swin Transformer
Pyramidal Architecture
Multi-scale feature hierarchy (like CNNs)
Transformer at each scale
Example: Pyramid Vision Transformer (PVT)
Recent AdvancesΒΆ
Efficient Transformers:
Swin Transformer: Shifted windows, \(O(N)\) complexity
Linformer: Linear attention approximation
Performer: Kernel-based attention, \(O(N)\)
Better ViT Training:
DeiT: Data-efficient training with distillation
CaiT: Class-attention in later layers
BEiT: BERT-style pre-training for images
Theoretical UnderstandingΒΆ
Why do ViTs work?
Expressiveness: Can represent any function CNNs can (universal approximators)
Optimization: Self-attention provides multiple gradient paths
Scaling: Performance improves log-linearly with compute
Transfer: Learn general visual representations
Open Questions:
Why do ViTs need so much data initially?
What visual features do attention heads learn?
Can we design better inductive biases?