Transformers from Scratch: Self-Attention & Positional EncodingΒΆ
Build the core transformer building blocks in pure PyTorch β demystify whatβs inside BERT, GPT, and every modern LLM.
1. The Problem with RNNsΒΆ
Recurrent Neural Networks (RNNs/LSTMs) process sequences one token at a time:
h1 = f(x1, h0)
h2 = f(x2, h1) β can't compute until h1 is ready
h3 = f(x3, h2) β can't compute until h2 is ready
...
Problems:
Sequential: Canβt parallelize across tokens β slow on GPUs
Vanishing gradients: Long-range dependencies are hard to learn
Fixed bottleneck: The entire sequence must be compressed into one hidden state
Transformers fix all three with self-attention: every token attends to every other token simultaneously.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
torch.manual_seed(42)
print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
2. Scaled Dot-Product AttentionΒΆ
The core equation:
Where:
Q (Query): βWhat am I looking for?β
K (Key): βWhat do I contain?β
V (Value): βWhat do I return if you match me?β
\(\sqrt{d_k}\): Scaling factor to prevent softmax from saturating in high dimensions
The result: each token gets a weighted average of all values, where weights reflect how relevant each other token is.
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: Query tensor (batch, seq_len_q, d_k)
K: Key tensor (batch, seq_len_k, d_k)
V: Value tensor (batch, seq_len_k, d_v)
mask: Optional mask (batch, 1, seq_len_q, seq_len_k)
Returns:
output: (batch, seq_len_q, d_v)
attn_weights: (batch, seq_len_q, seq_len_k)
"""
d_k = Q.size(-1)
# Attention scores: (batch, seq_len_q, seq_len_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask (e.g., for causal/autoregressive attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attn_weights, V)
return output, attn_weights
# Test with random tensors
batch_size = 2
seq_len = 6
d_k = 64
d_v = 64
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)
output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"Output shape: {output.shape} (same as Q)")
print(f"Attention weights: {attn_weights.shape} (seq_len Γ seq_len matrix)")
print(f"Attn weights sum: {attn_weights[0].sum(dim=-1)} (rows sum to 1.0)")
3. Multi-Head AttentionΒΆ
Instead of one attention, run h parallel attention heads each with smaller dimension d_model/h:
Input (batch, seq, d_model)
β split into h heads
Head 1: Q1, K1, V1 β Attention1 β output1
Head 2: Q2, K2, V2 β Attention2 β output2 (in parallel!)
...
Head h: Qh, Kh, Vh β Attentionh β outputh
β concatenate
Concat β Linear β final output (batch, seq, d_model)
Why? Each head can attend to different types of relationships (syntax, semantics, coreference, etc.).
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
"""
d_model: total embedding dimension
n_heads: number of parallel attention heads
d_k = d_model // n_heads (dimension per head)
"""
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Linear projections for Q, K, V, and output
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def split_heads(self, x):
"""Reshape from (batch, seq, d_model) to (batch, n_heads, seq, d_k)"""
batch, seq, _ = x.shape
x = x.view(batch, seq, self.n_heads, self.d_k)
return x.transpose(1, 2) # (batch, n_heads, seq, d_k)
def forward(self, x, mask=None):
batch, seq, _ = x.shape
# Project to Q, K, V then split into heads
Q = self.split_heads(self.W_q(x)) # (batch, n_heads, seq, d_k)
K = self.split_heads(self.W_k(x))
V = self.split_heads(self.W_v(x))
# Attention on all heads in parallel
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Merge heads: (batch, n_heads, seq, d_k) β (batch, seq, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, seq, self.d_model)
# Final linear projection
return self.W_o(attn_output), attn_weights
# Test
d_model = 128
n_heads = 8
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
x = torch.randn(2, seq_len, d_model)
out, weights = mha(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape} (same shape as input!)")
print(f"Attn weights: {weights.shape} (batch, heads, seq, seq)")
print(f"d_k per head: {d_model // n_heads} (d_model / n_heads)")
4. Positional EncodingΒΆ
Self-attention is permutation invariant β it has no notion of order. Positional encoding injects position information using sine and cosine functions:
The period of each sine/cosine increases exponentially β giving each position a unique βfingerprintβ at every scale.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Build positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# Exponentially spaced frequencies
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even indices: sin
pe[:, 1::2] = torch.cos(position * div_term) # odd indices: cos
# Register as buffer (not a parameter β not learned)
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_seq_len, d_model)
def forward(self, x):
"""x: (batch, seq_len, d_model)"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# Visualize positional encodings
d_model_vis = 64
max_len_vis = 50
pe_layer = PositionalEncoding(d_model=d_model_vis, max_seq_len=max_len_vis, dropout=0.0)
pe_matrix = pe_layer.pe[0].numpy() # (max_len, d_model)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Heatmap
im = axes[0].imshow(pe_matrix[:40, :32], cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
axes[0].set_xlabel('Embedding dimension')
axes[0].set_ylabel('Position (token index)')
axes[0].set_title('Positional Encoding: Sine/Cosine Heatmap')
plt.colorbar(im, ax=axes[0])
# Individual curves
for dim in [0, 4, 10, 20]:
axes[1].plot(pe_matrix[:40, dim], label=f'dim {dim}')
axes[1].set_xlabel('Position (token index)')
axes[1].set_ylabel('PE value')
axes[1].set_title('PE values for different dimensions\n(lower dims: higher freq, higher dims: lower freq)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
5. Layer Norm + Residual ConnectionsΒΆ
Residual connections (x = x + sublayer(x)) solve vanishing gradients in deep networks β gradients can flow directly through the skip connection.
Layer Norm normalizes across the feature dimension (not batch), making training stable across varying sequence lengths.
class FeedForward(nn.Module):
"""Position-wise feed-forward network."""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
# Demonstrate LayerNorm vs BatchNorm behavior
x_demo = torch.randn(4, 10, 32) # (batch=4, seq=10, features=32)
layer_norm = nn.LayerNorm(32)
batch_norm = nn.BatchNorm1d(32)
ln_out = layer_norm(x_demo)
print(f"LayerNorm output: mean={ln_out[0,0].mean():.4f}, std={ln_out[0,0].std():.4f}")
print(f" Normalized across features for EACH token independently")
print(f" Works the same during training and inference")
print(f" Works for variable sequence lengths (no batch statistics)")
6. The TransformerBlockΒΆ
class TransformerBlock(nn.Module):
"""
One transformer encoder block:
x β LayerNorm β MultiHeadAttention β + (residual) β x'
x' β LayerNorm β FeedForward β + (residual) β output
Note: "Pre-LN" variant (LN before sublayer) is more stable than original "Post-LN".
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with Pre-LN + residual
attn_out, attn_weights = self.attention(self.norm1(x), mask)
x = x + self.dropout(attn_out) # residual connection
# Feed-forward with Pre-LN + residual
ffn_out = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_out) # residual connection
return x, attn_weights
# Test the block
block = TransformerBlock(d_model=128, n_heads=8, d_ff=512, dropout=0.1)
x_in = torch.randn(2, 10, 128) # (batch=2, seq=10, d_model=128)
x_out, weights = block(x_in)
print(f"Input shape: {x_in.shape}")
print(f"Output shape: {x_out.shape} (same β transformer preserves shape)")
print(f"\nBlock parameters: {sum(p.numel() for p in block.parameters()):,}")
7. Mini Transformer EncoderΒΆ
class TransformerEncoder(nn.Module):
"""
Full transformer encoder: embedding + positional encoding + N transformer blocks.
"""
def __init__(self, vocab_size, d_model, n_heads, d_ff, n_layers, max_seq_len=512, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model) # final layer norm
self.d_model = d_model
def forward(self, token_ids, mask=None):
"""
token_ids: (batch, seq_len) β integer token IDs
returns: (batch, seq_len, d_model) β contextualized embeddings
"""
# Scale embeddings by sqrt(d_model) as in the original paper
x = self.embedding(token_ids) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x, _ = layer(x, mask)
return self.norm(x)
# Build a mini transformer encoder
mini_encoder = TransformerEncoder(
vocab_size = 1000,
d_model = 128,
n_heads = 8,
d_ff = 512,
n_layers = 4,
max_seq_len = 64
)
# Test with random token IDs
token_ids = torch.randint(0, 1000, (2, 20)) # batch=2, seq=20
encoder_output = mini_encoder(token_ids)
print(f"Token IDs shape: {token_ids.shape}")
print(f"Encoder output shape: {encoder_output.shape}")
print(f"\nMini encoder params: {sum(p.numel() for p in mini_encoder.parameters()):,}")
8. Parameter Count vs BERTΒΆ
def count_params(model):
return sum(p.numel() for p in model.parameters())
configs = {
'Mini (ours)': dict(vocab_size=1000, d_model=128, n_heads=8, d_ff=512, n_layers=4),
'BERT-Tiny': dict(vocab_size=30522, d_model=128, n_heads=2, d_ff=512, n_layers=2),
'BERT-Small': dict(vocab_size=30522, d_model=512, n_heads=8, d_ff=2048, n_layers=4),
'BERT-Base': dict(vocab_size=30522, d_model=768, n_heads=12, d_ff=3072, n_layers=12),
'BERT-Large': dict(vocab_size=30522, d_model=1024, n_heads=16, d_ff=4096, n_layers=24),
}
print(f"{'Config':<15} {'d_model':>8} {'n_heads':>8} {'n_layers':>9} {'Parameters':>15}")
print('-' * 60)
for name, config in configs.items():
model = TransformerEncoder(**config, max_seq_len=512)
n = count_params(model)
print(f"{name:<15} {config['d_model']:>8} {config['n_heads']:>8} {config['n_layers']:>9} {n:>15,}")
print("\nNote: BERT-Base has ~110M params (actually 110M in the original paper;")
print(" our simplified version counts slightly fewer due to simplified architecture)")
print("\nThe quadratic scaling: doubling d_model β 4x more attention params")
ExercisesΒΆ
Causal mask: Implement a causal (autoregressive) mask for the decoder. The mask should be a lower-triangular matrix of 1s so that position
ican only attend to positions0..i. Test that attention weights are zero for future positions.Attention visualization: Pass the sentence βThe cat sat on the matβ (as integer IDs) through your encoder. Plot the attention weight heatmap for each head in the first layer. Do different heads attend to different patterns?
Scaling law: Build models of increasing size (vary
n_layersfrom 2 to 12). Plot parameter count vsn_layers. Is it linear? Quadratic? Where is the bottleneck?Pre-LN vs Post-LN: Implement a Post-LN TransformerBlock (LN after attention, not before). Train both on a toy task (e.g., copy task: predict the same sequence as input). Which converges faster? Which is more stable?
Learned positional encoding: Replace the sinusoidal PositionalEncoding with a learned
nn.Embedding(max_seq_len, d_model). What are the tradeoffs? (Hint: think about sequences longer thanmax_seq_lenat inference time.)