Transformers from Scratch: Self-Attention & Positional EncodingΒΆ

Build the core transformer building blocks in pure PyTorch β€” demystify what’s inside BERT, GPT, and every modern LLM.

1. The Problem with RNNsΒΆ

Recurrent Neural Networks (RNNs/LSTMs) process sequences one token at a time:

h1 = f(x1, h0)
h2 = f(x2, h1)   ← can't compute until h1 is ready
h3 = f(x3, h2)   ← can't compute until h2 is ready
...

Problems:

  1. Sequential: Can’t parallelize across tokens β†’ slow on GPUs

  2. Vanishing gradients: Long-range dependencies are hard to learn

  3. Fixed bottleneck: The entire sequence must be compressed into one hidden state

Transformers fix all three with self-attention: every token attends to every other token simultaneously.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

torch.manual_seed(42)
print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

2. Scaled Dot-Product AttentionΒΆ

The core equation:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

Where:

  • Q (Query): β€œWhat am I looking for?”

  • K (Key): β€œWhat do I contain?”

  • V (Value): β€œWhat do I return if you match me?”

  • \(\sqrt{d_k}\): Scaling factor to prevent softmax from saturating in high dimensions

The result: each token gets a weighted average of all values, where weights reflect how relevant each other token is.

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor  (batch, seq_len_q, d_k)
        K: Key tensor    (batch, seq_len_k, d_k)
        V: Value tensor  (batch, seq_len_k, d_v)
        mask: Optional mask (batch, 1, seq_len_q, seq_len_k)
    
    Returns:
        output:   (batch, seq_len_q, d_v)
        attn_weights: (batch, seq_len_q, seq_len_k)
    """
    d_k = Q.size(-1)
    
    # Attention scores: (batch, seq_len_q, seq_len_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask (e.g., for causal/autoregressive attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights
    attn_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

# Test with random tensors
batch_size = 2
seq_len    = 6
d_k        = 64
d_v        = 64

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Q shape:            {Q.shape}")
print(f"K shape:            {K.shape}")
print(f"V shape:            {V.shape}")
print(f"Output shape:       {output.shape}   (same as Q)")
print(f"Attention weights:  {attn_weights.shape}  (seq_len Γ— seq_len matrix)")
print(f"Attn weights sum:   {attn_weights[0].sum(dim=-1)}  (rows sum to 1.0)")

3. Multi-Head AttentionΒΆ

Instead of one attention, run h parallel attention heads each with smaller dimension d_model/h:

Input (batch, seq, d_model)
    ↓ split into h heads
Head 1: Q1, K1, V1 β†’ Attention1 β†’ output1
Head 2: Q2, K2, V2 β†’ Attention2 β†’ output2   (in parallel!)
...
Head h: Qh, Kh, Vh β†’ Attentionh β†’ outputh
    ↓ concatenate
Concat β†’ Linear β†’ final output (batch, seq, d_model)

Why? Each head can attend to different types of relationships (syntax, semantics, coreference, etc.).

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        """
        d_model: total embedding dimension
        n_heads: number of parallel attention heads
        d_k = d_model // n_heads (dimension per head)
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k     = d_model // n_heads
        
        # Linear projections for Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def split_heads(self, x):
        """Reshape from (batch, seq, d_model) to (batch, n_heads, seq, d_k)"""
        batch, seq, _ = x.shape
        x = x.view(batch, seq, self.n_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, n_heads, seq, d_k)
    
    def forward(self, x, mask=None):
        batch, seq, _ = x.shape
        
        # Project to Q, K, V then split into heads
        Q = self.split_heads(self.W_q(x))  # (batch, n_heads, seq, d_k)
        K = self.split_heads(self.W_k(x))
        V = self.split_heads(self.W_v(x))
        
        # Attention on all heads in parallel
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Merge heads: (batch, n_heads, seq, d_k) β†’ (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch, seq, self.d_model)
        
        # Final linear projection
        return self.W_o(attn_output), attn_weights

