import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

1. Attention Complexity ProblemΒΆ

Standard AttentionΒΆ

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Complexity:

  • Time: \(O(n^2 d)\) where \(n\) = sequence length

  • Space: \(O(n^2)\) for attention matrix

Problem: Quadratic scaling prevents long sequences!

SolutionsΒΆ

Method

Complexity

Idea

Linformer

\(O(nk)\)

Low-rank projection

Performer

\(O(nd^2)\)

Kernel approximation

Reformer

\(O(n\log n)\)

LSH attention

Sparse

\(O(n\sqrt{n})\)

Fixed patterns

πŸ“š Reference Materials:

# Benchmark standard attention
def standard_attention(Q, K, V):
    scores = Q @ K.transpose(-2, -1) / np.sqrt(Q.size(-1))
    attn = F.softmax(scores, dim=-1)
    return attn @ V

# Measure complexity
seq_lengths = [128, 256, 512, 1024]
times = []

d_model = 64
batch_size = 4

for n in seq_lengths:
    Q = torch.randn(batch_size, n, d_model).to(device)
    K = torch.randn(batch_size, n, d_model).to(device)
    V = torch.randn(batch_size, n, d_model).to(device)
    
    start = time.time()
    for _ in range(10):
        _ = standard_attention(Q, K, V)
    times.append((time.time() - start) / 10)

plt.figure(figsize=(8, 5))
plt.plot(seq_lengths, times, marker='o', label='Standard Attention')
plt.xlabel('Sequence Length', fontsize=11)
plt.ylabel('Time (s)', fontsize=11)
plt.title('Attention Complexity', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

2. Linformer: Low-Rank AttentionΒΆ

Key InsightΒΆ

Attention matrix is often low-rank!

MethodΒΆ

Project \(K, V\) to lower dimension \(k \ll n\):

\[\tilde{K} = EK, \quad \tilde{V} = FV\]

where \(E, F \in \mathbb{R}^{k \times n}\).

\[\text{Linformer}(Q, K, V) = \text{softmax}\left(\frac{Q\tilde{K}^T}{\sqrt{d_k}}\right)\tilde{V}\]

Complexity: \(O(nkd)\) instead of \(O(n^2d)\)!

class LinformerAttention(nn.Module):
    def __init__(self, d_model, n_heads, max_len=512, k=64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.k = k
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Low-rank projections
        self.E = nn.Linear(max_len, k, bias=False)
        self.F = nn.Linear(max_len, k, bias=False)
    
    def forward(self, x):
        B, T, C = x.size()
        
        qkv = self.qkv(x)
        q, k, v = qkv.split(self.d_model, dim=2)
        
        # Reshape
        q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # Project K, V to lower dimension
        k = k.transpose(-2, -1)  # (B, h, d_k, T)
        k = self.E(k)  # (B, h, d_k, k)
        k = k.transpose(-2, -1)  # (B, h, k, d_k)
        
        v = v.transpose(-2, -1)
        v = self.F(v)
        v = v.transpose(-2, -1)
        
        # Attention
        scores = q @ k.transpose(-2, -1) / np.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = attn @ v
        
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

# Test
linformer = LinformerAttention(d_model=64, n_heads=4, max_len=512, k=64).to(device)
x = torch.randn(2, 128, 64).to(device)
out = linformer(x)
print(f"Output shape: {out.shape}")

3. Performer: Kernel ApproximationΒΆ

Key InsightΒΆ

Rewrite attention without explicit softmax:

\[\text{Attention}(Q, K, V)_i = \frac{\sum_j \exp(q_i^T k_j / \sqrt{d}) v_j}{\sum_j \exp(q_i^T k_j / \sqrt{d})}\]

Kernel TrickΒΆ

Approximate: $\(\exp(q^T k) \approx \phi(q)^T \phi(k)\)$

using random features \(\phi: \mathbb{R}^d \to \mathbb{R}^r\).

Fast ComputationΒΆ

\[\text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)(\phi(K)^T \mathbf{1})}\]

Complexity: \(O(nrd)\) with \(r \ll n\)!

class PerformerAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_features=64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.n_features = n_features
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def kernel_feature_map(self, x):
        """Positive random features."""
        B, H, T, D = x.size()
        
        # Random projection
        omega = torch.randn(D, self.n_features, device=x.device) / np.sqrt(D)
        x_proj = x @ omega
        
        # Positive features
        return torch.exp(x_proj - torch.max(x_proj, dim=-1, keepdim=True)[0])
    
    def forward(self, x):
        B, T, C = x.size()
        
        qkv = self.qkv(x)
        q, k, v = qkv.split(self.d_model, dim=2)
        
        q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # Feature maps
        q_prime = self.kernel_feature_map(q)
        k_prime = self.kernel_feature_map(k)
        
        # Linear attention
        kv = k_prime.transpose(-2, -1) @ v
        z = q_prime @ kv
        
        # Normalization
        normalizer = q_prime @ k_prime.sum(dim=-2, keepdim=True).transpose(-2, -1)
        out = z / (normalizer + 1e-6)
        
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

# Test
performer = PerformerAttention(d_model=64, n_heads=4, n_features=64).to(device)
x = torch.randn(2, 128, 64).to(device)
out = performer(x)
print(f"Output shape: {out.shape}")

ComparisonΒΆ

Benchmarking the efficient Transformer variants against standard self-attention reveals the core trade-off: reduced computational complexity (from \(O(n^2)\) to \(O(n \log n)\) or \(O(n)\)) at the cost of some approximation error in the attention pattern. Linear attention methods (e.g., Performer) replace the softmax kernel with a feature map approximation, achieving \(O(n)\) time and memory. Sparse attention methods (e.g., Longformer, BigBird) restrict each token to attend to a fixed local window plus selected global tokens. Comparing wall-clock time, memory usage, and downstream task accuracy across sequence lengths helps practitioners choose the right variant for their use case – exact attention for short sequences, efficient variants for long-document or genomic applications.

# Benchmark all methods
seq_lengths = [128, 256, 512, 1024, 2048]
d_model = 64
batch_size = 2

times_standard = []
times_linformer = []
times_performer = []

