import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")
1. Attention Complexity ProblemΒΆ
Standard AttentionΒΆ
Complexity:
Time: \(O(n^2 d)\) where \(n\) = sequence length
Space: \(O(n^2)\) for attention matrix
Problem: Quadratic scaling prevents long sequences!
SolutionsΒΆ
Method |
Complexity |
Idea |
|---|---|---|
Linformer |
\(O(nk)\) |
Low-rank projection |
Performer |
\(O(nd^2)\) |
Kernel approximation |
Reformer |
\(O(n\log n)\) |
LSH attention |
Sparse |
\(O(n\sqrt{n})\) |
Fixed patterns |
π Reference Materials:
transformer.pdf - Transformer
# Benchmark standard attention
def standard_attention(Q, K, V):
scores = Q @ K.transpose(-2, -1) / np.sqrt(Q.size(-1))
attn = F.softmax(scores, dim=-1)
return attn @ V
# Measure complexity
seq_lengths = [128, 256, 512, 1024]
times = []
d_model = 64
batch_size = 4
for n in seq_lengths:
Q = torch.randn(batch_size, n, d_model).to(device)
K = torch.randn(batch_size, n, d_model).to(device)
V = torch.randn(batch_size, n, d_model).to(device)
start = time.time()
for _ in range(10):
_ = standard_attention(Q, K, V)
times.append((time.time() - start) / 10)
plt.figure(figsize=(8, 5))
plt.plot(seq_lengths, times, marker='o', label='Standard Attention')
plt.xlabel('Sequence Length', fontsize=11)
plt.ylabel('Time (s)', fontsize=11)
plt.title('Attention Complexity', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
2. Linformer: Low-Rank AttentionΒΆ
Key InsightΒΆ
Attention matrix is often low-rank!
MethodΒΆ
Project \(K, V\) to lower dimension \(k \ll n\):
where \(E, F \in \mathbb{R}^{k \times n}\).
Complexity: \(O(nkd)\) instead of \(O(n^2d)\)!
class LinformerAttention(nn.Module):
def __init__(self, d_model, n_heads, max_len=512, k=64):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.k = k
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
# Low-rank projections
self.E = nn.Linear(max_len, k, bias=False)
self.F = nn.Linear(max_len, k, bias=False)
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.split(self.d_model, dim=2)
# Reshape
q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Project K, V to lower dimension
k = k.transpose(-2, -1) # (B, h, d_k, T)
k = self.E(k) # (B, h, d_k, k)
k = k.transpose(-2, -1) # (B, h, k, d_k)
v = v.transpose(-2, -1)
v = self.F(v)
v = v.transpose(-2, -1)
# Attention
scores = q @ k.transpose(-2, -1) / np.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
# Test
linformer = LinformerAttention(d_model=64, n_heads=4, max_len=512, k=64).to(device)
x = torch.randn(2, 128, 64).to(device)
out = linformer(x)
print(f"Output shape: {out.shape}")
3. Performer: Kernel ApproximationΒΆ
Key InsightΒΆ
Rewrite attention without explicit softmax:
Kernel TrickΒΆ
Approximate: $\(\exp(q^T k) \approx \phi(q)^T \phi(k)\)$
using random features \(\phi: \mathbb{R}^d \to \mathbb{R}^r\).
Fast ComputationΒΆ
Complexity: \(O(nrd)\) with \(r \ll n\)!
class PerformerAttention(nn.Module):
def __init__(self, d_model, n_heads, n_features=64):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.n_features = n_features
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def kernel_feature_map(self, x):
"""Positive random features."""
B, H, T, D = x.size()
# Random projection
omega = torch.randn(D, self.n_features, device=x.device) / np.sqrt(D)
x_proj = x @ omega
# Positive features
return torch.exp(x_proj - torch.max(x_proj, dim=-1, keepdim=True)[0])
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.split(self.d_model, dim=2)
q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Feature maps
q_prime = self.kernel_feature_map(q)
k_prime = self.kernel_feature_map(k)
# Linear attention
kv = k_prime.transpose(-2, -1) @ v
z = q_prime @ kv
# Normalization
normalizer = q_prime @ k_prime.sum(dim=-2, keepdim=True).transpose(-2, -1)
out = z / (normalizer + 1e-6)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
# Test
performer = PerformerAttention(d_model=64, n_heads=4, n_features=64).to(device)
x = torch.randn(2, 128, 64).to(device)
out = performer(x)
print(f"Output shape: {out.shape}")
ComparisonΒΆ
Benchmarking the efficient Transformer variants against standard self-attention reveals the core trade-off: reduced computational complexity (from \(O(n^2)\) to \(O(n \log n)\) or \(O(n)\)) at the cost of some approximation error in the attention pattern. Linear attention methods (e.g., Performer) replace the softmax kernel with a feature map approximation, achieving \(O(n)\) time and memory. Sparse attention methods (e.g., Longformer, BigBird) restrict each token to attend to a fixed local window plus selected global tokens. Comparing wall-clock time, memory usage, and downstream task accuracy across sequence lengths helps practitioners choose the right variant for their use case β exact attention for short sequences, efficient variants for long-document or genomic applications.
# Benchmark all methods
seq_lengths = [128, 256, 512, 1024, 2048]
d_model = 64
batch_size = 2
times_standard = []
times_linformer = []
times_performer = []
for n in seq_lengths:
x = torch.randn(batch_size, n, d_model).to(device)
# Standard (skip if too long)
if n <= 1024:
Q = K = V = x
start = time.time()
_ = standard_attention(Q, K, V)
times_standard.append(time.time() - start)
else:
times_standard.append(None)
# Linformer
linf = LinformerAttention(d_model, n_heads=4, max_len=max(seq_lengths), k=64).to(device)
start = time.time()
_ = linf(x)
times_linformer.append(time.time() - start)
# Performer
perf = PerformerAttention(d_model, n_heads=4, n_features=64).to(device)
start = time.time()
_ = perf(x)
times_performer.append(time.time() - start)
plt.figure(figsize=(10, 6))
valid_std = [t for t in times_standard if t is not None]
plt.plot(seq_lengths[:len(valid_std)], valid_std, marker='o', label='Standard')
plt.plot(seq_lengths, times_linformer, marker='s', label='Linformer')
plt.plot(seq_lengths, times_performer, marker='^', label='Performer')
plt.xlabel('Sequence Length', fontsize=11)
plt.ylabel('Time (s)', fontsize=11)
plt.title('Efficient Attention Comparison', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.show()
SummaryΒΆ
Problem:ΒΆ
Standard attention: \(O(n^2d)\) complexity
Solutions:ΒΆ
Method |
Complexity |
Tradeoff |
|---|---|---|
Linformer |
\(O(nkd)\) |
Low-rank assumption |
Performer |
\(O(nrd)\) |
Kernel approximation |
Reformer |
\(O(n\log n)\) |
LSH bucketing |
Longformer |
\(O(n)\) |
Sparse patterns |
Linformer:ΒΆ
Project \(K, V\) to dimension \(k\)
Good when attention is low-rank
Performer:ΒΆ
Kernel approximation via random features
Linear complexity in sequence length
Unbiased approximation
Applications:ΒΆ
Long documents (16k+ tokens)
DNA sequences
Video processing
Audio generation
Next Steps:ΒΆ
13_gpt_architecture.ipynb - Full transformer
Explore FlashAttention (memory-efficient)
Test on long-range tasks
Advanced Efficient Transformers TheoryΒΆ
1. The Quadratic Complexity ProblemΒΆ
1.1 Standard Self-AttentionΒΆ
Standard Transformer self-attention has O(NΒ²) time and memory complexity:
Attention formula:
Attention(Q, K, V) = softmax(QK^T / βd_k) V
Computational breakdown:
QK^T: Matrix multiplication (NΓd_k) @ (d_kΓN) = O(NΒ²d_k)
Softmax: O(NΒ²) (requires materializing NΓN attention matrix)
Multiply by V: (NΓN) @ (NΓd_v) = O(NΒ²d_v)
Total: O(NΒ²d) time, O(NΒ²) memory for attention matrix
Why this is problematic:
N=1024 tokens: 1M attention scores
N=4096 tokens: 16M attention scores (16Γ larger!)
GPU memory limits: Canβt fit long sequences
1.2 Bottleneck AnalysisΒΆ
The attention matrix A = softmax(QK^T/βd_k) β β^(NΓN) is the bottleneck:
Must materialize all NΒ² attention scores
Cannot be computed incrementally (softmax couples all scores)
Memory scales quadratically with sequence length
2. Linear AttentionΒΆ
2.1 Kernel Trick ReformulationΒΆ
Key insight: Reorder operations to avoid materializing NΓN matrix.
Standard attention:
Attention(Q, K, V) = softmax(QK^T) V
= (exp(QK^T) / Ξ£_j exp(q_i^T k_j)) V
Kernel trick: Define similarity kernel k(q, k) = Ο(q)^T Ο(k)
Linear attention:
Attention(Q, K, V) = Ο(Q)(Ο(K)^T V) / (Ο(Q)(Ο(K)^T 1_N))
Complexity:
Ο(K)^T V: (dβΓN) @ (NΓd_v) = O(Ndβd_v)
Ο(Q) Γ result: (NΓdβ) @ (dβΓd_v) = O(Ndβd_v)
Total: O(Ndβd_v) (linear in N!)
Feature map choices:
Identity: Ο(x) = x (no softmax, loses attention properties)
ReLU: Ο(x) = ReLU(x) (non-negative, fast)
ELU+1: Ο(x) = elu(x) + 1 (smooth, non-negative)
Theoretical tradeoff:
Pro: O(N) complexity, constant memory
Con: Loses softmax normalization (changes attention semantics)
Con: No bounded attention weights (can have exploding values)
2.2 Causal Linear AttentionΒΆ
For autoregressive generation, need causal masking. Can be computed incrementally:
Recurrent form:
S_t = S_(t-1) + Ο(k_t) β v_t (running sum, O(d'd_v) per step)
z_t = z_(t-1) + Ο(k_t) (normalization, O(d') per step)
y_t = Ο(q_t)^T S_t / (Ο(q_t)^T z_t)
Advantage: O(1) per token in generation (vs O(N) for standard attention)
3. Performer (FAVOR+)ΒΆ
3.1 Positive Random FeaturesΒΆ
Problem with linear attention: Feature maps Ο(x) donβt approximate softmax well.
Performer solution: Use random Fourier features to approximate exp(q^T k):
FAVOR+ (Fast Attention Via positive Orthogonal Random features):
exp(q^T k / βd) β E[Ο(q)^T Ο(k)]
Feature map:
Ο(x) = (1/βm) [exp(w_1^T x - ||x||Β²/2), ..., exp(w_m^T x - ||x||Β²/2)]
where w_i ~ N(0, I) are random Gaussian projections.
Orthogonal random features (ORF): Sample w_1, β¦, w_m from orthogonalized Gaussian matrix β lower variance estimator.
Renormalization:
Ο(x) = h(x) [f_1(x), ..., f_m(x)]
h(x) = exp(-||x||Β²/2)
f_i(x) = exp(w_i^T x)
Approximation quality:
Theorem: For m random features, approximation error is O(1/βm)
Typical m=256 gives good approximation with O(NΒ·256Β·d) = O(Nd) complexity
3.2 Complexity AnalysisΒΆ
Standard attention: O(NΒ²d)
Performer: O(Nmd) where m << N
Typical settings: m=256, N=4096 β 16Γ speedup
Memory:
Standard: O(NΒ²) for attention matrix
Performer: O(md) for S = Ο(K)^T V (constant in N!)
3.3 Theoretical GuaranteesΒΆ
Approximation bound (Choromanski et al., 2020):
|softmax(QK^T/βd)V - Ο(Q)(Ο(K)^TV)| β€ Ξ΅
with probability β₯ 1-Ξ΄ if m β₯ O(log(N/Ξ΄)/Ρ²)
Practical implications:
m=256: Ξ΅ β 0.1 (10% error acceptable for many tasks)
Higher m: Better approximation but higher cost
4. ReformerΒΆ
4.1 Locality-Sensitive Hashing (LSH) AttentionΒΆ
Key insight: Most attention weights are near-zero. Only attend to similar keys.
LSH basics: Hash function h such that similar items collide:
P(h(x) = h(y)) β similarity(x, y)
Angular LSH for cosine similarity:
h(x) = sign(r^T x) where r ~ N(0, I)
Multi-round hashing: Use b hash functions [h_1, β¦, h_b], concatenate to get bucket:
bucket(x) = [h_1(x), ..., h_b(x)]
LSH attention algorithm:
Hash queries and keys into buckets
Sort by bucket (queries and keys together)
Only compute attention within buckets (and adjacent for smoothness)
Complexity: O(N log N) (dominated by sorting)
Attention sparsity:
Each query attends to ~N/n_buckets keys (+ neighbors)
Typical: n_buckets = N/32 β 32Γ reduction in computations
4.2 Reversible LayersΒΆ
Problem: Storing activations for backprop requires O(NΒ·LΒ·d) memory (L layers).
Reversible residual network (RevNet): Split activations into two parts [x1, x2], compute:
y1 = x1 + F(x2)
y2 = x2 + G(y1)
Inverse (for backprop):
x2 = y2 - G(y1)
x1 = y1 - F(x2)
Advantage: Donβt need to store intermediate activations! Can recompute during backprop.
Reformer layers:
Y1 = X1 + Attention(X2)
Y2 = X2 + FeedForward(Y1)
Memory savings:
Standard Transformer: O(NΒ·LΒ·d) activation memory
Reformer: O(NΒ·d) (only store inputs, recompute rest)
4.3 Chunked Feed-ForwardΒΆ
Split sequence into chunks, process feed-forward chunk-by-chunk to save memory:
for chunk in chunks(X):
chunk = chunk + FeedForward(chunk)
No cross-chunk dependencies in FFN β parallelizable.
5. LongformerΒΆ
5.1 Sliding Window AttentionΒΆ
Local attention: Each token attends to w neighbors on each side.
Complexity: O(NΒ·w) where w << N (typically w=512 even for N=4096)
Pattern:
Attention mask for position i: [i-w, ..., i, ..., i+w]
Advantages:
Linear complexity in sequence length
Captures local context (most dependencies are local)
Can use different window sizes per layer
Implementation: Sparse attention matrix A where A[i,j] = 0 if |i-j| > w.
5.2 Dilated Sliding WindowΒΆ
Problem: Sliding window misses long-range dependencies.
Dilated attention: Skip every d tokens in window.
Effective receptive field:
Layer 0: window w, dilation 1 β range w
Layer 1: window w, dilation 2 β range 2w
Layer 2: window w, dilation 4 β range 4w
Exponential growth in receptive field!
Pattern for dilation d:
Attend to positions: [..., i-dΒ·w, ..., i-d, i, i+d, ..., i+dΒ·w]
5.3 Global AttentionΒΆ
Task-specific tokens: Some tokens (e.g., [CLS]) need to attend to all tokens.
Hybrid pattern:
Most tokens: Local attention O(NΒ·w)
Few global tokens (g << N): Full attention O(gΒ·N)
Total: O(NΒ·w + gΒ·N) β O(NΒ·w) if g is small
Example (document QA):
Question tokens: Global attention (attend to all doc tokens)
Document tokens: Local attention (sliding window)
5.4 Longformer Attention PatternsΒΆ
Three types combined:
Sliding window: All tokens, w=512
Dilated window: Selected layers, exponentially increasing dilation
Global: Task tokens ([CLS], question tokens)
Complexity analysis:
Standard: O(NΒ²)
Longformer: O(NΒ·w + gΒ·N) where w=512, g << N
For N=4096, w=512, g=10: ~8Γ speedup
6. SynthesizerΒΆ
6.1 Learned Attention PatternsΒΆ
Key insight: Maybe we donβt need data-dependent attention QK^T. Learn attention patterns directly.
Dense Synthesizer:
A = softmax(W_A) where W_A β β^(NΓN) is learned
Output = A V
Problem: W_A has O(NΒ²) parameters (not generalizable to different lengths).
Factored Synthesizer:
A[i,j] = softmax(w_i^T w_j) where w_i β β^k, k << N
Parameters: O(NΒ·k) instead of O(NΒ²)
Random Synthesizer:
A = softmax(R) where R is fixed random matrix
Surprisingly effective baseline!
6.2 Hybrid SynthesizerΒΆ
Combine learned and data-dependent:
A = Ξ±Β·softmax(QK^T) + (1-Ξ±)Β·softmax(W_A)
Findings (Tay et al., 2020):
Synthesizer can match standard Transformer on some tasks
QK^T attention not always necessary
But task-dependent: summarization benefits from data-dependent attention
7. Sparse Attention PatternsΒΆ
7.1 Fixed PatternsΒΆ
Strided attention (Sparse Transformer):
Attend to every k-th position
One head: Local, another head: Strided
Combine heads to capture local + long-range
Pattern:
Head 1 (local): [i-w, ..., i, ..., i+w]
Head 2 (strided): [0, k, 2k, ..., βi/kβΒ·k]
Complexity: O(NΒ·βN) for 2D patterns (images)
7.2 Content-Based Sparse AttentionΒΆ
Routing Transformer: Use k-means clustering to route queries to relevant keys:
1. Cluster K into k clusters
2. Assign each q_i to nearest cluster
3. Attend only within cluster
Complexity: O(NΒ·k) where k is avg cluster size.
8. Big BirdΒΆ
8.1 Attention PatternΒΆ
Three components:
Random attention: Each token attends to r random tokens
Window attention: Sliding window of size w
Global tokens: g tokens that attend to/from all positions
Pattern for token i:
Window: [i-w/2, β¦, i, β¦, i+w/2]
Random: r random positions
Global: [0, β¦, g-1] (designated global tokens)
Complexity: O(NΒ·(w + r + g))
8.2 Theoretical JustificationΒΆ
Graph connectivity: Attention as graph where edge (i,j) means i attends to j.
Theorem (Zaheer et al., 2020): Big Bird attention graph is expander graph with:
Diameter: O(1) (shortest path between any two nodes is constant)
Spectral gap: Ξ©(1) (good mixing properties)
Implications:
Any token can reach any other in O(1) hops
Information flow is efficient despite sparsity
Theoretical guarantees lacking in pure sliding window
9. Flash AttentionΒΆ
9.1 IO-Aware AttentionΒΆ
Problem: Not just FLOP count, but memory transfers!
GPU memory hierarchy:
HBM (main memory): Large (40GB), slow (1.5 TB/s)
SRAM (on-chip): Small (20MB), fast (19 TB/s)
Registers: Tiny, fastest
Standard attention memory transfers:
Load Q, K from HBM β compute QK^T β store to HBM
Load QK^T from HBM β softmax β store to HBM
Load softmax, V from HBM β multiply β store output
Total HBM accesses: O(NΒ² + Nd) reads/writes of attention matrix.
9.2 Tiling and RecomputationΒΆ
Flash Attention algorithm:
Tile: Split Q, K, V into blocks that fit in SRAM
Fused kernel: Compute attention for each block without HBM write
Online softmax: Compute softmax incrementally (safe softmax with running max)
Online softmax (numerically stable):
m_i = max(m_{i-1}, max(q_i^T K))
β_i = e^{m_{i-1} - m_i} β_{i-1} + Ξ£_j e^{q_i^T k_j - m_i}
O_i = e^{m_{i-1} - m_i} O_{i-1} + Ξ£_j e^{q_i^T k_j - m_i} v_j
Backward pass: Recompute attention on-the-fly instead of storing.
Complexity:
FLOPs: O(NΒ²d) (same as standard)
HBM accesses: O(NΒ²dΒ²/M) where M = SRAM size
Speedup: 2-4Γ due to better memory utilization
9.3 Flash Attention 2ΒΆ
Improvements (Dao et al., 2023):
Better parallelization: Parallelize over sequence length, not batch
Reduced non-matmul ops: Minimize pointer arithmetic
Forward-backward pass fusion: Recompute even more in backward
Results:
2Γ faster than Flash Attention 1
73% of A100 theoretical max (vs 35% standard attention)
10. Complexity ComparisonΒΆ
10.1 Time ComplexityΒΆ
Method Time Params Context
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Standard Attention O(NΒ²d) O(dΒ²) Full
Linear Attention O(NdΒ²) O(dΒ²) Full
Performer O(Nmd) O(dΒ²) Full (approx)
Reformer (LSH) O(N log NΒ·d) O(dΒ²) Sparse (hashed)
Longformer O(NΒ·wΒ·d) O(dΒ²) Local + Global
Synthesizer O(NdΒ²) O(Nd) or O(dΒ²) Learned
Sparse (fixed) O(NΒ·sΒ·d) O(dΒ²) Fixed sparse
Big Bird O(NΒ·(w+r)Β·d) O(dΒ²) Local + Random + Global
Flash Attention O(NΒ²d) O(dΒ²) Full (IO-optimized)
where:
N: sequence length
d: model dimension
m: number of random features (Performer)
w: window size (Longformer, Big Bird)
s: sparsity (number of attended positions)
r: random attention positions (Big Bird)
10.2 Memory ComplexityΒΆ
Method Memory Notes
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Standard Attention O(NΒ²) Attention matrix
Linear Attention O(Nd) No attention matrix
Performer O(md) Feature map activations
Reformer O(NΒ·d) Reversible layers
Longformer O(NΒ·w) Sparse attention
Flash Attention O(NΒ²) But computed in tiles
11. Evaluation MetricsΒΆ
11.1 Efficiency MetricsΒΆ
Time:
Wall-clock time (seconds)
FLOPs (theoretical)
Throughput (tokens/sec)
Memory:
Peak memory usage
HBM bandwidth utilization
Maximum sequence length
Quality:
Perplexity (language modeling)
Downstream task accuracy
Attention pattern visualization
11.2 Approximation QualityΒΆ
For methods that approximate attention (Linear, Performer):
Attention distance:
||A_approx - A_exact||_F / ||A_exact||_F
Output distance:
||Output_approx - Output_exact||_2 / ||Output_exact||_2
Typical results (Performer):
Attention distance: 10-20%
Output distance: 1-5%
Task accuracy: Within 1% of standard Transformer
12. When to Use Each MethodΒΆ
12.1 Decision GuideΒΆ
Use Standard Attention when:
Sequence length < 1024
Need exact attention (theoretical guarantees)
Compute/memory not bottleneck
Baseline for research
Use Linear Attention when:
Need O(N) complexity
Can tolerate non-softmax attention
Autoregressive generation (recurrent form)
Extremely long sequences (N > 16k)
Use Performer when:
Need softmax approximation
Want provable error bounds
Bidirectional context
Research applications
Use Reformer when:
Limited GPU memory (reversible layers)
Natural language (locality assumption valid)
Sequence length 4k-16k
Can tolerate LSH approximation
Use Longformer when:
Document-level tasks (N = 4k-16k)
Strong locality in data
Need task-specific global tokens
Production document understanding
Use Flash Attention when:
Exact attention required
Have modern GPUs (A100, H100)
Memory bandwidth bottleneck
Want easy drop-in replacement
Use Big Bird when:
Need theoretical guarantees (expander graph)
Sparse attention acceptable
Long documents (N > 4k)
Graph-theoretic properties important
12.2 Combination StrategiesΒΆ
Multi-scale:
Lower layers: Local attention (cheap)
Higher layers: Global attention (expensive but fewer layers)
Hybrid:
Some heads: Linear attention (efficient)
Some heads: Standard attention (expressive)
Adaptive:
Easy examples: Sparse attention
Hard examples: Dense attention
Learn routing with RL
13. State-of-the-Art ResultsΒΆ
13.1 Language ModelingΒΆ
Long-Range Arena benchmark (Tay et al., 2020):
Tasks: Text, ListOps, Retrieval, Image, Pathfinder
Sequence lengths: 1k-16k
Results:
Model Avg Accuracy Speed vs Transformer
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Transformer 58.5% 1.0Γ
Linear Attention 53.9% 2.1Γ
Performer 61.4% 1.8Γ
Reformer 56.1% 1.3Γ
Longformer 62.8% 2.3Γ
Big Bird 64.2% 2.5Γ
13.2 Long Document UnderstandingΒΆ
Longformer on long documents:
WikiHop (N=4096): 75.8% accuracy (vs 71.2% truncated)
TriviaQA (N=4096): 75.2% F1 (vs 72.5% truncated)
Big Bird on text:
ArXiv summarization (N=4096): 46.3 ROUGE-1
PubMed QA (N=3072): 70.5% accuracy
13.3 Image GenerationΒΆ
Sparse Transformers (Child et al., 2019):
ImageNet 64Γ64 (N=4096): 2.80 bits/dim
CIFAR-10 (N=1024): 2.80 bits/dim
Linear Transformers (Katharopoulos et al., 2020):
CIFAR-10 classification: 94.2% (vs 95.1% standard)
3Γ faster training
14. Implementation ConsiderationsΒΆ
14.1 Software OptimizationsΒΆ
Kernel fusion:
Fuse softmax + mask + dropout into single kernel
Reduces memory bandwidth
Mixed precision:
Compute in FP16, accumulate in FP32
2Γ speedup on modern GPUs
Gradient checkpointing:
Trade compute for memory
Recompute activations in backward pass
14.2 Hardware ConsiderationsΒΆ
Tensor cores (A100, H100):
Optimized for matrix multiplication
Flash Attention leverages better than standard
Memory hierarchy:
Minimize HBM <-> SRAM transfers
Tile sizes match hardware (e.g., 64Γ64 for Flash)
Parallelism:
Batch parallelism: Standard
Sequence parallelism: Newer (Megatron-LM)
Pipeline parallelism: Across layers
15. Recent Advances (2023-2024)ΒΆ
15.1 Flash Attention 2ΒΆ
2Γ faster than Flash Attention 1
Better parallelization, reduced overhead
15.2 Paged Attention (vLLM)ΒΆ
Virtual memory for KV cache
Enables longer context in serving
No memory fragmentation
15.3 Multi-Query Attention (MQA)ΒΆ
Share K, V across heads (only Q is per-head)
10Γ faster decoding
Slight quality drop, but worth it for serving
15.4 Grouped-Query Attention (GQA)ΒΆ
Middle ground: Group heads, share K, V per group
Llama 2 uses GQA
Better quality than MQA, still fast
16. Research DirectionsΒΆ
16.1 Open ProblemsΒΆ
Softmax-free attention: Can we replace softmax entirely?
Adaptive sparsity: Learn attention patterns, not hand-designed
Theoretical understanding: Why do these approximations work?
Long-range dependencies: Better than O(N) for truly long sequences
Causal approximations: Most work on bidirectional, autoregressive harder
16.2 Emerging ApproachesΒΆ
State Space Models (S4, Mamba):
Not attention at all, but recurrent
O(N log N) or O(N) complexity
Competitive with Transformers on long sequences
Attention-free models:
AFT (Attention Free Transformer)
FNet (Fourier Transform instead of attention)
Limited success, but interesting
Hybrid architectures:
Local: CNN or MLP
Global: Sparse attention
Best of both worlds
17. Key TakeawaysΒΆ
No free lunch: Every efficient attention method trades something (approximation quality, implementation complexity, hardware efficiency).
Task-dependent: Longformer for documents, Linear for streaming, Flash for exact efficient.
Complexity isnβt everything: Flash Attention is O(NΒ²) but faster than O(N) methods due to better hardware utilization.
Locality matters: Most real-world data has strong locality β sparse methods work well.
Softmax may not be necessary: Synthesizer, Linear attention show alternative attention mechanisms can work.
Hardware co-design: Flash Attention shows importance of optimizing for memory hierarchy, not just FLOPs.
Combination is powerful: Hybrid methods (local + global, exact + approximate) often best.
18. ReferencesΒΆ
Foundational papers:
Linear Attention: Katharopoulos et al. (2020)
Performer: Choromanski et al. (2020)
Reformer: Kitaev et al. (2020)
Longformer: Beltagy et al. (2020)
Synthesizer: Tay et al. (2020)
Sparse Transformer: Child et al. (2019)
Big Bird: Zaheer et al. (2020)
Flash Attention: Dao et al. (2022)
Flash Attention 2: Dao (2023)
Surveys:
Tay et al. (2020): βEfficient Transformers: A Surveyβ
Lin et al. (2021): βA Survey of Transformersβ
# Advanced Efficient Transformers Implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Optional, Tuple
# ============================================================================
# 1. Linear Attention
# ============================================================================
class LinearAttention(nn.Module):
"""
Linear attention using kernel trick: O(N) complexity.
Avoids materializing NΓN attention matrix.
"""
def __init__(self, d_model, n_heads, feature_map='elu'):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.feature_map = feature_map
def apply_feature_map(self, x):
"""Apply non-linearity Ο(x) to make attention non-negative."""
if self.feature_map == 'elu':
return F.elu(x) + 1
elif self.feature_map == 'relu':
return F.relu(x)
else: # identity
return x
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, d_model)
mask: (batch, seq_len) - padding mask
Returns:
output: (batch, seq_len, d_model)
"""
B, N, D = x.shape
# Project to Q, K, V
Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim)
K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim)
V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim)
# Apply feature map: Ο(Q), Ο(K)
Q = self.apply_feature_map(Q) # (B, N, H, D/H)
K = self.apply_feature_map(K)
# Linear attention: Ο(Q) @ (Ο(K)^T @ V)
# Standard: softmax(QK^T) @ V has shape (N, N) @ (N, D/H) = (N, D/H)
# Linear: (N, D') @ ((D', N) @ (N, D/H)) = (N, D') @ (D', D/H) = (N, D/H)
Q = Q.transpose(1, 2) # (B, H, N, D/H)
K = K.transpose(1, 2) # (B, H, N, D/H)
V = V.transpose(1, 2) # (B, H, N, D/H)
# Compute K^T @ V: (B, H, D/H, N) @ (B, H, N, D/H) = (B, H, D/H, D/H)
KV = torch.matmul(K.transpose(-2, -1), V) # (B, H, D/H, D/H)
# Normalization: sum of K for each position
K_sum = K.sum(dim=2, keepdim=True) # (B, H, 1, D/H)
# Compute Q @ KV: (B, H, N, D/H) @ (B, H, D/H, D/H) = (B, H, N, D/H)
QKV = torch.matmul(Q, KV)
# Normalize: divide by Q @ K_sum^T
normalizer = torch.matmul(Q, K_sum.transpose(-2, -1)) # (B, H, N, 1)
output = QKV / (normalizer + 1e-6)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(output)
class CausalLinearAttention(nn.Module):
"""
Causal linear attention with O(1) per-token generation.
Maintains running sums for autoregressive decoding.
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
"""Causal linear attention (training mode)."""
B, N, D = x.shape
Q = F.elu(self.q_proj(x)) + 1
K = F.elu(self.k_proj(x)) + 1
V = self.v_proj(x)
Q = Q.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
# Causal: compute cumulative sums
output = []
KV_state = torch.zeros(B, self.n_heads, self.head_dim, self.head_dim,
device=x.device, dtype=x.dtype)
K_state = torch.zeros(B, self.n_heads, self.head_dim,
device=x.device, dtype=x.dtype)
for t in range(N):
# Update running sums
k_t = K[:, :, t:t+1, :] # (B, H, 1, D/H)
v_t = V[:, :, t:t+1, :]
KV_state = KV_state + torch.matmul(k_t.transpose(-2, -1), v_t)
K_state = K_state + k_t.squeeze(2)
# Compute output for position t
q_t = Q[:, :, t:t+1, :]
o_t = torch.matmul(q_t, KV_state)
normalizer = torch.matmul(q_t, K_state.unsqueeze(-1))
o_t = o_t / (normalizer + 1e-6)
output.append(o_t)
output = torch.cat(output, dim=2) # (B, H, N, D/H)
output = output.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(output)
# ============================================================================
# 2. Performer (FAVOR+)
# ============================================================================
class PerformerAttention(nn.Module):
"""
Performer attention using positive random features.
Approximates softmax attention with provable error bounds.
"""
def __init__(self, d_model, n_heads, num_features=256,
orthogonal=True, redraw_features=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.num_features = num_features
self.orthogonal = orthogonal
self.redraw_features = redraw_features
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
# Initialize random features
self.register_buffer('random_features',
self._create_random_features())
def _create_random_features(self):
"""Create random projection matrix."""
if self.orthogonal:
# Orthogonal random features (lower variance)
features = torch.randn(self.num_features, self.head_dim)
q, r = torch.qr(features.T)
return q.T[:self.num_features, :]
else:
# i.i.d. Gaussian features
return torch.randn(self.num_features, self.head_dim) / math.sqrt(self.head_dim)
def apply_kernel_feature_map(self, x):
"""
Apply FAVOR+ feature map: Ο(x) = exp(w^T x - ||x||Β²/2)
Args:
x: (B, H, N, D/H)
Returns:
features: (B, H, N, num_features)
"""
# Compute ||x||Β²/2
x_norm_sq = (x ** 2).sum(dim=-1, keepdim=True) / 2 # (B, H, N, 1)
# Compute w^T x for all random features
# x: (B, H, N, D/H), random_features: (num_features, D/H)
projections = torch.matmul(x, self.random_features.T) # (B, H, N, num_features)
# Apply exp(w^T x - ||x||Β²/2)
features = torch.exp(projections - x_norm_sq) / math.sqrt(self.num_features)
return features
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
B, N, D = x.shape
# Redraw random features if specified (for training)
if self.training and self.redraw_features:
self.random_features = self._create_random_features().to(x.device)
# Project to Q, K, V
Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim)
K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim)
V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim)
Q = Q.transpose(1, 2) # (B, H, N, D/H)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Apply kernel feature maps
Q_prime = self.apply_kernel_feature_map(Q) # (B, H, N, num_features)
K_prime = self.apply_kernel_feature_map(K)
# Linear attention with random features
# KV: (B, H, num_features, D/H)
KV = torch.matmul(K_prime.transpose(-2, -1), V)
# Output: (B, H, N, D/H)
output = torch.matmul(Q_prime, KV)
# Normalization
K_sum = K_prime.sum(dim=2, keepdim=True) # (B, H, 1, num_features)
normalizer = torch.matmul(Q_prime, K_sum.transpose(-2, -1))
output = output / (normalizer + 1e-6)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(output)
# ============================================================================
# 3. Longformer (Sliding Window + Global)
# ============================================================================
class LongformerAttention(nn.Module):
"""
Longformer attention: local sliding window + global attention.
O(NΒ·w) complexity where w is window size.
"""
def __init__(self, d_model, n_heads, window_size=512):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.window_size = window_size
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def _get_local_attention_mask(self, seq_len, device):
"""
Create sliding window attention mask.
Returns:
mask: (seq_len, seq_len) with 1s for allowed positions
"""
# Each position attends to [i - w/2, ..., i, ..., i + w/2]
mask = torch.zeros(seq_len, seq_len, device=device)
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 1
return mask
def forward(self, x, global_mask=None):
"""
Args:
x: (batch, seq_len, d_model)
global_mask: (batch, seq_len) - 1 for global tokens, 0 otherwise
Returns:
output: (batch, seq_len, d_model)
"""
B, N, D = x.shape
# Project
Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
# Create attention mask: local + global
local_mask = self._get_local_attention_mask(N, x.device)
if global_mask is not None:
# Global tokens attend to all, and all attend to global tokens
global_mask = global_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, N)
global_attn_mask = global_mask | global_mask.transpose(-2, -1)
attention_mask = local_mask.unsqueeze(0) | global_attn_mask
else:
attention_mask = local_mask.unsqueeze(0)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply mask (-inf for disallowed positions)
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# Softmax and attention
attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights) # Handle all -inf rows
output = torch.matmul(attn_weights, V)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(output)
# ============================================================================
# 4. Flash Attention (Simplified Conceptual Version)
# ============================================================================
class FlashAttention(nn.Module):
"""
Simplified Flash Attention concept (not actual CUDA kernel).
Demonstrates tiling and online softmax ideas.
"""
def __init__(self, d_model, n_heads, block_size=64):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.block_size = block_size
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def online_softmax_attention(self, Q, K, V):
"""
Online softmax: process K, V in blocks, maintain running max and sum.
This is a simplified version showing the concept.
Real Flash Attention uses optimized CUDA kernels.
"""
B, H, N, D = Q.shape
# Initialize output and running statistics
O = torch.zeros_like(Q)
m = torch.full((B, H, N, 1), float('-inf'), device=Q.device) # running max
ell = torch.zeros(B, H, N, 1, device=Q.device) # running sum
# Process in blocks (simulated tiling)
for j in range(0, N, self.block_size):
# Block of K, V
K_block = K[:, :, j:j+self.block_size, :]
V_block = V[:, :, j:j+self.block_size, :]
# Compute scores for this block
scores = torch.matmul(Q, K_block.transpose(-2, -1)) / math.sqrt(D)
# Update running max
m_new = torch.maximum(m, scores.max(dim=-1, keepdim=True)[0])
# Reweight previous output and sum
alpha = torch.exp(m - m_new)
O = O * alpha
ell = ell * alpha
# Add new contributions
exp_scores = torch.exp(scores - m_new)
O = O + torch.matmul(exp_scores, V_block)
ell = ell + exp_scores.sum(dim=-1, keepdim=True)
m = m_new
# Final normalization
O = O / ell
return O
def forward(self, x):
"""Standard forward (uses online softmax concept)."""
B, N, D = x.shape
Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
# Use online softmax attention
output = self.online_softmax_attention(Q, K, V)
output = output.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(output)
# ============================================================================
# Demonstrations
# ============================================================================
print("=" * 70)
print("Efficient Transformers - Advanced Implementations")
print("=" * 70)
# 1. Complexity comparison
print("\n1. Complexity Analysis:")
print(" ββββββββββββββββββββββ¬ββββββββββββ¬ββββββββββββ¬ββββββββββββββ")
print(" β Method β Time β Memory β Exact? β")
print(" ββββββββββββββββββββββΌββββββββββββΌββββββββββββΌββββββββββββββ€")
print(" β Standard Attention β O(NΒ²d) β O(NΒ²) β Yes β")
print(" β Linear Attention β O(NdΒ²) β O(Nd) β No (approx) β")
print(" β Performer β O(Nmd) β O(md) β No (approx) β")
print(" β Longformer β O(Nwd) β O(Nw) β Yes (sparse)β")
print(" β Flash Attention β O(NΒ²d) β O(NΒ²)* β Yes β")
print(" ββββββββββββββββββββββ΄ββββββββββββ΄ββββββββββββ΄ββββββββββββββ")
print(" * Flash: Same asymptotic memory but tiled (fits in SRAM)")
# 2. Linear attention demo
print("\n2. Linear Attention:")
linear_attn = LinearAttention(d_model=256, n_heads=8, feature_map='elu')
x_test = torch.randn(2, 100, 256)
out_linear = linear_attn(x_test)
print(f" Input: {x_test.shape}")
print(f" Output: {out_linear.shape}")
print(f" Feature map: Ο(x) = elu(x) + 1 (non-negative)")
print(f" Complexity: O(NΒ·dΒ²) = O(100Β·256Β²) = 6.5M ops")
print(f" vs Standard O(NΒ²Β·d) = O(100Β²Β·256) = 2.5M ops")
print(f" Note: Linear wins when N >> d")
# 3. Performer demo
print("\n3. Performer (FAVOR+):")
performer = PerformerAttention(d_model=256, n_heads=8, num_features=256,
orthogonal=True)
out_performer = performer(x_test)
print(f" Input: {x_test.shape}")
print(f" Output: {out_performer.shape}")
print(f" Random features: {256} (orthogonal)")
print(f" Approximates: exp(q^T k / βd_k)")
print(f" Complexity: O(NΒ·mΒ·d) = O(100Β·256Β·256) = 6.5M ops")
print(f" Error bound: O(1/βm) = O(1/16) β 6%")
# 4. Causal linear attention
print("\n4. Causal Linear Attention:")
causal_attn = CausalLinearAttention(d_model=256, n_heads=8)
x_causal = torch.randn(1, 50, 256)
out_causal = causal_attn(x_causal)
print(f" Input: {x_causal.shape}")
print(f" Output: {out_causal.shape}")
print(f" Mechanism: Cumulative sums S_t = S_(t-1) + k_t β v_t")
print(f" Generation cost: O(dΒ²) per token (vs O(Nd) standard)")
print(f" Speedup at N=4096: ~16Γ faster decoding")
# 5. Longformer demo
print("\n5. Longformer Attention:")
longformer = LongformerAttention(d_model=256, n_heads=8, window_size=16)
x_long = torch.randn(1, 64, 256)
global_mask = torch.zeros(1, 64, dtype=torch.bool)
global_mask[0, 0] = True # First token is global (e.g., [CLS])
out_long = longformer(x_long, global_mask)
print(f" Input: {x_long.shape}")
print(f" Output: {out_long.shape}")
print(f" Window size: 16")
print(f" Global tokens: 1 (position 0)")
print(f" Complexity: O(NΒ·w) = O(64Β·16) = 1,024 attention scores")
print(f" vs Standard: O(NΒ²) = O(64Β²) = 4,096 scores (4Γ reduction)")
# 6. Flash attention concept
print("\n6. Flash Attention (Conceptual):")
flash_attn = FlashAttention(d_model=256, n_heads=8, block_size=32)
x_flash = torch.randn(1, 128, 256)
out_flash = flash_attn(x_flash)
print(f" Input: {x_flash.shape}")
print(f" Output: {out_flash.shape}")
print(f" Block size: 32 (tiles)")
print(f" Key idea: Online softmax with running max/sum")
print(f" Benefit: Reduced HBM β SRAM transfers")
print(f" Speedup: 2-4Γ on A100 (memory-bound workloads)")
# 7. Attention pattern visualization
print("\n7. Attention Patterns:")
print(" Standard Attention:")
print(" β β β β β β β β (dense, all-to-all)")
print(" β β β β β β β β ")
print(" β β β β β β β β ")
print("\n Longformer (window=3):")
print(" β β β Β· Β· Β· Β· Β· (sparse, local window)")
print(" β β β β Β· Β· Β· Β·")
print(" Β· β β β β Β· Β· Β·")
print("\n Longformer + Global (token 0):")
print(" β β β β β β β β (global token attends all)")
print(" β β β β Β· Β· Β· Β·")
print(" β Β· β β β Β· Β· Β·")
# 8. When to use guide
print("\n8. Method Selection Guide:")
print(" Sequence length N < 1024:")
print(" β Use standard attention (exact, well-optimized)")
print("\n 1024 < N < 4096:")
print(" β Longformer (good local/global tradeoff)")
print(" β Flash Attention (exact + fast on modern GPUs)")
print("\n N > 4096:")
print(" β Linear Attention (extreme lengths, streaming)")
print(" β Reformer (if memory constrained)")
print("\n Need exact attention:")
print(" β Flash Attention (best of both worlds)")
print("\n Can tolerate approximation:")
print(" β Performer (provable bounds, research)")
print(" β Linear (fastest, autoregressive)")
# 9. Approximation quality
print("\n9. Approximation Quality:")
# Compare linear vs standard attention on small example
x_small = torch.randn(1, 32, 64)
std_attn = nn.MultiheadAttention(64, 4, batch_first=True)
lin_attn = LinearAttention(64, 4, feature_map='elu')
std_out, _ = std_attn(x_small, x_small, x_small)
lin_out = lin_attn(x_small)
relative_error = (std_out - lin_out).norm() / std_out.norm()
print(f" Standard vs Linear attention:")
print(f" Input: {x_small.shape}")
print(f" Output difference (relative): {relative_error.item():.4f}")
print(f" Typical range: 5-20% for linear approximations")
print(f" Task accuracy drop: <1% on many benchmarks")
# 10. Memory savings
print("\n10. Memory Savings:")
seq_lengths = [512, 1024, 2048, 4096, 8192]
print(" Attention Matrix Memory (MB, float32):")
print(" βββββββββββ¬βββββββββββ¬βββββββββββ¬βββββββββββ")
print(" β Seq Len β Standard β Linear β Savings β")
print(" βββββββββββΌβββββββββββΌβββββββββββΌβββββββββββ€")
for N in seq_lengths:
std_mem = N * N * 4 / 1e6 # float32 = 4 bytes
lin_mem = N * 64 * 4 / 1e6 # assume d=64
savings = f"{std_mem / lin_mem:.1f}Γ"
print(f" β {N:7} β {std_mem:8.2f} β {lin_mem:8.2f} β {savings:8} β")
print(" βββββββββββ΄βββββββββββ΄βββββββββββ΄βββββββββββ")
print("\n" + "=" * 70)