import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional
torch.manual_seed(42)
np.random.seed(42)
print("β
Imports successful!")
Before Attention: The Bottleneck ProblemΒΆ
Traditional Sequence Models (RNNs):
Input: [The, cat, sat, on, the, mat]
β β β β β β
RNN: h1 β h2 β h3 β h4 β h5 β h6 β Final hidden state
β
Bottleneck!
Problems:
All information compressed into a fixed-size vector
Long-range dependencies are lost
Sequential processing (canβt parallelize)
Attention Solution:
Input: [The, cat, sat, on, the, mat]
β β β β β β
h1 h2 h3 h4 h5 h6
β β β β β β
Attention
β
Weighted sum of ALL inputs!
Benefits:
Access to all inputs directly
Learns what to focus on
Can process in parallel
2. Scaled Dot-Product AttentionΒΆ
The core attention mechanism:
Where:
Q (Query): What am I looking for?
K (Key): What do I have to offer?
V (Value): The actual information
\(d_k\): Dimension of keys (for scaling)
IntuitionΒΆ
Think of it like a database lookup:
Query: βShow me documents about βcatsββ
Keys: Document titles/descriptions
Values: Actual document content
Attention weights: Similarity between query and each key
Output: Weighted average of values based on similarity
Attention Mechanism: Mathematical FoundationsΒΆ
1. Why Attention? The Information Bottleneck ProblemΒΆ
In sequence-to-sequence tasks with RNNs/LSTMs, all source information must pass through a fixed-size context vector:
Problem: For long sequences (\(T\) large), \(\mathbf{c}\) becomes an information bottleneck.
Solution (Bahdanau et al., 2015): Dynamically compute context as weighted sum:
where \(\alpha_{ti}\) measures relevance of source position \(i\) to target position \(t\).
2. Attention Mechanism DerivationΒΆ
Original Additive Attention (Bahdanau)ΒΆ
Alignment score between query \(\mathbf{q}\) and key \(\mathbf{k}_i\):
Attention weights via softmax:
Context vector:
Complexity: \(O(T \cdot d^2)\) per query due to two linear transformations.
Dot-Product Attention (Luong et al., 2015)ΒΆ
Simplified alignment:
Advantage: More efficient (\(O(T \cdot d)\)), leverages matrix multiplication hardware.
Issue: For large \(d\), dot products grow large in magnitude, pushing softmax into saturation regions where gradients vanish.
Scaled Dot-Product Attention (Vaswani et al., 2017)ΒΆ
Key insight: Variance of dot product scales with dimension.
If \(\mathbf{q}, \mathbf{k} \sim \mathcal{N}(0, \mathbf{I})\), then:
Solution: Scale by \(\sqrt{d_k}\) to maintain unit variance:
Mathematical justification:
Variance stabilization: \(\text{Var}\left(\frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d_k}}\right) = 1\)
Softmax gradient: For input \(x\) to softmax: $\(\frac{\partial}{\partial x_i} \text{softmax}(x)_j = \text{softmax}(x)_j (\delta_{ij} - \text{softmax}(x)_i)\)$
When \(|x_i|\) is large, softmax saturates (\(\approx 0\) or \(1\)), killing gradients.
Empirical observation: For \(d_k = 512\), unscaled dot products have std \(\approx 22\), causing gradient vanishing.
3. Multi-Head Attention: Ensemble of Attention FunctionsΒΆ
Instead of single attention with large \(d_{model}\), use \(h\) parallel heads with smaller dimensions:
where each head:
Dimensions:
\(\mathbf{W}_i^Q, \mathbf{W}_i^K \in \mathbb{R}^{d_{model} \times d_k}\) where \(d_k = d_{model}/h\)
\(\mathbf{W}_i^V \in \mathbb{R}^{d_{model} \times d_v}\) where \(d_v = d_{model}/h\)
\(\mathbf{W}^O \in \mathbb{R}^{hd_v \times d_{model}}\)
Computational complexity: Same as single-head (\(O(T^2 d_{model})\)) due to dimension reduction.
Theoretical benefits:
Representation subspaces: Each head learns different aspects:
Head 1: Syntactic dependencies (subject-verb agreement)
Head 2: Semantic similarity
Head 3: Positional patterns
Gradient flow: Multiple paths for backpropagation β better optimization
Ensemble effect: Reduces variance, similar to model ensembles
Empirical analysis (Voita et al., 2019):
Some heads are position-based (attend to nearby tokens)
Others are syntax-based (attend to specific grammatical relations)
~20-30% of heads can be pruned without performance loss
4. Self-Attention vs. Cross-AttentionΒΆ
Self-AttentionΒΆ
All three matrices (\(\mathbf{Q}, \mathbf{K}, \mathbf{V}\)) derived from same input:
Use case: Encoder in Transformers, captures relationships within sequence.
Properties:
Attention matrix \(\mathbf{A} \in \mathbb{R}^{T \times T}\) is often sparse (most attention to nearby tokens)
Complexity: \(O(T^2 d)\) - quadratic in sequence length
Cross-AttentionΒΆ
Query from one sequence, keys/values from another:
Use case: Decoder in Transformers, aligns target with source.
Example: Machine translation
Source: βThe cat satβ β \(\mathbf{K}, \mathbf{V}\)
Target: βLe chatβ β \(\mathbf{Q}\)
Attention: βLeβ attends to βTheβ, βchatβ attends to βcatβ
5. Causal (Masked) Self-AttentionΒΆ
For autoregressive generation, prevent attending to future tokens:
Mask \(\mathbf{M}\): $\(M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}\)$
Effect: Position \(i\) can only attend to positions \(\leq i\).
Implementation trick: Use \(-10^9\) instead of \(-\infty\) to avoid numerical issues.
6. Attention Complexity and Efficient VariantsΒΆ
Standard Attention ComplexityΒΆ
Time: \(O(T^2 d)\)
Matrix multiplication \(\mathbf{Q}\mathbf{K}^T\): \(O(T^2 d)\)
Softmax: \(O(T^2)\)
Weighted sum: \(O(T^2 d)\)
Space: \(O(T^2)\) for attention matrix
Bottleneck: Quadratic in sequence length β limits to \(T \lesssim 512-2048\)
Efficient Attention VariantsΒΆ
Method |
Complexity |
Approximation |
Use Case |
|---|---|---|---|
Linformer |
\(O(Td)\) |
Low-rank projection of \(\mathbf{K}, \mathbf{V}\) |
Long documents |
Reformer |
\(O(T \log T)\) |
Locality-sensitive hashing |
Very long sequences |
Performer |
\(O(Td)\) |
Kernel approximation (Random Features) |
Efficient transformers |
Linear Attention |
\(O(Td^2)\) |
Kernel trick: \(\phi(\mathbf{q})^T \phi(\mathbf{k})\) |
Real-time processing |
Flash Attention |
\(O(T^2 d)\) |
Tiling + recomputation (memory-efficient) |
Training speedup |
Linear Attention (Katharopoulos et al., 2020):
Replace softmax with kernel:
where \(\phi(\mathbf{x}) = \text{elu}(\mathbf{x}) + 1\).
Key insight: Compute \(\phi(\mathbf{K})^T \mathbf{V}\) first (order matters):
Normal: \((\mathbf{Q}\mathbf{K}^T)\mathbf{V}\) requires \(O(T^2 d)\)
Linear: \(\mathbf{Q}(\mathbf{K}^T \mathbf{V})\) requires \(O(Td^2)\) if \(d \ll T\)
7. Attention as Differentiable Key-Value MemoryΒΆ
Perspective: Attention is a soft dictionary lookup.
Hard lookup (traditional): $\(\text{output} = \mathbf{V}[\arg\max_i \text{sim}(\mathbf{q}, \mathbf{k}_i)]\)$
Soft lookup (attention): $\(\text{output} = \sum_{i=1}^T \underbrace{\text{softmax}(\text{sim}(\mathbf{q}, \mathbf{k}_i))}_{\text{retrieval probability}} \mathbf{v}_i\)$
Advantages of soft:
Differentiable β end-to-end training
Graceful degradation (distributes among top-\(k\) instead of single match)
Interpolation between memories
Connection to Hopfield Networks: Modern continuous Hopfield networks with softmax energy have update rule equivalent to attention (Ramsauer et al., 2020).
8. Positional Encodings: Why and HowΒΆ
Problem: Attention is permutation-equivariant:
Order doesnβt matter β βcat sat matβ = βmat sat catβ
Solution: Add position information to input embeddings.
Sinusoidal Positional Encoding (Vaswani et al., 2017)ΒΆ
Properties:
Bounded: \(\text{PE} \in [-1, 1]\)
Unique: Each position has unique encoding
Relative positions: \(\text{PE}(pos + k)\) is linear function of \(\text{PE}(pos)\) $\(\text{PE}(pos + k) = \mathbf{T}_k \text{PE}(pos)\)\( for some transformation matrix \)\mathbf{T}_k$.
Extrapolation: Can generalize to longer sequences than training
Learned Positional EmbeddingsΒΆ
Alternative: Learn position embeddings as parameters:
Trade-off:
More flexible (can learn task-specific patterns)
Cannot extrapolate beyond \(T_{max}\)
Relative Positional Encodings (Shaw et al., 2018)ΒΆ
Modify attention to encode relative distances:
where \(\mathbf{r}_{i-j}\) is learned relative position embedding.
Advantage: Directly models βattend to token 3 positions aheadβ patterns.
9. Attention Interpretability: What Do Models Learn?ΒΆ
Attention Weights β Explanation (Jain & Wallace, 2019)ΒΆ
Caution: High attention weight doesnβt necessarily mean βimportanceβ.
Counterexample: Model can:
Distribute attention uniformly
Make decision in feed-forward layers
Attention becomes post-hoc rationalization
Better measures:
Gradient-based: \(\frac{\partial \text{output}}{\partial \text{input}_i}\)
Erasure: Remove token, measure output change
Attention rollout: Multiply attention across layers
Typical Attention Patterns (Clark et al., 2019)ΒΆ
Positional patterns: Attend to next/previous token
Syntactic patterns: Attend to head of phrase
Rare attention: Some heads attend to special tokens ([CLS], [SEP])
Broadcasting: One token (often first) aggregates information
Observation: Lower layers learn position/syntax, higher layers learn semantics.
10. Theoretical Properties and GuaranteesΒΆ
Universal Approximation with AttentionΒΆ
Theorem (Yun et al., 2020): Transformers with \(O(\log T)\) layers can approximate any sequence-to-sequence function with bounded smoothness.
Comparison: RNNs require \(O(T)\) layers for same guarantee.
Attention as Set FunctionΒΆ
Self-attention is permutation-equivariant set function:
for any permutation \(\pi\).
Implication: Transformers are suitable for set-based tasks (e.g., point clouds, graphs) with positional encodings removed.
Expressiveness vs. RNNsΒΆ
Transformers can simulate RNNs (PΓ©rez et al., 2019):
With appropriate attention patterns, Transformers can compute any RNN
Converse not true: RNNs cannot efficiently simulate Transformers
Reason: \(O(1)\) depth for parallel composition vs. \(O(T)\) sequential composition.
# Advanced Attention Implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple
# ============================================================
# 1. Multi-Head Attention with Full Details
# ============================================================
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention as in 'Attention Is All You Need'.
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Split last dimension into (num_heads, d_k)."""
batch_size, seq_len, d_model = x.shape
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, num_heads, seq_len, d_k)
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of multi-head attention.
Args:
Q, K, V: (batch, seq_len, d_model)
mask: (batch, 1, seq_len, seq_len) or (batch, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
attention_weights: (batch, num_heads, seq_len, seq_len)
"""
batch_size = Q.size(0)
# 1. Linear projections
Q = self.W_q(Q) # (batch, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# 2. Split into multiple heads
Q = self.split_heads(Q) # (batch, num_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 4. Apply attention to values
context = torch.matmul(attention_weights, V) # (batch, num_heads, seq_len, d_k)
# 5. Concatenate heads
context = context.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, d_k)
context = context.view(batch_size, -1, self.d_model) # (batch, seq_len, d_model)
# 6. Output projection
output = self.W_o(context)
return output, attention_weights
# ============================================================
# 2. Causal (Masked) Attention for Autoregressive Models
# ============================================================
def create_causal_mask(seq_len: int, device: str = 'cpu') -> torch.Tensor:
"""
Create causal mask for autoregressive attention.
Returns:
mask: (seq_len, seq_len) lower triangular matrix
"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask
# ============================================================
# 3. Linear (Efficient) Attention
# ============================================================
class LinearAttention(nn.Module):
"""
Linear complexity attention using kernel trick.
Complexity: O(T * d^2) instead of O(T^2 * d)
"""
def __init__(self, d_model: int, num_heads: int = 8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def feature_map(self, x: torch.Tensor) -> torch.Tensor:
"""Kernel feature map: Ο(x) = elu(x) + 1"""
return F.elu(x) + 1
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
"""
Linear attention: Ο(Q)(Ο(K)^T V) / Ο(Q)(Ο(K)^T 1)
Args:
Q, K, V: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
batch_size, seq_len, _ = Q.shape
# Linear projections
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# Apply feature map
Q = self.feature_map(Q) # Ο(Q)
K = self.feature_map(K) # Ο(K)
# Efficient computation: (Ο(K)^T V) first
# This is key: changes complexity from O(T^2 d) to O(T d^2)
KV = torch.einsum('bnd,bnm->bdm', K, V) # (batch, d_model, d_model)
# Normalization: Ο(K)^T 1
K_sum = K.sum(dim=1, keepdim=True) # (batch, 1, d_model)
# Output: Ο(Q) (Ο(K)^T V) / Ο(Q) (Ο(K)^T 1)
numerator = torch.einsum('bnd,bdm->bnm', Q, KV) # (batch, seq_len, d_model)
denominator = torch.einsum('bnd,bkd->bn', Q, K_sum).unsqueeze(-1) # (batch, seq_len, 1)
output = numerator / (denominator + 1e-6)
output = self.W_o(output)
return output
# ============================================================
# 4. Demonstration and Comparison
# ============================================================
# Setup
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2
# Create sample inputs
X = torch.randn(batch_size, seq_len, d_model)
# Standard Multi-Head Attention
mha = MultiHeadAttention(d_model, num_heads)
mha_output, attention_weights = mha(X, X, X)
print("="*70)
print("MULTI-HEAD ATTENTION")
print("="*70)
print(f"Input shape: {X.shape}")
print(f"Output shape: {mha_output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f" β {num_heads} heads, each with ({seq_len} Γ {seq_len}) attention matrix")
# Causal Attention
causal_mask = create_causal_mask(seq_len).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mha_causal_output, causal_attention = mha(X, X, X, mask=causal_mask)
print("\n" + "="*70)
print("CAUSAL ATTENTION")
print("="*70)
print("Causal mask (position i can only attend to j β€ i):")
print(causal_mask[0, 0].numpy().astype(int))
print(f"\nCausal attention output shape: {mha_causal_output.shape}")
# Linear Attention
linear_attn = LinearAttention(d_model, num_heads)
linear_output = linear_attn(X, X, X)
print("\n" + "="*70)
print("LINEAR ATTENTION")
print("="*70)
print(f"Output shape: {linear_output.shape}")
print(f"Complexity: O(TΒ·dΒ²) = O({seq_len}Β·{d_model}Β²) = {seq_len * d_model**2:,}")
print(f" vs Standard: O(TΒ²Β·d) = O({seq_len}Β²Β·{d_model}) = {seq_len**2 * d_model:,}")
print(f"Speedup factor for long sequences (T >> d): ~{seq_len / d_model:.2f}Γ")
# ============================================================
# 5. Visualization: Attention Patterns
# ============================================================
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Attention Mechanism Analysis', fontsize=16, fontweight='bold')
# Plot attention weights for different heads
for i in range(min(3, num_heads)):
ax = axes[0, i]
attn_map = attention_weights[0, i].detach().numpy()
sns.heatmap(attn_map, annot=False, cmap='viridis', cbar=True, ax=ax,
vmin=0, vmax=attn_map.max())
ax.set_title(f'Head {i+1} Attention Pattern')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
# Causal attention visualization
ax = axes[1, 0]
causal_map = causal_attention[0, 0].detach().numpy()
sns.heatmap(causal_map, annot=False, cmap='RdYlGn', cbar=True, ax=ax)
ax.set_title('Causal Attention (Head 1)')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
# Average attention across heads
ax = axes[1, 1]
avg_attention = attention_weights[0].mean(dim=0).detach().numpy()
sns.heatmap(avg_attention, annot=False, cmap='Blues', cbar=True, ax=ax)
ax.set_title('Average Attention Across All Heads')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
# Attention statistics
ax = axes[1, 2]
ax.axis('off')
# Compute statistics
max_attention_per_head = attention_weights[0].max(dim=-1)[0].mean(dim=-1)
entropy = -torch.sum(attention_weights[0] * torch.log(attention_weights[0] + 1e-9), dim=-1).mean(dim=-1)
stats_text = f"""
ATTENTION STATISTICS (Batch 1)
Per Head:
Max attention weight:
{', '.join([f'{x:.3f}' for x in max_attention_per_head[:4].tolist()])}...
Entropy (bits):
{', '.join([f'{x:.2f}' for x in entropy[:4].tolist()])}...
Global:
Total parameters: {sum(p.numel() for p in mha.parameters()):,}
- W_q, W_k, W_v: {d_model}Γ{d_model} each
- W_o: {d_model}Γ{d_model}
Complexity:
Time: O(TΒ² Β· d) = O({seq_len}Β² Β· {d_model})
Space: O(TΒ²) for attention matrix
Observations:
β’ Different heads learn different patterns
β’ Causal mask enforces left-to-right flow
β’ High entropy β diffuse attention
β’ Low entropy β focused attention
"""
ax.text(0.1, 0.5, stats_text, fontsize=10, verticalalignment='center',
family='monospace', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("VISUALIZATION COMPLETE")
print("="*70)
print("β Multi-head attention learns diverse patterns")
print("β Causal mask prevents future information leakage")
print("β Linear attention provides O(TΒ·dΒ²) complexity alternative")
print("="*70)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled Dot-Product Attention
Args:
Q: Query matrix (batch_size, seq_len, d_k)
K: Key matrix (batch_size, seq_len, d_k)
V: Value matrix (batch_size, seq_len, d_v)
mask: Optional mask (batch_size, seq_len, seq_len)
Returns:
output: Attention output (batch_size, seq_len, d_v)
attention_weights: Attention weights (batch_size, seq_len, seq_len)
"""
d_k = Q.size(-1)
# 1. Compute attention scores: QΒ·K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq, seq)
# 2. Scale by sqrt(d_k) to prevent softmax saturation
scores = scores / np.sqrt(d_k)
# 3. Apply mask (optional) - set masked positions to -inf
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Apply softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1) # (batch, seq, seq)
# 5. Weighted sum of values
output = torch.matmul(attention_weights, V) # (batch, seq, d_v)
return output, attention_weights
# Example: Simple attention
batch_size = 1
seq_len = 4
d_k = 8
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)
output, attention_weights = scaled_dot_product_attention(Q, K, V)
print("Input shapes:")
print(f" Q: {Q.shape}")
print(f" K: {K.shape}")
print(f" V: {V.shape}")
print(f"\nOutput shapes:")
print(f" Output: {output.shape}")
print(f" Attention weights: {attention_weights.shape}")
print(f"\nAttention weights (row i = how much position i attends to each position):")
print(attention_weights[0].detach().numpy())
print(f"\nβ
Each row sums to 1: {attention_weights[0].sum(dim=-1)}")
Visualizing AttentionΒΆ
Attention weights form a matrix where entry \((i, j)\) tells us how much position \(i\) attends to position \(j\). Plotting this matrix as a heatmap reveals the relationships the model has learned: in a well-trained language model, you would see pronouns attending strongly to their antecedents, verbs attending to their subjects, and so on. The softmax normalization ensures each row sums to 1, making every row a probability distribution over source positions. Visualizing these patterns is one of the key advantages of attention over opaque recurrent architectures β it provides built-in interpretability.
def visualize_attention(attention_weights, words=None):
"""
Visualize attention weights as a heatmap
"""
# Get attention matrix (seq_len, seq_len)
att = attention_weights[0].detach().numpy()
# Default labels
if words is None:
words = [f"Token {i+1}" for i in range(att.shape[0])]
# Plot heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(att, annot=True, fmt='.2f', cmap='Blues',
xticklabels=words, yticklabels=words,
cbar_kws={'label': 'Attention Weight'})
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.title('Attention Weights Heatmap')
plt.tight_layout()
plt.show()
# Visualize
visualize_attention(attention_weights, words=['The', 'cat', 'sat', 'down'])
3. Self-Attention Example β Understanding ContextΒΆ
How Words Inform Each OtherΒΆ
In self-attention, the queries, keys, and values all come from the same sequence. Each token asks βwhich other tokens in this sentence are relevant to me?β and aggregates their representations accordingly. For instance, in the sentence βThe animal didnβt cross the street because it was too tired,β self-attention helps the model determine that βitβ refers to βanimalβ rather than βstreetβ by assigning a high attention weight between those positions. The example below constructs a small vocabulary, assigns random embeddings, and computes self-attention so you can inspect the resulting weight matrix.
class SelfAttention(nn.Module):
"""
Self-Attention Layer
Learns to create Q, K, V from the same input
"""
def __init__(self, embed_dim):
super().__init__()
# Linear projections for Q, K, V
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.embed_dim = embed_dim
def forward(self, x, mask=None):
"""
Args:
x: Input (batch_size, seq_len, embed_dim)
mask: Optional mask
Returns:
output: Attention output
attention_weights: Attention weights
"""
# Project to Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# Apply attention
output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
return output, attention_weights
# Example: Sentence with ambiguous word
# "The bank can guarantee deposits will eventually cover future tuition costs"
# Word "bank" - is it financial institution or river bank?
# Simulate word embeddings
embed_dim = 16
seq_len = 5
words = ['The', 'bank', 'guarantees', 'deposits', 'money']
# Create embeddings (in real NLP, these come from word2vec, BERT, etc.)
embeddings = torch.randn(1, seq_len, embed_dim)
# Apply self-attention
self_attn = SelfAttention(embed_dim)
output, attention_weights = self_attn(embeddings)
print(f"Input shape: {embeddings.shape}")
print(f"Output shape: {output.shape}")
print(f"\nAttention pattern for 'bank' (position 1):")
print(f" Attends to 'The': {attention_weights[0, 1, 0]:.3f}")
print(f" Attends to 'bank': {attention_weights[0, 1, 1]:.3f}")
print(f" Attends to 'guarantees': {attention_weights[0, 1, 2]:.3f}")
print(f" Attends to 'deposits': {attention_weights[0, 1, 3]:.3f}")
print(f" Attends to 'money': {attention_weights[0, 1, 4]:.3f}")
# Visualize
visualize_attention(attention_weights, words=words)
4. Multi-Head AttentionΒΆ
Idea: Run multiple attention mechanisms in parallel, each learning different patterns!
where: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$
Benefits:
Different heads can focus on different aspects
One head might learn syntax, another semantics
More expressive power
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention Layer
"""
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Linear projections for all heads (combined)
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
# Output projection
self.out = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
"""
Args:
x: Input (batch_size, seq_len, embed_dim)
Returns:
output: Multi-head attention output
attention_weights: List of attention weights per head
"""
batch_size, seq_len, embed_dim = x.size()
# 1. Linear projections
Q = self.query(x) # (batch, seq, embed_dim)
K = self.key(x)
V = self.value(x)
# 2. Split into multiple heads
# Reshape: (batch, seq, embed_dim) β (batch, seq, num_heads, head_dim)
# Transpose: (batch, num_heads, seq, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 3. Apply attention for each head
d_k = self.head_dim
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# 4. Concatenate heads
# Transpose: (batch, num_heads, seq, head_dim) β (batch, seq, num_heads, head_dim)
attention_output = attention_output.transpose(1, 2).contiguous()
# Reshape: (batch, seq, embed_dim)
attention_output = attention_output.view(batch_size, seq_len, embed_dim)
# 5. Final linear projection
output = self.out(attention_output)
return output, attention_weights
# Example
embed_dim = 64
num_heads = 8
seq_len = 10
batch_size = 2
mha = MultiHeadAttention(embed_dim, num_heads)
x = torch.randn(batch_size, seq_len, embed_dim)
output, attention_weights = mha(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f" (batch_size, num_heads, seq_len, seq_len)")
print(f"\nNumber of parameters: {sum(p.numel() for p in mha.parameters()):,}")
Visualizing Different HeadsΒΆ
Each attention head in a multi-head attention layer learns to focus on a different linguistic or structural pattern. One head might capture syntactic dependencies (subject-verb agreement), another might track positional proximity, and a third might specialize in coreference resolution. By visualizing the attention weight matrices of individual heads side by side, we can observe this specialization directly. The diversity of patterns across heads is what makes multi-head attention more expressive than a single attention mechanism with the same total dimensionality.
def visualize_multi_head_attention(attention_weights, num_heads=4):
"""
Visualize attention patterns from different heads
"""
fig, axes = plt.subplots(2, num_heads // 2, figsize=(15, 6))
axes = axes.ravel()
for i in range(num_heads):
att = attention_weights[0, i].detach().numpy() # (seq_len, seq_len)
sns.heatmap(att, annot=False, cmap='viridis', ax=axes[i],
cbar=True, square=True)
axes[i].set_title(f'Head {i+1}')
axes[i].set_xlabel('Key')
axes[i].set_ylabel('Query')
plt.tight_layout()
plt.show()
print("\nπ Observations:")
print(" β’ Different heads show different patterns")
print(" β’ Some heads might focus on local context")
print(" β’ Others might capture long-range dependencies")
print(" β’ This diversity makes the model more expressive!")
# Visualize
visualize_multi_head_attention(attention_weights, num_heads=8)
5. Masked Attention - For Autoregressive ModelsΒΆ
Use case: Language models that generate text one word at a time
Problem: When predicting word \(i\), we canβt see words \(i+1, i+2, ...\) (future words)
Solution: Mask future positions
def create_causal_mask(seq_len):
"""
Create causal (lower triangular) mask
Prevents attending to future positions
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# Example
seq_len = 5
mask = create_causal_mask(seq_len)
print("Causal mask:")
print(mask.numpy())
print("\n1 = can attend, 0 = cannot attend (future positions)")
# Visualize
plt.figure(figsize=(6, 5))
sns.heatmap(mask.numpy(), annot=True, fmt='g', cmap='Blues',
cbar=False, square=True)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask\n(Autoregressive Generation)')
plt.tight_layout()
plt.show()
print("\nπ‘ Usage:")
print(" Position 0 can only see position 0")
print(" Position 1 can see positions 0-1")
print(" Position 2 can see positions 0-2")
print(" etc.")
# Apply masked attention
embed_dim = 16
seq_len = 5
words = ['Once', 'upon', 'a', 'time', '<predict>']
x = torch.randn(1, seq_len, embed_dim)
mask = create_causal_mask(seq_len).unsqueeze(0).unsqueeze(0) # (1, 1, seq, seq)
# Self-attention with mask
self_attn = SelfAttention(embed_dim)
output, attention_weights = self_attn(x, mask=mask)
print("Masked attention weights:")
visualize_attention(attention_weights, words=words)
print("\nβ
Notice: Upper triangle is all zeros (can't attend to future)!")
6. Cross-Attention - Connecting Two SequencesΒΆ
Use case: Machine translation, image captioning
Difference from self-attention:
Self-attention: Q, K, V all from same sequence
Cross-attention: Q from one sequence, K and V from another
Example: Translating English to French
Q: French words being generated
K, V: English words (source sentence)
class CrossAttention(nn.Module):
"""
Cross-Attention Layer
Query from one sequence, Key and Value from another
"""
def __init__(self, embed_dim):
super().__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
def forward(self, target, source, mask=None):
"""
Args:
target: Target sequence (batch, target_len, embed_dim)
source: Source sequence (batch, source_len, embed_dim)
Returns:
output: Cross-attention output
attention_weights: Attention from target to source
"""
# Q from target, K and V from source
Q = self.query(target)
K = self.key(source)
V = self.value(source)
output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
return output, attention_weights
# Example: Translation
embed_dim = 16
source_len = 4 # English: "I love cats"
target_len = 3 # French: "J'aime les" (generating "chats")
english_words = ['I', 'love', 'cats', '<EOS>']
french_words = ["J'aime", 'les', '<predict>']
source = torch.randn(1, source_len, embed_dim) # English embeddings
target = torch.randn(1, target_len, embed_dim) # French embeddings
cross_attn = CrossAttention(embed_dim)
output, attention_weights = cross_attn(target, source)
print(f"Source (English): {source.shape}")
print(f"Target (French): {target.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {attention_weights.shape}")
print(" (Each French word attends to all English words)")
# Visualize
plt.figure(figsize=(8, 6))
att = attention_weights[0].detach().numpy()
sns.heatmap(att, annot=True, fmt='.2f', cmap='Greens',
xticklabels=english_words, yticklabels=french_words,
cbar_kws={'label': 'Attention Weight'})
plt.xlabel('Source (English)')
plt.ylabel('Target (French)')
plt.title('Cross-Attention: French β English')
plt.tight_layout()
plt.show()
print("\nπ‘ Interpretation:")
print(" Row 0: When generating \"J'aime\", which English words to focus on?")
print(" Row 1: When generating \"les\", which English words to focus on?")
print(" Row 2: When generating next word, which English words to focus on?")
7. Practical Application β Sequence ClassificationΒΆ
Putting Attention to WorkΒΆ
Attention mechanisms are not just theoretical curiosities; they are the core component of virtually every modern NLP system. In this section we build a small sequence classifier that uses an attention layer to weigh the importance of each token before making a prediction. The attention output is pooled (typically by averaging or taking the first token) and fed into a linear classifier. This architecture mirrors how models like BERT produce sentence-level predictions by attending over all token representations and then projecting the pooled output to class logits.
class AttentionClassifier(nn.Module):
"""
Simple classifier using self-attention
"""
def __init__(self, vocab_size, embed_dim, num_heads, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.attention = MultiHeadAttention(embed_dim, num_heads)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
"""
Args:
x: Token indices (batch, seq_len)
Returns:
logits: Class logits (batch, num_classes)
"""
# Embed tokens
x = self.embedding(x) # (batch, seq, embed)
# Apply attention
x, _ = self.attention(x) # (batch, seq, embed)
# Pool: take mean across sequence
x = x.mean(dim=1) # (batch, embed)
# Classify
logits = self.fc(x) # (batch, num_classes)
return logits
# Example
vocab_size = 1000
embed_dim = 64
num_heads = 4
num_classes = 2 # Binary classification
model = AttentionClassifier(vocab_size, embed_dim, num_heads, num_classes)
# Dummy batch
batch_size = 4
seq_len = 20
x = torch.randint(0, vocab_size, (batch_size, seq_len))
logits = model(x)
print(f"Input shape: {x.shape}")
print(f"Output logits: {logits.shape}")
print(f"\nPredictions: {torch.argmax(logits, dim=1)}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
SummaryΒΆ
β What You LearnedΒΆ
Attention Mechanism: Revolutionary approach to sequence modeling
Scaled Dot-Product: Core attention computation
Self-Attention: Learning relationships within a sequence
Multi-Head Attention: Multiple attention patterns in parallel
Masked Attention: For autoregressive generation
Cross-Attention: Connecting two different sequences
Applications: Classification, translation, generation
π Key FormulasΒΆ
Attention: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)$
Multi-Head: $\(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\)$
π‘ Key InsightsΒΆ
Q, K, V: Think of them as query, database keys, and database values
Softmax: Ensures attention weights sum to 1 (weighted average)
Scaling: \(\sqrt{d_k}\) prevents softmax saturation
Multi-head: Different heads learn different relationships
Self-attention: All positions can attend to all positions
Cross-attention: Connect encoder and decoder
π― Whatβs Next?ΒΆ
Next notebook: 05_transformer_architecture.ipynb
Youβll learn:
Complete Transformer architecture
Positional encoding
Encoder and Decoder stacks
Training transformers
Fine-tuning pre-trained models (BERT, GPT)
π Additional ResourcesΒΆ
Incredible work! You now understand the attention mechanism that powers modern NLP models like BERT, GPT, and T5! π―