# Test
d_model = 128
n_heads = 8
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)

x = torch.randn(2, seq_len, d_model)
out, weights = mha(x)

print(f"Input shape:   {x.shape}")
print(f"Output shape:  {out.shape}  (same shape as input!)")
print(f"Attn weights:  {weights.shape}  (batch, heads, seq, seq)")
print(f"d_k per head:  {d_model // n_heads} (d_model / n_heads)")

4. Positional EncodingΒΆ

Self-attention is permutation invariant β€” it has no notion of order. Positional encoding injects position information using sine and cosine functions:

\[PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)\]
\[PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)\]

The period of each sine/cosine increases exponentially β€” giving each position a unique β€œfingerprint” at every scale.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Build positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Exponentially spaced frequencies
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices: sin
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices: cos
        
        # Register as buffer (not a parameter β€” not learned)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_seq_len, d_model)
    
    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Visualize positional encodings
d_model_vis = 64
max_len_vis = 50
pe_layer = PositionalEncoding(d_model=d_model_vis, max_seq_len=max_len_vis, dropout=0.0)
pe_matrix = pe_layer.pe[0].numpy()  # (max_len, d_model)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
im = axes[0].imshow(pe_matrix[:40, :32], cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
axes[0].set_xlabel('Embedding dimension')
axes[0].set_ylabel('Position (token index)')
axes[0].set_title('Positional Encoding: Sine/Cosine Heatmap')
plt.colorbar(im, ax=axes[0])

# Individual curves
for dim in [0, 4, 10, 20]:
    axes[1].plot(pe_matrix[:40, dim], label=f'dim {dim}')
axes[1].set_xlabel('Position (token index)')
axes[1].set_ylabel('PE value')
axes[1].set_title('PE values for different dimensions\n(lower dims: higher freq, higher dims: lower freq)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

5. Layer Norm + Residual ConnectionsΒΆ

Residual connections (x = x + sublayer(x)) solve vanishing gradients in deep networks β€” gradients can flow directly through the skip connection.

Layer Norm normalizes across the feature dimension (not batch), making training stable across varying sequence lengths.

class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# Demonstrate LayerNorm vs BatchNorm behavior
x_demo = torch.randn(4, 10, 32)  # (batch=4, seq=10, features=32)

layer_norm = nn.LayerNorm(32)
batch_norm = nn.BatchNorm1d(32)

ln_out = layer_norm(x_demo)
print(f"LayerNorm output: mean={ln_out[0,0].mean():.4f}, std={ln_out[0,0].std():.4f}")
print(f"  Normalized across features for EACH token independently")
print(f"  Works the same during training and inference")
print(f"  Works for variable sequence lengths (no batch statistics)")

6. The TransformerBlockΒΆ

class TransformerBlock(nn.Module):
    """
    One transformer encoder block:
      x β†’ LayerNorm β†’ MultiHeadAttention β†’ + (residual) β†’ x'
      x' β†’ LayerNorm β†’ FeedForward β†’ + (residual) β†’ output
    
    Note: "Pre-LN" variant (LN before sublayer) is more stable than original "Post-LN".
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.ffn       = FeedForward(d_model, d_ff, dropout)
        self.norm1     = nn.LayerNorm(d_model)
        self.norm2     = nn.LayerNorm(d_model)
        self.dropout   = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with Pre-LN + residual
        attn_out, attn_weights = self.attention(self.norm1(x), mask)
        x = x + self.dropout(attn_out)  # residual connection
        
        # Feed-forward with Pre-LN + residual
        ffn_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_out)   # residual connection
        
        return x, attn_weights

# Test the block
block = TransformerBlock(d_model=128, n_heads=8, d_ff=512, dropout=0.1)
x_in = torch.randn(2, 10, 128)  # (batch=2, seq=10, d_model=128)
x_out, weights = block(x_in)

print(f"Input shape:  {x_in.shape}")
print(f"Output shape: {x_out.shape}  (same β€” transformer preserves shape)")
print(f"\nBlock parameters: {sum(p.numel() for p in block.parameters()):,}")

7. Mini Transformer EncoderΒΆ

class TransformerEncoder(nn.Module):
    """
    Full transformer encoder: embedding + positional encoding + N transformer blocks.
    """
    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_layers, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)  # final layer norm
        self.d_model = d_model
    
    def forward(self, token_ids, mask=None):
        """
        token_ids: (batch, seq_len) β€” integer token IDs
        returns:   (batch, seq_len, d_model) β€” contextualized embeddings
        """
        # Scale embeddings by sqrt(d_model) as in the original paper
        x = self.embedding(token_ids) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x, _ = layer(x, mask)
        
        return self.norm(x)

# Build a mini transformer encoder
mini_encoder = TransformerEncoder(
    vocab_size  = 1000,
    d_model     = 128,
    n_heads     = 8,
    d_ff        = 512,
    n_layers    = 4,
    max_seq_len = 64
)

# Test with random token IDs
token_ids = torch.randint(0, 1000, (2, 20))  # batch=2, seq=20
encoder_output = mini_encoder(token_ids)

print(f"Token IDs shape:     {token_ids.shape}")
print(f"Encoder output shape: {encoder_output.shape}")
print(f"\nMini encoder params: {sum(p.numel() for p in mini_encoder.parameters()):,}")

8. Parameter Count vs BERTΒΆ

def count_params(model):
    return sum(p.numel() for p in model.parameters())

configs = {
    'Mini (ours)': dict(vocab_size=1000,  d_model=128,  n_heads=8,  d_ff=512,  n_layers=4),
    'BERT-Tiny':   dict(vocab_size=30522, d_model=128,  n_heads=2,  d_ff=512,  n_layers=2),
    'BERT-Small':  dict(vocab_size=30522, d_model=512,  n_heads=8,  d_ff=2048, n_layers=4),
    'BERT-Base':   dict(vocab_size=30522, d_model=768,  n_heads=12, d_ff=3072, n_layers=12),
    'BERT-Large':  dict(vocab_size=30522, d_model=1024, n_heads=16, d_ff=4096, n_layers=24),
}

print(f"{'Config':<15} {'d_model':>8} {'n_heads':>8} {'n_layers':>9} {'Parameters':>15}")
print('-' * 60)
for name, config in configs.items():
    model = TransformerEncoder(**config, max_seq_len=512)
    n = count_params(model)
    print(f"{name:<15} {config['d_model']:>8} {config['n_heads']:>8} {config['n_layers']:>9} {n:>15,}")

print("\nNote: BERT-Base has ~110M params (actually 110M in the original paper;")
print("      our simplified version counts slightly fewer due to simplified architecture)")
print("\nThe quadratic scaling: doubling d_model β†’ 4x more attention params")

ExercisesΒΆ

  1. Causal mask: Implement a causal (autoregressive) mask for the decoder. The mask should be a lower-triangular matrix of 1s so that position i can only attend to positions 0..i. Test that attention weights are zero for future positions.

  2. Attention visualization: Pass the sentence β€œThe cat sat on the mat” (as integer IDs) through your encoder. Plot the attention weight heatmap for each head in the first layer. Do different heads attend to different patterns?

  3. Scaling law: Build models of increasing size (vary n_layers from 2 to 12). Plot parameter count vs n_layers. Is it linear? Quadratic? Where is the bottleneck?

  4. Pre-LN vs Post-LN: Implement a Post-LN TransformerBlock (LN after attention, not before). Train both on a toy task (e.g., copy task: predict the same sequence as input). Which converges faster? Which is more stable?

  5. Learned positional encoding: Replace the sinusoidal PositionalEncoding with a learned nn.Embedding(max_seq_len, d_model). What are the tradeoffs? (Hint: think about sequences longer than max_seq_len at inference time.)