for n in seq_lengths:
    x = torch.randn(batch_size, n, d_model).to(device)
    
    # Standard (skip if too long)
    if n <= 1024:
        Q = K = V = x
        start = time.time()
        _ = standard_attention(Q, K, V)
        times_standard.append(time.time() - start)
    else:
        times_standard.append(None)
    
    # Linformer
    linf = LinformerAttention(d_model, n_heads=4, max_len=max(seq_lengths), k=64).to(device)
    start = time.time()
    _ = linf(x)
    times_linformer.append(time.time() - start)
    
    # Performer
    perf = PerformerAttention(d_model, n_heads=4, n_features=64).to(device)
    start = time.time()
    _ = perf(x)
    times_performer.append(time.time() - start)

plt.figure(figsize=(10, 6))
valid_std = [t for t in times_standard if t is not None]
plt.plot(seq_lengths[:len(valid_std)], valid_std, marker='o', label='Standard')
plt.plot(seq_lengths, times_linformer, marker='s', label='Linformer')
plt.plot(seq_lengths, times_performer, marker='^', label='Performer')
plt.xlabel('Sequence Length', fontsize=11)
plt.ylabel('Time (s)', fontsize=11)
plt.title('Efficient Attention Comparison', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.show()

SummaryΒΆ

Problem:ΒΆ

Standard attention: \(O(n^2d)\) complexity

Solutions:ΒΆ

Method

Complexity

Tradeoff

Linformer

\(O(nkd)\)

Low-rank assumption

Performer

\(O(nrd)\)

Kernel approximation

Reformer

\(O(n\log n)\)

LSH bucketing

Longformer

\(O(n)\)

Sparse patterns

Linformer:ΒΆ

  • Project \(K, V\) to dimension \(k\)

  • Good when attention is low-rank

Performer:ΒΆ

  • Kernel approximation via random features

  • Linear complexity in sequence length

  • Unbiased approximation

Applications:ΒΆ

  • Long documents (16k+ tokens)

  • DNA sequences

  • Video processing

  • Audio generation

Next Steps:ΒΆ

  • 13_gpt_architecture.ipynb - Full transformer

  • Explore FlashAttention (memory-efficient)

  • Test on long-range tasks

Advanced Efficient Transformers TheoryΒΆ

1. The Quadratic Complexity ProblemΒΆ

1.1 Standard Self-AttentionΒΆ

Standard Transformer self-attention has O(NΒ²) time and memory complexity:

Attention formula:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Computational breakdown:

  • QK^T: Matrix multiplication (NΓ—d_k) @ (d_kΓ—N) = O(NΒ²d_k)

  • Softmax: O(NΒ²) (requires materializing NΓ—N attention matrix)

  • Multiply by V: (NΓ—N) @ (NΓ—d_v) = O(NΒ²d_v)

  • Total: O(NΒ²d) time, O(NΒ²) memory for attention matrix

Why this is problematic:

  • N=1024 tokens: 1M attention scores

  • N=4096 tokens: 16M attention scores (16Γ— larger!)

  • GPU memory limits: Can’t fit long sequences

1.2 Bottleneck AnalysisΒΆ

The attention matrix A = softmax(QK^T/√d_k) ∈ ℝ^(NΓ—N) is the bottleneck:

  • Must materialize all NΒ² attention scores

  • Cannot be computed incrementally (softmax couples all scores)

  • Memory scales quadratically with sequence length

2. Linear AttentionΒΆ

2.1 Kernel Trick ReformulationΒΆ

Key insight: Reorder operations to avoid materializing NΓ—N matrix.

Standard attention:

Attention(Q, K, V) = softmax(QK^T) V
                   = (exp(QK^T) / Ξ£_j exp(q_i^T k_j)) V

Kernel trick: Define similarity kernel k(q, k) = Ο†(q)^T Ο†(k)

Linear attention:

Attention(Q, K, V) = Ο†(Q)(Ο†(K)^T V) / (Ο†(Q)(Ο†(K)^T 1_N))

Complexity:

  • Ο†(K)^T V: (d’×N) @ (NΓ—d_v) = O(Nd’d_v)

  • Ο†(Q) Γ— result: (NΓ—d’) @ (d’×d_v) = O(Nd’d_v)

  • Total: O(Nd’d_v) (linear in N!)

Feature map choices:

  1. Identity: Ο†(x) = x (no softmax, loses attention properties)

  2. ReLU: Ο†(x) = ReLU(x) (non-negative, fast)

  3. ELU+1: Ο†(x) = elu(x) + 1 (smooth, non-negative)

Theoretical tradeoff:

  • Pro: O(N) complexity, constant memory

  • Con: Loses softmax normalization (changes attention semantics)

  • Con: No bounded attention weights (can have exploding values)

2.2 Causal Linear AttentionΒΆ

For autoregressive generation, need causal masking. Can be computed incrementally:

Recurrent form:

S_t = S_(t-1) + Ο†(k_t) βŠ— v_t    (running sum, O(d'd_v) per step)
z_t = z_(t-1) + Ο†(k_t)          (normalization, O(d') per step)
y_t = Ο†(q_t)^T S_t / (Ο†(q_t)^T z_t)

Advantage: O(1) per token in generation (vs O(N) for standard attention)

3. Performer (FAVOR+)ΒΆ

3.1 Positive Random FeaturesΒΆ

Problem with linear attention: Feature maps Ο†(x) don’t approximate softmax well.

Performer solution: Use random Fourier features to approximate exp(q^T k):

FAVOR+ (Fast Attention Via positive Orthogonal Random features):

exp(q^T k / √d) β‰ˆ E[Ο†(q)^T Ο†(k)]

Feature map:

Ο†(x) = (1/√m) [exp(w_1^T x - ||x||Β²/2), ..., exp(w_m^T x - ||x||Β²/2)]

where w_i ~ N(0, I) are random Gaussian projections.

Orthogonal random features (ORF): Sample w_1, …, w_m from orthogonalized Gaussian matrix β†’ lower variance estimator.

Renormalization:

Ο†(x) = h(x) [f_1(x), ..., f_m(x)]
h(x) = exp(-||x||Β²/2)
f_i(x) = exp(w_i^T x)

Approximation quality:

  • Theorem: For m random features, approximation error is O(1/√m)

  • Typical m=256 gives good approximation with O(NΒ·256Β·d) = O(Nd) complexity

3.2 Complexity AnalysisΒΆ

  • Standard attention: O(NΒ²d)

  • Performer: O(Nmd) where m << N

  • Typical settings: m=256, N=4096 β†’ 16Γ— speedup

Memory:

  • Standard: O(NΒ²) for attention matrix

  • Performer: O(md) for S = Ο†(K)^T V (constant in N!)

3.3 Theoretical GuaranteesΒΆ

Approximation bound (Choromanski et al., 2020):

|softmax(QK^T/√d)V - Ο†(Q)(Ο†(K)^TV)| ≀ Ξ΅

with probability β‰₯ 1-Ξ΄ if m β‰₯ O(log(N/Ξ΄)/Ρ²)

Practical implications:

  • m=256: Ξ΅ β‰ˆ 0.1 (10% error acceptable for many tasks)

  • Higher m: Better approximation but higher cost

4. ReformerΒΆ

4.1 Locality-Sensitive Hashing (LSH) AttentionΒΆ

Key insight: Most attention weights are near-zero. Only attend to similar keys.

LSH basics: Hash function h such that similar items collide:

P(h(x) = h(y)) ∝ similarity(x, y)

Angular LSH for cosine similarity:

h(x) = sign(r^T x)  where r ~ N(0, I)

Multi-round hashing: Use b hash functions [h_1, …, h_b], concatenate to get bucket:

bucket(x) = [h_1(x), ..., h_b(x)]

LSH attention algorithm:

  1. Hash queries and keys into buckets

  2. Sort by bucket (queries and keys together)

  3. Only compute attention within buckets (and adjacent for smoothness)

  4. Complexity: O(N log N) (dominated by sorting)

Attention sparsity:

  • Each query attends to ~N/n_buckets keys (+ neighbors)

  • Typical: n_buckets = N/32 β†’ 32Γ— reduction in computations

4.2 Reversible LayersΒΆ

Problem: Storing activations for backprop requires O(NΒ·LΒ·d) memory (L layers).

Reversible residual network (RevNet): Split activations into two parts [x1, x2], compute:

y1 = x1 + F(x2)
y2 = x2 + G(y1)

Inverse (for backprop):

x2 = y2 - G(y1)
x1 = y1 - F(x2)

Advantage: Don’t need to store intermediate activations! Can recompute during backprop.

Reformer layers:

Y1 = X1 + Attention(X2)
Y2 = X2 + FeedForward(Y1)

Memory savings:

  • Standard Transformer: O(NΒ·LΒ·d) activation memory

  • Reformer: O(NΒ·d) (only store inputs, recompute rest)

4.3 Chunked Feed-ForwardΒΆ

Split sequence into chunks, process feed-forward chunk-by-chunk to save memory:

for chunk in chunks(X):
    chunk = chunk + FeedForward(chunk)

No cross-chunk dependencies in FFN β†’ parallelizable.

5. LongformerΒΆ

5.1 Sliding Window AttentionΒΆ

Local attention: Each token attends to w neighbors on each side.

Complexity: O(NΒ·w) where w << N (typically w=512 even for N=4096)

Pattern:

Attention mask for position i: [i-w, ..., i, ..., i+w]

Advantages:

  • Linear complexity in sequence length

  • Captures local context (most dependencies are local)

  • Can use different window sizes per layer

Implementation: Sparse attention matrix A where A[i,j] = 0 if |i-j| > w.

5.2 Dilated Sliding WindowΒΆ

Problem: Sliding window misses long-range dependencies.

Dilated attention: Skip every d tokens in window.

Effective receptive field:

  • Layer 0: window w, dilation 1 β†’ range w

  • Layer 1: window w, dilation 2 β†’ range 2w

  • Layer 2: window w, dilation 4 β†’ range 4w

  • Exponential growth in receptive field!

Pattern for dilation d:

Attend to positions: [..., i-dΒ·w, ..., i-d, i, i+d, ..., i+dΒ·w]

5.3 Global AttentionΒΆ

Task-specific tokens: Some tokens (e.g., [CLS]) need to attend to all tokens.

Hybrid pattern:

  • Most tokens: Local attention O(NΒ·w)

  • Few global tokens (g << N): Full attention O(gΒ·N)

  • Total: O(NΒ·w + gΒ·N) β‰ˆ O(NΒ·w) if g is small

Example (document QA):

  • Question tokens: Global attention (attend to all doc tokens)

  • Document tokens: Local attention (sliding window)

5.4 Longformer Attention PatternsΒΆ

Three types combined:

  1. Sliding window: All tokens, w=512

  2. Dilated window: Selected layers, exponentially increasing dilation

  3. Global: Task tokens ([CLS], question tokens)

Complexity analysis:

  • Standard: O(NΒ²)

  • Longformer: O(NΒ·w + gΒ·N) where w=512, g << N

  • For N=4096, w=512, g=10: ~8Γ— speedup

6. SynthesizerΒΆ

6.1 Learned Attention PatternsΒΆ

Key insight: Maybe we don’t need data-dependent attention QK^T. Learn attention patterns directly.

Dense Synthesizer:

A = softmax(W_A)  where W_A ∈ ℝ^(NΓ—N) is learned
Output = A V

Problem: W_A has O(NΒ²) parameters (not generalizable to different lengths).

Factored Synthesizer:

A[i,j] = softmax(w_i^T w_j)  where w_i ∈ ℝ^k, k << N

Parameters: O(NΒ·k) instead of O(NΒ²)

Random Synthesizer:

A = softmax(R)  where R is fixed random matrix

Surprisingly effective baseline!

6.2 Hybrid SynthesizerΒΆ

Combine learned and data-dependent:

A = Ξ±Β·softmax(QK^T) + (1-Ξ±)Β·softmax(W_A)

Findings (Tay et al., 2020):

  • Synthesizer can match standard Transformer on some tasks

  • QK^T attention not always necessary

  • But task-dependent: summarization benefits from data-dependent attention

7. Sparse Attention PatternsΒΆ

7.1 Fixed PatternsΒΆ

Strided attention (Sparse Transformer):

  • Attend to every k-th position

  • One head: Local, another head: Strided

  • Combine heads to capture local + long-range

Pattern:

Head 1 (local): [i-w, ..., i, ..., i+w]
Head 2 (strided): [0, k, 2k, ..., ⌊i/kβŒ‹Β·k]

Complexity: O(N·√N) for 2D patterns (images)

7.2 Content-Based Sparse AttentionΒΆ

Routing Transformer: Use k-means clustering to route queries to relevant keys:

1. Cluster K into k clusters
2. Assign each q_i to nearest cluster
3. Attend only within cluster

Complexity: O(NΒ·k) where k is avg cluster size.

8. Big BirdΒΆ

8.1 Attention PatternΒΆ

Three components:

  1. Random attention: Each token attends to r random tokens

  2. Window attention: Sliding window of size w

  3. Global tokens: g tokens that attend to/from all positions

Pattern for token i:

  • Window: [i-w/2, …, i, …, i+w/2]

  • Random: r random positions

  • Global: [0, …, g-1] (designated global tokens)

Complexity: O(NΒ·(w + r + g))

8.2 Theoretical JustificationΒΆ

Graph connectivity: Attention as graph where edge (i,j) means i attends to j.

Theorem (Zaheer et al., 2020): Big Bird attention graph is expander graph with:

  • Diameter: O(1) (shortest path between any two nodes is constant)

  • Spectral gap: Ξ©(1) (good mixing properties)

Implications:

  • Any token can reach any other in O(1) hops

  • Information flow is efficient despite sparsity

  • Theoretical guarantees lacking in pure sliding window

9. Flash AttentionΒΆ

9.1 IO-Aware AttentionΒΆ

Problem: Not just FLOP count, but memory transfers!

GPU memory hierarchy:

  • HBM (main memory): Large (40GB), slow (1.5 TB/s)

  • SRAM (on-chip): Small (20MB), fast (19 TB/s)

  • Registers: Tiny, fastest

Standard attention memory transfers:

  1. Load Q, K from HBM β†’ compute QK^T β†’ store to HBM

  2. Load QK^T from HBM β†’ softmax β†’ store to HBM

  3. Load softmax, V from HBM β†’ multiply β†’ store output

Total HBM accesses: O(NΒ² + Nd) reads/writes of attention matrix.

9.2 Tiling and RecomputationΒΆ

Flash Attention algorithm:

  1. Tile: Split Q, K, V into blocks that fit in SRAM

  2. Fused kernel: Compute attention for each block without HBM write

  3. Online softmax: Compute softmax incrementally (safe softmax with running max)

Online softmax (numerically stable):

m_i = max(m_{i-1}, max(q_i^T K))
β„“_i = e^{m_{i-1} - m_i} β„“_{i-1} + Ξ£_j e^{q_i^T k_j - m_i}
O_i = e^{m_{i-1} - m_i} O_{i-1} + Ξ£_j e^{q_i^T k_j - m_i} v_j

Backward pass: Recompute attention on-the-fly instead of storing.

Complexity:

  • FLOPs: O(NΒ²d) (same as standard)

  • HBM accesses: O(NΒ²dΒ²/M) where M = SRAM size

  • Speedup: 2-4Γ— due to better memory utilization

9.3 Flash Attention 2ΒΆ

Improvements (Dao et al., 2023):

  1. Better parallelization: Parallelize over sequence length, not batch

  2. Reduced non-matmul ops: Minimize pointer arithmetic

  3. Forward-backward pass fusion: Recompute even more in backward

Results:

  • 2Γ— faster than Flash Attention 1

  • 73% of A100 theoretical max (vs 35% standard attention)

10. Complexity ComparisonΒΆ

10.1 Time ComplexityΒΆ

Method                 Time          Params        Context
─────────────────────────────────────────────────────────────
Standard Attention     O(NΒ²d)       O(dΒ²)         Full
Linear Attention       O(NdΒ²)       O(dΒ²)         Full
Performer              O(Nmd)       O(dΒ²)         Full (approx)
Reformer (LSH)         O(N log NΒ·d) O(dΒ²)         Sparse (hashed)
Longformer             O(NΒ·wΒ·d)     O(dΒ²)         Local + Global
Synthesizer            O(NdΒ²)       O(Nd) or O(dΒ²) Learned
Sparse (fixed)         O(NΒ·sΒ·d)     O(dΒ²)         Fixed sparse
Big Bird               O(NΒ·(w+r)Β·d) O(dΒ²)         Local + Random + Global
Flash Attention        O(NΒ²d)       O(dΒ²)         Full (IO-optimized)

where:

  • N: sequence length

  • d: model dimension

  • m: number of random features (Performer)

  • w: window size (Longformer, Big Bird)

  • s: sparsity (number of attended positions)

  • r: random attention positions (Big Bird)

10.2 Memory ComplexityΒΆ

Method                 Memory        Notes
────────────────────────────────────────────────────────────
Standard Attention     O(NΒ²)         Attention matrix
Linear Attention       O(Nd)         No attention matrix
Performer              O(md)         Feature map activations
Reformer               O(NΒ·d)        Reversible layers
Longformer             O(NΒ·w)        Sparse attention
Flash Attention        O(NΒ²)         But computed in tiles

11. Evaluation MetricsΒΆ

11.1 Efficiency MetricsΒΆ

Time:

  • Wall-clock time (seconds)

  • FLOPs (theoretical)

  • Throughput (tokens/sec)

Memory:

  • Peak memory usage

  • HBM bandwidth utilization

  • Maximum sequence length

Quality:

  • Perplexity (language modeling)

  • Downstream task accuracy

  • Attention pattern visualization

11.2 Approximation QualityΒΆ

For methods that approximate attention (Linear, Performer):

Attention distance:

||A_approx - A_exact||_F / ||A_exact||_F

Output distance:

||Output_approx - Output_exact||_2 / ||Output_exact||_2

Typical results (Performer):

  • Attention distance: 10-20%

  • Output distance: 1-5%

  • Task accuracy: Within 1% of standard Transformer

12. When to Use Each MethodΒΆ

12.1 Decision GuideΒΆ

Use Standard Attention when:

  • Sequence length < 1024

  • Need exact attention (theoretical guarantees)

  • Compute/memory not bottleneck

  • Baseline for research

Use Linear Attention when:

  • Need O(N) complexity

  • Can tolerate non-softmax attention

  • Autoregressive generation (recurrent form)

  • Extremely long sequences (N > 16k)

Use Performer when:

  • Need softmax approximation

  • Want provable error bounds

  • Bidirectional context

  • Research applications

Use Reformer when:

  • Limited GPU memory (reversible layers)

  • Natural language (locality assumption valid)

  • Sequence length 4k-16k

  • Can tolerate LSH approximation

Use Longformer when:

  • Document-level tasks (N = 4k-16k)

  • Strong locality in data

  • Need task-specific global tokens

  • Production document understanding

Use Flash Attention when:

  • Exact attention required

  • Have modern GPUs (A100, H100)

  • Memory bandwidth bottleneck

  • Want easy drop-in replacement

Use Big Bird when:

  • Need theoretical guarantees (expander graph)

  • Sparse attention acceptable

  • Long documents (N > 4k)

  • Graph-theoretic properties important

12.2 Combination StrategiesΒΆ

Multi-scale:

  • Lower layers: Local attention (cheap)

  • Higher layers: Global attention (expensive but fewer layers)

Hybrid:

  • Some heads: Linear attention (efficient)

  • Some heads: Standard attention (expressive)

Adaptive:

  • Easy examples: Sparse attention

  • Hard examples: Dense attention

  • Learn routing with RL

13. State-of-the-Art ResultsΒΆ

13.1 Language ModelingΒΆ

Long-Range Arena benchmark (Tay et al., 2020):

  • Tasks: Text, ListOps, Retrieval, Image, Pathfinder

  • Sequence lengths: 1k-16k

Results:

Model                Avg Accuracy  Speed vs Transformer
────────────────────────────────────────────────────────
Transformer          58.5%         1.0Γ—
Linear Attention     53.9%         2.1Γ—
Performer            61.4%         1.8Γ—
Reformer             56.1%         1.3Γ—
Longformer           62.8%         2.3Γ—
Big Bird             64.2%         2.5Γ—

13.2 Long Document UnderstandingΒΆ

Longformer on long documents:

  • WikiHop (N=4096): 75.8% accuracy (vs 71.2% truncated)

  • TriviaQA (N=4096): 75.2% F1 (vs 72.5% truncated)

Big Bird on text:

  • ArXiv summarization (N=4096): 46.3 ROUGE-1

  • PubMed QA (N=3072): 70.5% accuracy

13.3 Image GenerationΒΆ

Sparse Transformers (Child et al., 2019):

  • ImageNet 64Γ—64 (N=4096): 2.80 bits/dim

  • CIFAR-10 (N=1024): 2.80 bits/dim

Linear Transformers (Katharopoulos et al., 2020):

  • CIFAR-10 classification: 94.2% (vs 95.1% standard)

  • 3Γ— faster training

14. Implementation ConsiderationsΒΆ

14.1 Software OptimizationsΒΆ

Kernel fusion:

  • Fuse softmax + mask + dropout into single kernel

  • Reduces memory bandwidth

Mixed precision:

  • Compute in FP16, accumulate in FP32

  • 2Γ— speedup on modern GPUs

Gradient checkpointing:

  • Trade compute for memory

  • Recompute activations in backward pass

14.2 Hardware ConsiderationsΒΆ

Tensor cores (A100, H100):

  • Optimized for matrix multiplication

  • Flash Attention leverages better than standard

Memory hierarchy:

  • Minimize HBM <-> SRAM transfers

  • Tile sizes match hardware (e.g., 64Γ—64 for Flash)

Parallelism:

  • Batch parallelism: Standard

  • Sequence parallelism: Newer (Megatron-LM)

  • Pipeline parallelism: Across layers

15. Recent Advances (2023-2024)ΒΆ

15.1 Flash Attention 2ΒΆ

  • 2Γ— faster than Flash Attention 1

  • Better parallelization, reduced overhead

15.2 Paged Attention (vLLM)ΒΆ

  • Virtual memory for KV cache

  • Enables longer context in serving

  • No memory fragmentation

15.3 Multi-Query Attention (MQA)ΒΆ

  • Share K, V across heads (only Q is per-head)

  • 10Γ— faster decoding

  • Slight quality drop, but worth it for serving

15.4 Grouped-Query Attention (GQA)ΒΆ

  • Middle ground: Group heads, share K, V per group

  • Llama 2 uses GQA

  • Better quality than MQA, still fast

16. Research DirectionsΒΆ

16.1 Open ProblemsΒΆ

  1. Softmax-free attention: Can we replace softmax entirely?

  2. Adaptive sparsity: Learn attention patterns, not hand-designed

  3. Theoretical understanding: Why do these approximations work?

  4. Long-range dependencies: Better than O(N) for truly long sequences

  5. Causal approximations: Most work on bidirectional, autoregressive harder

16.2 Emerging ApproachesΒΆ

State Space Models (S4, Mamba):

  • Not attention at all, but recurrent

  • O(N log N) or O(N) complexity

  • Competitive with Transformers on long sequences

Attention-free models:

  • AFT (Attention Free Transformer)

  • FNet (Fourier Transform instead of attention)

  • Limited success, but interesting

Hybrid architectures:

  • Local: CNN or MLP

  • Global: Sparse attention

  • Best of both worlds

17. Key TakeawaysΒΆ

  1. No free lunch: Every efficient attention method trades something (approximation quality, implementation complexity, hardware efficiency).

  2. Task-dependent: Longformer for documents, Linear for streaming, Flash for exact efficient.

  3. Complexity isn’t everything: Flash Attention is O(NΒ²) but faster than O(N) methods due to better hardware utilization.

  4. Locality matters: Most real-world data has strong locality β†’ sparse methods work well.

  5. Softmax may not be necessary: Synthesizer, Linear attention show alternative attention mechanisms can work.

  6. Hardware co-design: Flash Attention shows importance of optimizing for memory hierarchy, not just FLOPs.

  7. Combination is powerful: Hybrid methods (local + global, exact + approximate) often best.

18. ReferencesΒΆ

Foundational papers:

  • Linear Attention: Katharopoulos et al. (2020)

  • Performer: Choromanski et al. (2020)

  • Reformer: Kitaev et al. (2020)

  • Longformer: Beltagy et al. (2020)

  • Synthesizer: Tay et al. (2020)

  • Sparse Transformer: Child et al. (2019)

  • Big Bird: Zaheer et al. (2020)

  • Flash Attention: Dao et al. (2022)

  • Flash Attention 2: Dao (2023)

Surveys:

  • Tay et al. (2020): β€œEfficient Transformers: A Survey”

  • Lin et al. (2021): β€œA Survey of Transformers”

# Advanced Efficient Transformers Implementations

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Optional, Tuple

# ============================================================================
# 1. Linear Attention
# ============================================================================

class LinearAttention(nn.Module):
    """
    Linear attention using kernel trick: O(N) complexity.
    Avoids materializing NΓ—N attention matrix.
    """
    
    def __init__(self, d_model, n_heads, feature_map='elu'):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.feature_map = feature_map
    
    def apply_feature_map(self, x):
        """Apply non-linearity Ο†(x) to make attention non-negative."""
        if self.feature_map == 'elu':
            return F.elu(x) + 1
        elif self.feature_map == 'relu':
            return F.relu(x)
        else:  # identity
            return x
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: (batch, seq_len) - padding mask
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, N, D = x.shape
        
        # Project to Q, K, V
        Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim)
        K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim)
        V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim)
        
        # Apply feature map: Ο†(Q), Ο†(K)
        Q = self.apply_feature_map(Q)  # (B, N, H, D/H)
        K = self.apply_feature_map(K)
        
        # Linear attention: Ο†(Q) @ (Ο†(K)^T @ V)
        # Standard: softmax(QK^T) @ V has shape (N, N) @ (N, D/H) = (N, D/H)
        # Linear: (N, D') @ ((D', N) @ (N, D/H)) = (N, D')  @ (D', D/H) = (N, D/H)
        
        Q = Q.transpose(1, 2)  # (B, H, N, D/H)
        K = K.transpose(1, 2)  # (B, H, N, D/H)
        V = V.transpose(1, 2)  # (B, H, N, D/H)
        
        # Compute K^T @ V: (B, H, D/H, N) @ (B, H, N, D/H) = (B, H, D/H, D/H)
        KV = torch.matmul(K.transpose(-2, -1), V)  # (B, H, D/H, D/H)
        
        # Normalization: sum of K for each position
        K_sum = K.sum(dim=2, keepdim=True)  # (B, H, 1, D/H)
        
        # Compute Q @ KV: (B, H, N, D/H) @ (B, H, D/H, D/H) = (B, H, N, D/H)
        QKV = torch.matmul(Q, KV)
        
        # Normalize: divide by Q @ K_sum^T
        normalizer = torch.matmul(Q, K_sum.transpose(-2, -1))  # (B, H, N, 1)
        output = QKV / (normalizer + 1e-6)
        
        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, N, D)
        return self.out_proj(output)


class CausalLinearAttention(nn.Module):
    """
    Causal linear attention with O(1) per-token generation.
    Maintains running sums for autoregressive decoding.
    """
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        """Causal linear attention (training mode)."""
        B, N, D = x.shape
        
        Q = F.elu(self.q_proj(x)) + 1
        K = F.elu(self.k_proj(x)) + 1
        V = self.v_proj(x)
        
        Q = Q.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Causal: compute cumulative sums
        output = []
        KV_state = torch.zeros(B, self.n_heads, self.head_dim, self.head_dim, 
                              device=x.device, dtype=x.dtype)
        K_state = torch.zeros(B, self.n_heads, self.head_dim, 
                             device=x.device, dtype=x.dtype)
        
        for t in range(N):
            # Update running sums
            k_t = K[:, :, t:t+1, :]  # (B, H, 1, D/H)
            v_t = V[:, :, t:t+1, :]
            KV_state = KV_state + torch.matmul(k_t.transpose(-2, -1), v_t)
            K_state = K_state + k_t.squeeze(2)
            
            # Compute output for position t
            q_t = Q[:, :, t:t+1, :]
            o_t = torch.matmul(q_t, KV_state)
            normalizer = torch.matmul(q_t, K_state.unsqueeze(-1))
            o_t = o_t / (normalizer + 1e-6)
            output.append(o_t)
        
        output = torch.cat(output, dim=2)  # (B, H, N, D/H)
        output = output.transpose(1, 2).contiguous().view(B, N, D)
        return self.out_proj(output)


# ============================================================================
# 2. Performer (FAVOR+)
# ============================================================================

class PerformerAttention(nn.Module):
    """
    Performer attention using positive random features.
    Approximates softmax attention with provable error bounds.
    """
    
    def __init__(self, d_model, n_heads, num_features=256, 
                 orthogonal=True, redraw_features=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.num_features = num_features
        self.orthogonal = orthogonal
        self.redraw_features = redraw_features
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Initialize random features
        self.register_buffer('random_features', 
                           self._create_random_features())
    
    def _create_random_features(self):
        """Create random projection matrix."""
        if self.orthogonal:
            # Orthogonal random features (lower variance)
            features = torch.randn(self.num_features, self.head_dim)
            q, r = torch.qr(features.T)
            return q.T[:self.num_features, :]
        else:
            # i.i.d. Gaussian features
            return torch.randn(self.num_features, self.head_dim) / math.sqrt(self.head_dim)
    
    def apply_kernel_feature_map(self, x):
        """
        Apply FAVOR+ feature map: Ο†(x) = exp(w^T x - ||x||Β²/2)
        
        Args:
            x: (B, H, N, D/H)
        Returns:
            features: (B, H, N, num_features)
        """
        # Compute ||x||Β²/2
        x_norm_sq = (x ** 2).sum(dim=-1, keepdim=True) / 2  # (B, H, N, 1)
        
        # Compute w^T x for all random features
        # x: (B, H, N, D/H), random_features: (num_features, D/H)
        projections = torch.matmul(x, self.random_features.T)  # (B, H, N, num_features)
        
        # Apply exp(w^T x - ||x||Β²/2)
        features = torch.exp(projections - x_norm_sq) / math.sqrt(self.num_features)
        
        return features
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, N, D = x.shape
        
        # Redraw random features if specified (for training)
        if self.training and self.redraw_features:
            self.random_features = self._create_random_features().to(x.device)
        
        # Project to Q, K, V
        Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim)
        K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim)
        V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim)
        
        Q = Q.transpose(1, 2)  # (B, H, N, D/H)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # Apply kernel feature maps
        Q_prime = self.apply_kernel_feature_map(Q)  # (B, H, N, num_features)
        K_prime = self.apply_kernel_feature_map(K)
        
        # Linear attention with random features
        # KV: (B, H, num_features, D/H)
        KV = torch.matmul(K_prime.transpose(-2, -1), V)
        
        # Output: (B, H, N, D/H)
        output = torch.matmul(Q_prime, KV)
        
        # Normalization
        K_sum = K_prime.sum(dim=2, keepdim=True)  # (B, H, 1, num_features)
        normalizer = torch.matmul(Q_prime, K_sum.transpose(-2, -1))
        output = output / (normalizer + 1e-6)
        
        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, N, D)
        return self.out_proj(output)


# ============================================================================
# 3. Longformer (Sliding Window + Global)
# ============================================================================

class LongformerAttention(nn.Module):
    """
    Longformer attention: local sliding window + global attention.
    O(NΒ·w) complexity where w is window size.
    """
    
    def __init__(self, d_model, n_heads, window_size=512):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def _get_local_attention_mask(self, seq_len, device):
        """
        Create sliding window attention mask.
        
        Returns:
            mask: (seq_len, seq_len) with 1s for allowed positions
        """
        # Each position attends to [i - w/2, ..., i, ..., i + w/2]
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask
    
    def forward(self, x, global_mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            global_mask: (batch, seq_len) - 1 for global tokens, 0 otherwise
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, N, D = x.shape
        
        # Project
        Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Create attention mask: local + global
        local_mask = self._get_local_attention_mask(N, x.device)
        
        if global_mask is not None:
            # Global tokens attend to all, and all attend to global tokens
            global_mask = global_mask.unsqueeze(1).unsqueeze(2)  # (B, 1, 1, N)
            global_attn_mask = global_mask | global_mask.transpose(-2, -1)
            attention_mask = local_mask.unsqueeze(0) | global_attn_mask
        else:
            attention_mask = local_mask.unsqueeze(0)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply mask (-inf for disallowed positions)
        scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        # Softmax and attention
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = torch.nan_to_num(attn_weights)  # Handle all -inf rows
        
        output = torch.matmul(attn_weights, V)
        
        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, N, D)
        return self.out_proj(output)


# ============================================================================
# 4. Flash Attention (Simplified Conceptual Version)
# ============================================================================

class FlashAttention(nn.Module):
    """
    Simplified Flash Attention concept (not actual CUDA kernel).
    Demonstrates tiling and online softmax ideas.
    """
    
    def __init__(self, d_model, n_heads, block_size=64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.block_size = block_size
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def online_softmax_attention(self, Q, K, V):
        """
        Online softmax: process K, V in blocks, maintain running max and sum.
        
        This is a simplified version showing the concept.
        Real Flash Attention uses optimized CUDA kernels.
        """
        B, H, N, D = Q.shape
        
        # Initialize output and running statistics
        O = torch.zeros_like(Q)
        m = torch.full((B, H, N, 1), float('-inf'), device=Q.device)  # running max
        ell = torch.zeros(B, H, N, 1, device=Q.device)  # running sum
        
        # Process in blocks (simulated tiling)
        for j in range(0, N, self.block_size):
            # Block of K, V
            K_block = K[:, :, j:j+self.block_size, :]
            V_block = V[:, :, j:j+self.block_size, :]
            
            # Compute scores for this block
            scores = torch.matmul(Q, K_block.transpose(-2, -1)) / math.sqrt(D)
            
            # Update running max
            m_new = torch.maximum(m, scores.max(dim=-1, keepdim=True)[0])
            
            # Reweight previous output and sum
            alpha = torch.exp(m - m_new)
            O = O * alpha
            ell = ell * alpha
            
            # Add new contributions
            exp_scores = torch.exp(scores - m_new)
            O = O + torch.matmul(exp_scores, V_block)
            ell = ell + exp_scores.sum(dim=-1, keepdim=True)
            
            m = m_new
        
        # Final normalization
        O = O / ell
        return O
    
    def forward(self, x):
        """Standard forward (uses online softmax concept)."""
        B, N, D = x.shape
        
        Q = self.q_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Use online softmax attention
        output = self.online_softmax_attention(Q, K, V)
        
        output = output.transpose(1, 2).contiguous().view(B, N, D)
        return self.out_proj(output)


# ============================================================================
# Demonstrations
# ============================================================================

print("=" * 70)
print("Efficient Transformers - Advanced Implementations")
print("=" * 70)

# 1. Complexity comparison
print("\n1. Complexity Analysis:")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Method             β”‚ Time      β”‚ Memory    β”‚ Exact?      β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ Standard Attention β”‚ O(NΒ²d)    β”‚ O(NΒ²)     β”‚ Yes         β”‚")
print("   β”‚ Linear Attention   β”‚ O(NdΒ²)    β”‚ O(Nd)     β”‚ No (approx) β”‚")
print("   β”‚ Performer          β”‚ O(Nmd)    β”‚ O(md)     β”‚ No (approx) β”‚")
print("   β”‚ Longformer         β”‚ O(Nwd)    β”‚ O(Nw)     β”‚ Yes (sparse)β”‚")
print("   β”‚ Flash Attention    β”‚ O(NΒ²d)    β”‚ O(NΒ²)*    β”‚ Yes         β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print("   * Flash: Same asymptotic memory but tiled (fits in SRAM)")

# 2. Linear attention demo
print("\n2. Linear Attention:")
linear_attn = LinearAttention(d_model=256, n_heads=8, feature_map='elu')
x_test = torch.randn(2, 100, 256)
out_linear = linear_attn(x_test)
print(f"   Input: {x_test.shape}")
print(f"   Output: {out_linear.shape}")
print(f"   Feature map: Ο†(x) = elu(x) + 1 (non-negative)")
print(f"   Complexity: O(NΒ·dΒ²) = O(100Β·256Β²) = 6.5M ops")
print(f"   vs Standard O(NΒ²Β·d) = O(100Β²Β·256) = 2.5M ops")
print(f"   Note: Linear wins when N >> d")

# 3. Performer demo
print("\n3. Performer (FAVOR+):")
performer = PerformerAttention(d_model=256, n_heads=8, num_features=256, 
                               orthogonal=True)
out_performer = performer(x_test)
print(f"   Input: {x_test.shape}")
print(f"   Output: {out_performer.shape}")
print(f"   Random features: {256} (orthogonal)")
print(f"   Approximates: exp(q^T k / √d_k)")
print(f"   Complexity: O(NΒ·mΒ·d) = O(100Β·256Β·256) = 6.5M ops")
print(f"   Error bound: O(1/√m) = O(1/16) β‰ˆ 6%")

# 4. Causal linear attention
print("\n4. Causal Linear Attention:")
causal_attn = CausalLinearAttention(d_model=256, n_heads=8)
x_causal = torch.randn(1, 50, 256)
out_causal = causal_attn(x_causal)
print(f"   Input: {x_causal.shape}")
print(f"   Output: {out_causal.shape}")
print(f"   Mechanism: Cumulative sums S_t = S_(t-1) + k_t βŠ— v_t")
print(f"   Generation cost: O(dΒ²) per token (vs O(Nd) standard)")
print(f"   Speedup at N=4096: ~16Γ— faster decoding")

# 5. Longformer demo
print("\n5. Longformer Attention:")
longformer = LongformerAttention(d_model=256, n_heads=8, window_size=16)
x_long = torch.randn(1, 64, 256)
global_mask = torch.zeros(1, 64, dtype=torch.bool)
global_mask[0, 0] = True  # First token is global (e.g., [CLS])
out_long = longformer(x_long, global_mask)
print(f"   Input: {x_long.shape}")
print(f"   Output: {out_long.shape}")
print(f"   Window size: 16")
print(f"   Global tokens: 1 (position 0)")
print(f"   Complexity: O(NΒ·w) = O(64Β·16) = 1,024 attention scores")
print(f"   vs Standard: O(NΒ²) = O(64Β²) = 4,096 scores (4Γ— reduction)")

# 6. Flash attention concept
print("\n6. Flash Attention (Conceptual):")
flash_attn = FlashAttention(d_model=256, n_heads=8, block_size=32)
x_flash = torch.randn(1, 128, 256)
out_flash = flash_attn(x_flash)
print(f"   Input: {x_flash.shape}")
print(f"   Output: {out_flash.shape}")
print(f"   Block size: 32 (tiles)")
print(f"   Key idea: Online softmax with running max/sum")
print(f"   Benefit: Reduced HBM ↔ SRAM transfers")
print(f"   Speedup: 2-4Γ— on A100 (memory-bound workloads)")

# 7. Attention pattern visualization
print("\n7. Attention Patterns:")
print("   Standard Attention:")
print("     β–  β–  β–  β–  β–  β–  β–  β–   (dense, all-to-all)")
print("     β–  β–  β–  β–  β–  β–  β–  β– ")
print("     β–  β–  β–  β–  β–  β–  β–  β– ")
print("\n   Longformer (window=3):")
print("     β–  β–  β–  Β· Β· Β· Β· Β·  (sparse, local window)")
print("     β–  β–  β–  β–  Β· Β· Β· Β·")
print("     Β· β–  β–  β–  β–  Β· Β· Β·")
print("\n   Longformer + Global (token 0):")
print("     β–  β–  β–  β–  β–  β–  β–  β–   (global token attends all)")
print("     β–  β–  β–  β–  Β· Β· Β· Β·")
print("     β–  Β· β–  β–  β–  Β· Β· Β·")

# 8. When to use guide
print("\n8. Method Selection Guide:")
print("   Sequence length N < 1024:")
print("     β†’ Use standard attention (exact, well-optimized)")
print("\n   1024 < N < 4096:")
print("     β†’ Longformer (good local/global tradeoff)")
print("     β†’ Flash Attention (exact + fast on modern GPUs)")
print("\n   N > 4096:")
print("     β†’ Linear Attention (extreme lengths, streaming)")
print("     β†’ Reformer (if memory constrained)")
print("\n   Need exact attention:")
print("     β†’ Flash Attention (best of both worlds)")
print("\n   Can tolerate approximation:")
print("     β†’ Performer (provable bounds, research)")
print("     β†’ Linear (fastest, autoregressive)")

# 9. Approximation quality
print("\n9. Approximation Quality:")
# Compare linear vs standard attention on small example
x_small = torch.randn(1, 32, 64)
std_attn = nn.MultiheadAttention(64, 4, batch_first=True)
lin_attn = LinearAttention(64, 4, feature_map='elu')

std_out, _ = std_attn(x_small, x_small, x_small)
lin_out = lin_attn(x_small)

relative_error = (std_out - lin_out).norm() / std_out.norm()
print(f"   Standard vs Linear attention:")
print(f"     Input: {x_small.shape}")
print(f"     Output difference (relative): {relative_error.item():.4f}")
print(f"     Typical range: 5-20% for linear approximations")
print(f"     Task accuracy drop: <1% on many benchmarks")

# 10. Memory savings
print("\n10. Memory Savings:")
seq_lengths = [512, 1024, 2048, 4096, 8192]
print("   Attention Matrix Memory (MB, float32):")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Seq Len β”‚ Standard β”‚ Linear   β”‚ Savings  β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
for N in seq_lengths:
    std_mem = N * N * 4 / 1e6  # float32 = 4 bytes
    lin_mem = N * 64 * 4 / 1e6  # assume d=64
    savings = f"{std_mem / lin_mem:.1f}Γ—"
    print(f"   β”‚ {N:7} β”‚ {std_mem:8.2f} β”‚ {lin_mem:8.2f} β”‚ {savings:8} β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")

print("\n" + "=" * 70)