import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Bahdanau Attention (Additive)ΒΆ

Energy FunctionΒΆ

\[e_{ij} = v^T \tanh(W_h h_i + W_s s_j)\]

Attention WeightsΒΆ

\[\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^T \exp(e_{kj})}\]

Context VectorΒΆ

\[c_j = \sum_{i=1}^T \alpha_{ij} h_i\]

πŸ“š Reference Materials:

class BahdanauAttention(nn.Module):
    """Additive attention mechanism."""
    
    def __init__(self, hidden_dim, query_dim):
        super().__init__()
        self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_s = nn.Linear(query_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, hidden_states, query):
        """
        Args:
            hidden_states: (batch, seq_len, hidden_dim)
            query: (batch, query_dim)
        """
        # Expand query for broadcasting
        query = query.unsqueeze(1)  # (batch, 1, query_dim)
        
        # Compute energy
        energy = self.v(torch.tanh(
            self.W_h(hidden_states) + self.W_s(query)
        ))  # (batch, seq_len, 1)
        
        # Compute attention weights
        attention = F.softmax(energy.squeeze(-1), dim=1)  # (batch, seq_len)
        
        # Compute context
        context = torch.bmm(attention.unsqueeze(1), hidden_states)  # (batch, 1, hidden_dim)
        context = context.squeeze(1)  # (batch, hidden_dim)
        
        return context, attention

print("BahdanauAttention defined")

2. Luong Attention (Multiplicative)ΒΆ

Dot ProductΒΆ

\[e_{ij} = h_i^T s_j\]

GeneralΒΆ

\[e_{ij} = h_i^T W s_j\]

ConcatΒΆ

\[e_{ij} = v^T \tanh(W [h_i; s_j])\]
class LuongAttention(nn.Module):
    """Multiplicative attention mechanism."""
    
    def __init__(self, hidden_dim, method='general'):
        super().__init__()
        self.method = method
        
        if method == 'general':
            self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(2 * hidden_dim, hidden_dim)
            self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, hidden_states, query):
        """
        Args:
            hidden_states: (batch, seq_len, hidden_dim)
            query: (batch, hidden_dim)
        """
        if self.method == 'dot':
            # Dot product
            energy = torch.bmm(hidden_states, query.unsqueeze(2))  # (batch, seq_len, 1)
            energy = energy.squeeze(-1)
        
        elif self.method == 'general':
            # General: h^T W s
            query_transformed = self.W(query)  # (batch, hidden_dim)
            energy = torch.bmm(hidden_states, query_transformed.unsqueeze(2))
            energy = energy.squeeze(-1)
        
        elif self.method == 'concat':
            # Concat: v^T tanh(W[h;s])
            batch_size, seq_len, _ = hidden_states.size()
            query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
            concat = torch.cat([hidden_states, query_expanded], dim=2)
            energy = self.v(torch.tanh(self.W(concat))).squeeze(-1)
        
        # Attention weights
        attention = F.softmax(energy, dim=1)
        
        # Context
        context = torch.bmm(attention.unsqueeze(1), hidden_states)
        context = context.squeeze(1)
        
        return context, attention

print("LuongAttention defined")

3. Scaled Dot-Product AttentionΒΆ

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Scaling prevents large dot products in high dimensions.

class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention."""
    
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k
    
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: (batch, seq_len_q, d_k)
            K: (batch, seq_len_k, d_k)
            V: (batch, seq_len_v, d_v)
        """
        # Compute scores
        scores = torch.bmm(Q, K.transpose(1, 2)) / np.sqrt(self.d_k)
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Attention weights
        attention = F.softmax(scores, dim=-1)
        
        # Context
        context = torch.bmm(attention, V)
        
        return context, attention

print("ScaledDotProductAttention defined")

4. Multi-Head AttentionΒΆ

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]
\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism."""
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(self.d_k)
    
    def split_heads(self, x):
        """Split into multiple heads."""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine heads."""
        batch_size, n_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear projections
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V)
        
        # Split into heads
        Q = self.split_heads(Q)  # (batch, n_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Apply attention for each head
        contexts = []
        attentions = []
        for i in range(self.n_heads):
            ctx, attn = self.attention(Q[:, i], K[:, i], V[:, i], mask)
            contexts.append(ctx)
            attentions.append(attn)
        
        # Concatenate heads
        context = torch.stack(contexts, dim=1)  # (batch, n_heads, seq_len, d_k)
        context = self.combine_heads(context)  # (batch, seq_len, d_model)
        
        # Output projection
        output = self.W_O(context)
        
        return output, attentions

print("MultiHeadAttention defined")

Test Attention MechanismsΒΆ

We evaluate each attention variant by feeding the same input through each mechanism and comparing the outputs and attention weight distributions. Key properties to check include: attention weight sparsity (does the mechanism focus on a few positions or spread attention broadly?), computational cost (how does runtime scale with sequence length?), and output quality (do different mechanisms produce meaningfully different representations?). These comparisons help practitioners choose the right attention mechanism for their specific constraints on sequence length, memory, and accuracy requirements.

# Generate synthetic data
batch_size = 2
seq_len = 10
hidden_dim = 64

hidden_states = torch.randn(batch_size, seq_len, hidden_dim).to(device)
query = torch.randn(batch_size, hidden_dim).to(device)

# Test Bahdanau
bahdanau = BahdanauAttention(hidden_dim, hidden_dim).to(device)
ctx_b, attn_b = bahdanau(hidden_states, query)
print(f"Bahdanau - Context: {ctx_b.shape}, Attention: {attn_b.shape}")

# Test Luong (general)
luong = LuongAttention(hidden_dim, method='general').to(device)
ctx_l, attn_l = luong(hidden_states, query)
print(f"Luong - Context: {ctx_l.shape}, Attention: {attn_l.shape}")

# Test Multi-Head
mha = MultiHeadAttention(d_model=64, n_heads=4).to(device)
out, attns = mha(hidden_states, hidden_states, hidden_states)
print(f"Multi-Head - Output: {out.shape}, Heads: {len(attns)}")

Visualize AttentionΒΆ

Attention weight visualization is one of the most powerful tools for understanding what a model has learned. Heatmaps of the attention matrix reveal which input positions each output position attends to, exposing patterns like local context focus, syntactic dependencies, or global aggregation. Comparing attention patterns across different heads in multi-head attention shows that heads specialize: some attend locally, others capture long-range dependencies, and some develop task-specific patterns. These visualizations are valuable both for debugging model behavior and for generating human-interpretable explanations of model decisions.

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Bahdanau
axes[0].imshow(attn_b[0].detach().cpu().numpy().reshape(1, -1), aspect='auto', cmap='Blues')
axes[0].set_xlabel('Position', fontsize=11)
axes[0].set_title('Bahdanau Attention', fontsize=12)
axes[0].set_yticks([])

# Luong
axes[1].imshow(attn_l[0].detach().cpu().numpy().reshape(1, -1), aspect='auto', cmap='Blues')
axes[1].set_xlabel('Position', fontsize=11)
axes[1].set_title('Luong Attention', fontsize=12)
axes[1].set_yticks([])

# Multi-Head (first head)
axes[2].imshow(attns[0][0].detach().cpu().numpy(), aspect='auto', cmap='Blues')
axes[2].set_xlabel('Key Position', fontsize=11)
axes[2].set_ylabel('Query Position', fontsize=11)
axes[2].set_title('Multi-Head (Head 1)', fontsize=12)

plt.tight_layout()
plt.show()

SummaryΒΆ

Attention Variants:ΒΆ

Bahdanau (Additive):

  • Energy: \(v^T \tanh(W_h h + W_s s)\)

  • Original seq2seq attention

  • More parameters

Luong (Multiplicative):

  • Dot: \(h^T s\)

  • General: \(h^T W s\)

  • Concat: \(v^T \tanh(W[h;s])\)

  • Computationally efficient

Scaled Dot-Product:

  • \(QK^T / \sqrt{d_k}\)

  • Prevents saturation in high dims

  • Foundation for Transformers

Multi-Head:

  • Multiple parallel attention

  • Different representation subspaces

  • Richer feature extraction

Applications:ΒΆ

  • Machine translation

  • Text summarization

  • Image captioning

  • Speech recognition

Advanced Attention Mechanisms: Mathematical Foundations and Modern ArchitecturesΒΆ

1. Introduction to Attention MechanismsΒΆ

Attention mechanisms have revolutionized deep learning by enabling models to focus on relevant parts of the input when making predictions. The core idea is to compute a weighted combination of input features, where the weights indicate the importance of each feature for the current task.

1.1 Fundamental Attention EquationΒΆ

The general attention mechanism can be formulated as:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{f(Q, K)}{\sqrt{d_k}}\right)V\]

where:

  • \(Q\) (query): What we’re looking for

  • \(K\) (key): What each input element represents

  • \(V\) (value): The actual content to be aggregated

  • \(f(Q, K)\): Compatibility function (often dot product)

  • \(d_k\): Dimension of keys (for scaling)

1.2 Historical EvolutionΒΆ

  1. Bahdanau Attention (2014): Additive attention for seq2seq models

    • \(\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{ik})}\) where \(e_{ij} = v^T \tanh(W_1 h_i + W_2 s_j)\)

  2. Luong Attention (2015): Multiplicative attention variants

    • Dot: \(\text{score}(h_i, s_j) = h_i^T s_j\)

    • General: \(\text{score}(h_i, s_j) = h_i^T W s_j\)

    • Concat: \(\text{score}(h_i, s_j) = v^T \tanh(W[h_i; s_j])\)

  3. Self-Attention (2017): Attention within the same sequence

    • Foundation for Transformers

    • All positions attend to all other positions

2. Scaled Dot-Product AttentionΒΆ

2.1 Mathematical FormulationΒΆ

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Computational Steps:

  1. Compute similarity scores: \(S = QK^T\) (shape: \([n_q \times n_k]\))

  2. Scale: \(S' = \frac{S}{\sqrt{d_k}}\) (prevents saturation of softmax)

  3. Normalize: \(A = \text{softmax}(S')\) (attention weights)

  4. Aggregate: \(\text{Output} = AV\) (weighted sum of values)

2.2 Why Scaling by \(\sqrt{d_k}\)?ΒΆ

For large \(d_k\), dot products grow in magnitude, pushing softmax into saturation regions with tiny gradients.

Analysis: If \(Q, K \sim \mathcal{N}(0, 1)\), then \(QK^T\) has variance \(d_k\). Scaling by \(\sqrt{d_k}\) normalizes variance to 1.

2.3 Complexity AnalysisΒΆ

  • Time: \(O(n^2 d)\) where \(n\) is sequence length, \(d\) is dimension

  • Space: \(O(n^2)\) for attention matrix

  • Bottleneck: Quadratic scaling with sequence length

3. Multi-Head Attention (MHA)ΒΆ

3.1 ArchitectureΒΆ

Instead of single attention, compute \(h\) parallel attention operations:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]

where each head is:

\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

Parameter Matrices:

  • \(W_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}\)

  • \(W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\)

  • \(W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\)

  • \(W^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}\)

Typically: \(d_k = d_v = d_{\text{model}}/h\)

3.2 Benefits of Multi-HeadΒΆ

  1. Different representation subspaces: Each head learns different attention patterns

  2. Ensemble effect: Multiple heads provide robustness

  3. Increased expressivity: Captures various types of relationships

  4. Parallelization: All heads computed simultaneously

3.3 Common Head PatternsΒΆ

Research has identified specialized heads:

  • Positional heads: Attend to specific relative positions

  • Syntactic heads: Focus on grammatical relationships

  • Rare word heads: Handle out-of-distribution tokens

  • Delimiter heads: Attend to sentence boundaries

4. Cross-Attention vs Self-AttentionΒΆ

4.1 Self-AttentionΒΆ

Queries, keys, and values all from the same sequence: $\(Q = K = V = X\)$

Use cases:

  • Encoder layers in Transformers

  • Capturing dependencies within a sequence

  • Learning contextual representations

4.2 Cross-Attention (Encoder-Decoder Attention)ΒΆ

Queries from one sequence, keys/values from another:

  • \(Q\) from decoder (target)

  • \(K, V\) from encoder (source)

Use cases:

  • Machine translation (target attends to source)

  • Image captioning (text attends to image features)

  • Retrieval-augmented generation (query attends to documents)

4.3 Mathematical ComparisonΒΆ

Self-Attention Complexity:

  • Input: \(X \in \mathbb{R}^{n \times d}\)

  • Output: \(X' \in \mathbb{R}^{n \times d}\)

  • Attention matrix: \(n \times n\)

Cross-Attention Complexity:

  • Query input: \(X_q \in \mathbb{R}^{n_q \times d}\)

  • Key/value input: \(X_{kv} \in \mathbb{R}^{n_{kv} \times d}\)

  • Attention matrix: \(n_q \times n_{kv}\)

5. Positional Encoding and AttentionΒΆ

5.1 Why Positional Encoding?ΒΆ

Attention is permutation-invariant: swapping input positions doesn’t change output (without positional information).

Solution: Add positional information to input embeddings.

5.2 Sinusoidal Positional EncodingΒΆ

\[PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)\]
\[PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)\]

Properties:

  • Deterministic (no learned parameters)

  • Continuous (can extrapolate to longer sequences)

  • Relative position encoding through trigonometric identities: $\(PE_{pos+k} = \text{Linear}(PE_{pos})\)$

5.3 Learned Positional EmbeddingsΒΆ

Alternative: Learn position embeddings like word embeddings

  • \(PE \in \mathbb{R}^{n_{\max} \times d}\) (lookup table)

  • Better for fixed-length sequences

  • Cannot extrapolate beyond training length

5.4 Relative Positional EncodingΒΆ

Shaw et al. (2018): Incorporate relative distances directly into attention:

\[e_{ij} = \frac{x_i W^Q (x_j W^K + a_{ij}^K)^T}{\sqrt{d_k}}\]

where \(a_{ij}^K\) is learned relative position embedding for distance \((j-i)\).

6. Attention Variants and OptimizationsΒΆ

6.1 Additive (Bahdanau) AttentionΒΆ

\[e_{ij} = v^T \tanh(W_1 h_i + W_2 s_j)\]
\[\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{ik})}\]
\[c_i = \sum_j \alpha_{ij} h_j\]

Complexity: \(O(n^2 d)\) (same as dot-product) Advantage: Can handle different dimensionalities for \(h_i\) and \(s_j\)

6.2 Multiplicative (Luong) AttentionΒΆ

Dot Product: $\(\text{score}(h_i, s_j) = h_i^T s_j\)$

General: $\(\text{score}(h_i, s_j) = h_i^T W s_j\)$

Concat: $\(\text{score}(h_i, s_j) = v^T \tanh(W[h_i; s_j])\)$

6.3 Local AttentionΒΆ

Attention to a window of positions instead of all positions:

\[p_t = \text{aligned position} = S \cdot \text{sigmoid}(v_p^T \tanh(W_p h_t))\]
\[\alpha_{ij} = \text{align}(h_i, s_j) \cdot \exp\left(-\frac{(j - p_t)^2}{2\sigma^2}\right)\]

Benefits:

  • Reduces complexity from \(O(n^2)\) to \(O(nw)\) where \(w\) is window size

  • Maintains local context sensitivity

6.4 Sparse Attention PatternsΒΆ

Fixed Patterns (Child et al., 2019):

  1. Strided: Attend to every \(k\)-th position

  2. Fixed: Attend to a fixed set of positions

  3. Combined: Mix of strided and fixed

Learnable Sparsity (Correia et al., 2019):

  • Use attention weights to predict sparsity mask

  • Learn which positions to attend to

6.5 Linear Attention ApproximationsΒΆ

Goal: Approximate softmax attention with linear complexity

Kernel Trick (Katharopoulos et al., 2020): $\(\text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)(\phi(K)^T \mathbf{1})}\)$

where \(\phi\) is a feature map (e.g., \(\phi(x) = \text{elu}(x) + 1\))

Complexity: \(O(nd^2)\) instead of \(O(n^2d)\)

7. Attention MasksΒΆ

7.1 Padding MaskΒΆ

Prevent attention to padding tokens: $\(\text{mask}_{ij} = \begin{cases} 0 & \text{if position } j \text{ is valid} \\ -\infty & \text{if position } j \text{ is padding} \end{cases}\)$

Applied before softmax: \(\text{softmax}(S + \text{mask})\)

7.2 Causal (Look-Ahead) MaskΒΆ

For autoregressive models, prevent attending to future positions: $\(\text{mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}\)$

Creates lower-triangular attention pattern.

7.3 Combined MasksΒΆ

For masked language modeling:

  • Padding mask for efficiency

  • Causal mask for autoregression

  • Random mask for BERT-style pre-training

\[\text{final\_mask} = \text{padding\_mask} \land \text{causal\_mask} \land \text{custom\_mask}\]

8. Attention in Different ArchitecturesΒΆ

8.1 Transformer EncoderΒΆ

Self-attention layers:

  1. Multi-head self-attention

  2. Layer normalization + residual connection

  3. Feed-forward network

  4. Layer normalization + residual connection

\[\begin{split}\begin{align} Z &= \text{LayerNorm}(X + \text{MultiHead}(X, X, X)) \\ \text{Output} &= \text{LayerNorm}(Z + \text{FFN}(Z)) \end{align}\end{split}\]

8.2 Transformer DecoderΒΆ

Two types of attention:

  1. Masked self-attention: Attends to previous positions

  2. Cross-attention: Attends to encoder output

\[\begin{split}\begin{align} Z_1 &= \text{LayerNorm}(X + \text{MaskedMultiHead}(X, X, X)) \\ Z_2 &= \text{LayerNorm}(Z_1 + \text{MultiHead}(Z_1, \text{Enc}, \text{Enc})) \\ \text{Output} &= \text{LayerNorm}(Z_2 + \text{FFN}(Z_2)) \end{align}\end{split}\]

8.3 Vision Transformers (ViT)ΒΆ

Patch-based attention:

  1. Split image into patches: \(x \in \mathbb{R}^{H \times W \times C} \to \text{patches} \in \mathbb{R}^{N \times (P^2 \cdot C)}\)

  2. Linear projection: \(z_0 = [\text{cls}; x_1 E; x_2 E; ...; x_N E] + E_{pos}\)

  3. Self-attention across patches

Complexity: \(O(N^2)\) where \(N = HW/P^2\) (number of patches)

8.4 Attention in CNNsΒΆ

Non-local neural networks (Wang et al., 2018): $\(y_i = \frac{1}{C(x)} \sum_{\forall j} f(x_i, x_j)g(x_j)\)$

where:

  • \(f(x_i, x_j) = e^{\theta(x_i)^T \phi(x_j)}\) (Gaussian)

  • \(g(x_j) = W_g x_j\) (value transformation)

  • \(C(x)\) is normalization factor

9. Attention InterpretabilityΒΆ

9.1 Attention Weights as ExplanationsΒΆ

Common assumption: High attention weight = high importance

Challenges (Jain & Wallace, 2019):

  • Attention weights may not correlate with gradient-based importance

  • Multiple attention patterns can produce similar outputs

  • Attention may capture spurious correlations

9.2 Attention RolloutΒΆ

Aggregate attention across layers: $\(\tilde{A}_l = 0.5 \cdot I + 0.5 \cdot A_l\)\( \)\(A_{\text{rollout}} = \tilde{A}_1 \times \tilde{A}_2 \times ... \times \tilde{A}_L\)$

Provides layer-wise attention flow visualization.

9.3 Attention FlowΒΆ

Track gradient-weighted attention: $\(\text{Flow}_{ij}^{(l)} = A_{ij}^{(l)} \cdot \left|\frac{\partial \mathcal{L}}{\partial A_{ij}^{(l)}}\right|\)$

Identifies attention weights that impact the loss.

10. Modern Attention Innovations (2020-2024)ΒΆ

10.1 Flash Attention (Dao et al., 2022)ΒΆ

Key idea: Tiling and recomputation to reduce memory access

Algorithm:

  1. Split Q, K, V into blocks

  2. Compute attention incrementally

  3. Recompute during backward pass instead of storing

Benefits:

  • 2-4Γ— speedup on long sequences

  • Linear memory in sequence length

  • Exact (no approximation)

10.2 Multi-Query Attention (MQA)ΒΆ

Share keys and values across all heads, only queries are different:

  • \(h\) query projections: \(W_1^Q, ..., W_h^Q\)

  • 1 key projection: \(W^K\)

  • 1 value projection: \(W^V\)

Benefits:

  • Faster inference (less KV cache for autoregressive generation)

  • Minimal quality degradation

10.3 Grouped-Query Attention (GQA)ΒΆ

Compromise between MHA and MQA:

  • Divide heads into \(g\) groups

  • Share K, V within each group

  • \(g=1\): MQA, \(g=h\): MHA

Used in LLaMA-2, Mistral.

10.4 Sliding Window AttentionΒΆ

Attend to fixed window + global tokens: $\(\text{Attention}(i) = \{i-w, ..., i+w\} \cup \text{Global}\)$

Benefits:

  • Linear complexity: \(O(nw)\)

  • Maintains long-range through stacking

  • Used in Longformer, BigBird

10.5 Cross-Covariance Attention (XCA)ΒΆ

Transpose attention to feature dimension: $\(\text{XCA}(Q, K, V) = \text{softmax}\left(\frac{Q^T K}{\tau}\right)V^T\)$

Benefits:

  • Complexity depends on feature dimension \(d\) not sequence length \(n\)

  • Better for high-resolution vision tasks

11. Attention Efficiency TechniquesΒΆ

11.1 Complexity ComparisonΒΆ

Method

Time Complexity

Space Complexity

Approximation

Full Attention

\(O(n^2 d)\)

\(O(n^2)\)

Exact

Sparse (fixed)

\(O(nk d)\)

\(O(nk)\)

Lossy

Linear

\(O(nd^2)\)

\(O(nd)\)

Approximate

LSH Attention

\(O(n \log n \cdot d)\)

\(O(n \log n)\)

Approximate

Sliding Window

\(O(nw d)\)

\(O(nw)\)

Lossy

Kernelized

\(O(nd^2)\)

\(O(nd)\)

Approximate

where \(n\) = sequence length, \(d\) = dimension, \(k\) = sparsity, \(w\) = window size

11.2 Reformer (LSH Attention)ΒΆ

Locality-Sensitive Hashing for approximate nearest neighbors:

  1. Hash queries and keys: \(h(x) = \arg\max(x; -x) \cdot R\)

  2. Sort by hash

  3. Attend within same hash bucket

Complexity: \(O(n \log n)\) with high probability

11.3 LinformerΒΆ

Project keys and values to lower dimension: $\(K' = KE_k, \quad V' = VE_v\)\( where \)E_k, E_v \in \mathbb{R}^{n \times k}\( with \)k \ll n$

Complexity: \(O(nk)\) instead of \(O(n^2)\)

12. Attention Training DynamicsΒΆ

12.1 Gradient FlowΒΆ

Attention provides skip connections for gradient flow: $\(\frac{\partial \mathcal{L}}{\partial x_i} = \frac{\partial \mathcal{L}}{\partial y_i} + \sum_j \alpha_{ji} \frac{\partial \mathcal{L}}{\partial y_j}\)$

This mitigates vanishing gradients in deep networks.

12.2 Attention EntropyΒΆ

Low entropy: Peaked distribution (one token dominates) High entropy: Uniform distribution (attends equally to all)

\[H(A_i) = -\sum_j A_{ij} \log A_{ij}\]

Observations:

  • Early layers: Higher entropy (broad attention)

  • Late layers: Lower entropy (focused attention)

  • Task-dependent patterns

12.3 InitializationΒΆ

Standard: Initialize projection matrices with Xavier/He initialization

Special considerations:

  • Scale attention logits to prevent saturation: \(\frac{1}{\sqrt{d_k}}\)

  • Warm-up learning rate to stabilize early training

  • Layer normalization before attention (Pre-LN) improves stability

13. Applications by DomainΒΆ

13.1 Natural Language ProcessingΒΆ

  • Machine Translation: Cross-attention between source and target

  • Question Answering: Query attends to context

  • Summarization: Extract salient information via attention

  • Language Modeling: Self-attention for context

13.2 Computer VisionΒΆ

  • Image Classification: ViT, patch-based attention

  • Object Detection: DETR, attention-based detection

  • Segmentation: Attention for pixel-wise predictions

  • Image Generation: Cross-attention in diffusion models

13.3 Multimodal LearningΒΆ

  • Vision-Language: CLIP, cross-modal attention

  • Audio-Visual: Attend across modalities

  • Video Understanding: Temporal + spatial attention

  • Image Captioning: Visual features attend to generated text

13.4 Structured PredictionΒΆ

  • Graph Neural Networks: Attention over graph edges

  • Point Clouds: Attention over 3D points

  • Time Series: Temporal attention

  • Molecular Generation: Attention over atoms

14. Theoretical AnalysisΒΆ

14.1 Attention as Kernel SmoothingΒΆ

Attention can be viewed as kernel regression: $\(\text{Attention}(q, K, V) = \sum_i k(q, k_i) v_i\)$

where \(k(q, k_i) = \frac{\exp(q^T k_i / \sqrt{d})}{\sum_j \exp(q^T k_j / \sqrt{d})}\) is a normalized kernel.

14.2 Universal ApproximationΒΆ

Theorem (Yun et al., 2020): Transformers with attention are universal approximators for sequence-to-sequence functions.

Requirements:

  • Sufficient depth

  • Sufficient dimension

  • Appropriate positional encoding

14.3 Attention as Hopfield NetworksΒΆ

Connection (Ramsauer et al., 2020): Modern Hopfield networks with continuous states: $\(\text{Update}(X, \xi) = \text{softmax}(\beta X^T \xi)X\)$

Equivalent to attention with query \(\xi\), keys/values \(X\).

14.4 Expressiveness vs Efficiency Trade-offΒΆ

Full attention: Maximum expressiveness, \(O(n^2)\) cost Sparse attention: Reduced expressiveness, \(O(n)\) cost Question: What’s the minimum attention pattern for a given task?

Result (Alon & Yahav, 2021): Some tasks require \(\Omega(n^2)\) attention; sparse patterns insufficient.

15. Best Practices and GuidelinesΒΆ

15.1 Choosing Attention TypeΒΆ

Full Self-Attention:

  • βœ… Short sequences (\(n < 2048\))

  • βœ… Tasks requiring global context

  • ❌ Long sequences (memory constraints)

Sparse/Local Attention:

  • βœ… Long sequences (\(n > 4096\))

  • βœ… Local dependencies dominant

  • ❌ Tasks requiring long-range dependencies

Cross-Attention:

  • βœ… Two separate input sequences

  • βœ… Alignment tasks (translation, retrieval)

  • βœ… Conditioning (image β†’ text)

Linear Attention:

  • βœ… Very long sequences (\(n > 16k\))

  • βœ… Real-time inference requirements

  • ⚠️ May sacrifice quality for speed

15.2 Hyperparameter SelectionΒΆ

Number of heads (\(h\)):

  • Typical: 8-16 for large models

  • Rule of thumb: \(h = d_{\text{model}} / 64\)

  • More heads = more diversity, but diminishing returns

Head dimension (\(d_k\)):

  • Standard: 64

  • Range: 32-128

  • Trade-off: Smaller = faster, larger = more expressive

Dropout (\(p_{\text{dropout}}\)):

  • Typical: 0.1

  • Apply to attention weights and feedforward

  • Higher for smaller datasets

15.3 Debugging AttentionΒΆ

Check attention patterns:

  • Visualize attention maps

  • Verify causality (for autoregressive)

  • Check for degenerate patterns (uniform or peaked)

Monitor metrics:

  • Attention entropy (per layer, per head)

  • Gradient norms

  • Attention weight statistics (mean, std, sparsity)

Common issues:

  • All attention on [CLS] token β†’ Add regularization

  • Uniform attention β†’ Check initialization, learning rate

  • NaN in attention β†’ Gradient clipping, fp16 precision issues

15.4 Production OptimizationΒΆ

Inference:

  • KV caching for autoregressive generation

  • Batch processing for parallel sequences

  • Quantization (8-bit attention)

  • Operator fusion (Flash Attention)

Training:

  • Gradient checkpointing for memory

  • Mixed precision (fp16/bf16)

  • Distributed attention (model parallelism)

  • Sparse attention for long sequences

16. Recent Advances and Future Directions (2023-2024)ΒΆ

16.1 Attention-Free ArchitecturesΒΆ

RWKV (Receptance Weighted Key Value):

  • Linear complexity attention alternative

  • Combines RNN and Transformer benefits

  • Competitive performance on language tasks

Mamba (Structured State Spaces):

  • Selective state space models

  • \(O(n)\) complexity

  • Strong results on long-range tasks

16.2 Mixture-of-Experts AttentionΒΆ

Route tokens to different expert attention modules: $\(\text{MoE-Attention}(x) = \sum_i g_i(x) \cdot \text{Expert}_i(x)\)$

Benefits:

  • Conditional computation

  • Increased capacity without proportional cost

  • Different experts learn different patterns

16.3 Neural Architecture Search for AttentionΒΆ

Automatically discover optimal attention patterns:

  • Learnable sparsity masks

  • Adaptive head count

  • Dynamic window sizes

Example: HAT (Hardware-Aware Transformers) optimizes for specific hardware.

16.4 Attention for Long ContextΒΆ

Challenges:

  • Memory: \(O(n^2)\) attention matrix

  • Computation: Quadratic operations

  • Training: Gradient instability

Solutions (2024):

  • Ring Attention: Distribute attention across devices

  • Streaming Attention: Process incrementally

  • Sparse + Dense hybrid: Local sparse + global dense

16.5 Multimodal Attention FusionΒΆ

Cross-modal attention for vision-language models:

  • CLIP: Contrastive learning with cross-attention

  • Flamingo: Interleaved vision-text attention

  • GPT-4V: Dense cross-modal attention

Challenges:

  • Modality alignment

  • Fusion strategy (early vs late)

  • Computational efficiency

17. Mathematical AppendixΒΆ

17.1 Softmax GradientΒΆ

\[\frac{\partial}{\partial x_i} \text{softmax}(x)_j = \text{softmax}(x)_j (\delta_{ij} - \text{softmax}(x)_i)\]

where \(\delta_{ij}\) is Kronecker delta.

17.2 Attention Backward PassΒΆ

Given \(\mathcal{L}\) (loss), compute gradients:

\[\frac{\partial \mathcal{L}}{\partial V} = A^T \frac{\partial \mathcal{L}}{\partial \text{Output}}\]
\[\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial \text{Output}} V^T\]
\[\frac{\partial \mathcal{L}}{\partial Q} = \frac{1}{\sqrt{d_k}} \left(\frac{\partial \mathcal{L}}{\partial A} \odot A \odot (1 - A)\right) K\]
\[\frac{\partial \mathcal{L}}{\partial K} = \frac{1}{\sqrt{d_k}} Q^T \left(\frac{\partial \mathcal{L}}{\partial A} \odot A \odot (1 - A)\right)\]

17.3 Multi-Head Attention GradientsΒΆ

For head \(i\): $\(\frac{\partial \mathcal{L}}{\partial W_i^Q} = X^T \frac{\partial \mathcal{L}}{\partial Q_i}\)$

Similar for \(W_i^K\), \(W_i^V\), and output projection \(W^O\).

18. Summary and Key TakeawaysΒΆ

Core Concepts:ΒΆ

  1. Attention = Weighted aggregation based on learned similarity

  2. Scaling factor \(1/\sqrt{d_k}\) prevents softmax saturation

  3. Multi-head enables diverse representation learning

  4. Positional encoding required for sequence order

  5. Quadratic complexity is main bottleneck

Design Choices:ΒΆ

  • Full attention: Best quality, \(O(n^2)\) cost

  • Sparse attention: Trade quality for efficiency

  • Linear attention: Approximate for long sequences

  • Multi-head count: 8-16 typical, \(h \propto d_{\text{model}}\)

Production Considerations:ΒΆ

  • KV caching for inference speedup

  • Mixed precision for memory efficiency

  • Gradient checkpointing for training large models

  • Attention visualization for interpretability

Next Steps: Implement these mechanisms, experiment with variants, profile performance, and adapt to your specific use case!

"""
Advanced Attention Mechanisms - Production Implementation
Comprehensive PyTorch implementations of modern attention variants

Architecture Coverage:
1. Scaled Dot-Product Attention (base mechanism)
2. Multi-Head Attention (MHA)
3. Multi-Query Attention (MQA) - fast inference
4. Grouped-Query Attention (GQA) - MHA/MQA hybrid
5. Cross-Attention (encoder-decoder)
6. Flash Attention (memory-efficient)
7. Sliding Window Attention (local + global)
8. Linear Attention (kernel approximation)
9. Relative Position Encoding (Shaw et al.)
10. Rotary Position Embedding (RoPE)

Performance Optimizations:
- KV caching for autoregressive generation
- Gradient checkpointing for memory efficiency
- Mixed precision support (fp16/bf16)
- Fused operations where possible

Author: Advanced Deep Learning Course
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math


# ============================================================================
# 1. Core Attention Building Blocks
# ============================================================================

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
    
    Args:
        dropout: Dropout probability on attention weights
        scale: Optional manual scaling factor (default: 1/sqrt(d_k))
    """
    def __init__(self, dropout: float = 0.1, scale: Optional[float] = None):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.scale = scale
        
    def forward(
        self,
        query: torch.Tensor,  # [batch, n_q, d_k]
        key: torch.Tensor,    # [batch, n_k, d_k]
        value: torch.Tensor,  # [batch, n_k, d_v]
        mask: Optional[torch.Tensor] = None,  # [batch, n_q, n_k] or broadcastable
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            query: Query tensor [batch, n_heads, n_q, d_k]
            key: Key tensor [batch, n_heads, n_k, d_k]
            value: Value tensor [batch, n_heads, n_k, d_v]
            mask: Attention mask (True = keep, False = mask out)
            return_attention: Whether to return attention weights
            
        Returns:
            output: [batch, n_heads, n_q, d_v]
            attention_weights: [batch, n_heads, n_q, n_k] (if return_attention=True)
        """
        d_k = query.size(-1)
        scale = self.scale if self.scale is not None else 1.0 / math.sqrt(d_k)
        
        # Compute attention scores: QK^T / sqrt(d_k)
        scores = torch.matmul(query, key.transpose(-2, -1)) * scale  # [batch, n_heads, n_q, n_k]
        
        # Apply mask (set masked positions to -inf before softmax)
        if mask is not None:
            scores = scores.masked_fill(~mask, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, value)  # [batch, n_heads, n_q, d_v]
        
        if return_attention:
            return output, attention_weights
        return output, None


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention: Multiple parallel attention heads with different learned projections.
    
    MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
    where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
    
    Args:
        d_model: Total embedding dimension
        n_heads: Number of attention heads
        dropout: Dropout probability
        bias: Whether to use bias in projections
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        dropout: float = 0.1,
        bias: bool = True
    ):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=bias)
        
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        query: torch.Tensor,  # [batch, n_q, d_model]
        key: torch.Tensor,    # [batch, n_k, d_model]
        value: torch.Tensor,  # [batch, n_k, d_model]
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = query.size(0)
        
        # Linear projections and split into heads
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Expand mask for all heads if provided
        if mask is not None and mask.dim() == 3:
            mask = mask.unsqueeze(1)  # [batch, 1, n_q, n_k]
        
        # Apply attention
        attn_output, attention_weights = self.attention(Q, K, V, mask, return_attention)
        
        # Concatenate heads and apply output projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        output = self.dropout(output)
        
        return output, attention_weights


class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention (MQA): Share K, V projections across all heads.
    Only queries have separate projections per head.
    
    Benefits:
    - Faster inference (smaller KV cache for autoregressive generation)
    - 2-4x speedup with minimal quality loss
    
    Used in: PaLM, StarCoder, Falcon
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        dropout: float = 0.1,
        bias: bool = True
    ):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Multiple query projections (one per head)
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        
        # Single key and value projections (shared across heads)
        self.W_k = nn.Linear(d_model, self.d_k, bias=bias)
        self.W_v = nn.Linear(d_model, self.d_k, bias=bias)
        
        self.W_o = nn.Linear(d_model, d_model, bias=bias)
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = query.size(0)
        
        # Query: multiple heads
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Key, Value: single projection, expanded to all heads
        K = self.W_k(key).view(batch_size, -1, 1, self.d_k).expand(-1, -1, self.n_heads, -1).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, 1, self.d_k).expand(-1, -1, self.n_heads, -1).transpose(1, 2)
        
        if mask is not None and mask.dim() == 3:
            mask = mask.unsqueeze(1)
        
        attn_output, attention_weights = self.attention(Q, K, V, mask, return_attention)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        output = self.dropout(output)
        
        return output, attention_weights


class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA): Interpolate between MHA and MQA.
    Divide heads into groups, share K/V within each group.
    
    Args:
        n_groups: Number of KV groups (n_groups=1 β†’ MQA, n_groups=n_heads β†’ MHA)
    
    Used in: LLaMA-2, Mistral-7B
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_groups: int,
        dropout: float = 0.1,
        bias: bool = True
    ):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads % n_groups == 0, "n_heads must be divisible by n_groups"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        self.d_k = d_model // n_heads
        self.heads_per_group = n_heads // n_groups
        
        # Query: one projection per head
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        
        # Key, Value: one projection per group
        self.W_k = nn.Linear(d_model, n_groups * self.d_k, bias=bias)
        self.W_v = nn.Linear(d_model, n_groups * self.d_k, bias=bias)
        
        self.W_o = nn.Linear(d_model, d_model, bias=bias)
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = query.size(0)
        
        # Query: [batch, n_heads, seq, d_k]
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Key, Value: [batch, n_groups, seq, d_k] β†’ expand to [batch, n_heads, seq, d_k]
        K = self.W_k(key).view(batch_size, -1, self.n_groups, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_groups, self.d_k).transpose(1, 2)
        
        # Repeat each group to match number of query heads
        K = K.repeat_interleave(self.heads_per_group, dim=1)
        V = V.repeat_interleave(self.heads_per_group, dim=1)
        
        if mask is not None and mask.dim() == 3:
            mask = mask.unsqueeze(1)
        
        attn_output, attention_weights = self.attention(Q, K, V, mask, return_attention)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        output = self.dropout(output)
        
        return output, attention_weights


# ============================================================================
# 2. Positional Encoding Variants
# ============================================================================

class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding from "Attention is All You Need".
    
    PE(pos, 2i) = sin(pos / 10000^(2i/d))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
    
    Properties:
    - Deterministic (no parameters)
    - Can extrapolate to longer sequences
    - Relative positions via linear combinations
    """
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            x + positional_encoding
        """
        return x + self.pe[:, :x.size(1)]


class LearnedPositionalEmbedding(nn.Module):
    """
    Learned positional embeddings (lookup table).
    Better for fixed-length sequences, cannot extrapolate.
    """
    def __init__(self, max_len: int, d_model: int):
        super().__init__()
        self.embedding = nn.Embedding(max_len, d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = x.size(0), x.size(1)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        return x + self.embedding(positions)


class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) from "RoFormer" (Su et al., 2021).
    
    Applies rotation to query and key based on position:
    - Incorporates relative position naturally via rotation
    - Better extrapolation to longer sequences
    - Used in: GPT-NeoX, PaLM, LLaMA
    
    Math:
        RoPE(x, m) = [x_0, x_1, ...] rotated by angles ΞΈ_i * m
        where ΞΈ_i = 10000^(-2i/d)
    """
    def __init__(self, dim: int, max_len: int = 2048, base: int = 10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute cos, sin for all positions
        t = torch.arange(max_len, dtype=torch.float)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
        
    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """Rotate half the hidden dims of the input."""
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat([-x2, x1], dim=-1)
    
    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply rotary embedding to queries and keys.
        
        Args:
            q: [batch, n_heads, seq_len, d_k]
            k: [batch, n_heads, seq_len, d_k]
        Returns:
            q_rotated, k_rotated
        """
        seq_len = q.size(2)
        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]
        
        q_rotated = (q * cos) + (self.rotate_half(q) * sin)
        k_rotated = (k * cos) + (self.rotate_half(k) * sin)
        
        return q_rotated, k_rotated


# ============================================================================
# 3. Efficient Attention Variants
# ============================================================================

class SlidingWindowAttention(nn.Module):
    """
    Sliding Window Attention: Attend to local window + optional global tokens.
    
    Complexity: O(n * window_size) instead of O(n^2)
    
    Used in: Longformer, BigBird
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        window_size: int,
        global_tokens: int = 0,
        dropout: float = 0.1
    ):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.window_size = window_size
        self.global_tokens = global_tokens
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def _create_sliding_window_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Create mask for sliding window + global tokens."""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
        
        # Sliding window
        for i in range(seq_len):
            start = max(0, i - self.window_size)
            end = min(seq_len, i + self.window_size + 1)
            mask[i, start:end] = True
        
        # Global tokens (first global_tokens positions attend to everything and vice versa)
        if self.global_tokens > 0:
            mask[:self.global_tokens, :] = True
            mask[:, :self.global_tokens] = True
        
        return mask
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()
        
        # Projections
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Create sliding window mask
        window_mask = self._create_sliding_window_mask(seq_len, x.device)
        if mask is not None:
            window_mask = window_mask & mask
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        scores = scores.masked_fill(~window_mask.unsqueeze(0).unsqueeze(1), float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        output = torch.matmul(attn, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        return output


class LinearAttention(nn.Module):
    """
    Linear Attention via kernel approximation (Katharopoulos et al., 2020).
    
    Key idea: Approximate softmax with feature map Ο†
    Attention(Q, K, V) β‰ˆ Ο†(Q)(Ο†(K)^T V) / (Ο†(Q)(Ο†(K)^T 1))
    
    Complexity: O(n * d^2) instead of O(n^2 * d)
    
    Trade-off: Faster for long sequences, but approximates softmax
    """
    def __init__(self, d_model: int, n_heads: int, eps: float = 1e-6):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.eps = eps
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def feature_map(self, x: torch.Tensor) -> torch.Tensor:
        """Feature map Ο†(x) = elu(x) + 1"""
        return F.elu(x) + 1
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()
        
        # Projections
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply feature map
        Q = self.feature_map(Q)  # [batch, heads, seq, d_k]
        K = self.feature_map(K)
        
        # Linear attention: Ο†(Q) * (Ο†(K)^T * V)
        KV = torch.matmul(K.transpose(-2, -1), V)  # [batch, heads, d_k, d_k]
        Z = torch.matmul(Q, KV)  # [batch, heads, seq, d_k]
        
        # Normalization: divide by sum of attention weights
        K_sum = K.sum(dim=-2, keepdim=True)  # [batch, heads, 1, d_k]
        normalizer = torch.matmul(Q, K_sum.transpose(-2, -1))  # [batch, heads, seq, 1]
        output = Z / (normalizer + self.eps)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        return output


# ============================================================================
# 4. Cross-Attention and Specialized Variants
# ============================================================================

class CrossAttention(nn.Module):
    """
    Cross-Attention: Q from one sequence, K/V from another.
    
    Used in:
    - Encoder-decoder Transformers (target attends to source)
    - DETR (queries attend to image features)
    - Diffusion models (text conditions image generation)
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_context: Optional[int] = None,
        dropout: float = 0.1
    ):
        super().__init__()
        d_context = d_context or d_model
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_context, d_model)
        self.W_v = nn.Linear(d_context, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        query: torch.Tensor,      # [batch, n_q, d_model]
        context: torch.Tensor,    # [batch, n_ctx, d_context]
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size = query.size(0)
        
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(context).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(context).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        if mask is not None and mask.dim() == 3:
            mask = mask.unsqueeze(1)
        
        attn_output, _ = self.attention(Q, K, V, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        output = self.dropout(output)
        
        return output


# ============================================================================
# 5. Mask Utilities
# ============================================================================

def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """
    Create padding mask: True for valid tokens, False for padding.
    
    Args:
        seq: [batch, seq_len]
        pad_idx: Index representing padding
    Returns:
        mask: [batch, 1, seq_len] (broadcastable to attention dimensions)
    """
    return (seq != pad_idx).unsqueeze(1)


def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    Create causal (lower triangular) mask for autoregressive models.
    
    Returns:
        mask: [seq_len, seq_len] (True = can attend, False = masked)
    """
    return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))


def combine_masks(*masks: torch.Tensor) -> torch.Tensor:
    """Combine multiple masks with logical AND."""
    result = masks[0]
    for mask in masks[1:]:
        result = result & mask
    return result


# ============================================================================
# 6. Complete Transformer Blocks
# ============================================================================

class TransformerEncoderLayer(nn.Module):
    """
    Standard Transformer encoder layer with self-attention + FFN.
    
    Architecture:
        x -> LayerNorm -> MultiHeadAttention -> Residual
          -> LayerNorm -> FeedForward -> Residual
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = 'relu'
    ):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU() if activation == 'relu' else nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Self-attention with residual
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_output)
        
        # Feed-forward with residual
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        
        return x


class TransformerDecoderLayer(nn.Module):
    """
    Transformer decoder layer with masked self-attention + cross-attention + FFN.
    
    Architecture:
        x -> LayerNorm -> MaskedSelfAttention -> Residual
          -> LayerNorm -> CrossAttention(encoder) -> Residual
          -> LayerNorm -> FeedForward -> Residual
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = 'relu'
    ):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = CrossAttention(d_model, n_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU() if activation == 'relu' else nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        self_attn_mask: Optional[torch.Tensor] = None,
        cross_attn_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Masked self-attention
        attn_output, _ = self.self_attn(x, x, x, self_attn_mask)
        x = self.norm1(x + attn_output)
        
        # Cross-attention to encoder
        cross_output = self.cross_attn(x, encoder_output, cross_attn_mask)
        x = self.norm2(x + cross_output)
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + ff_output)
        
        return x


# ============================================================================
# 7. Demo and Benchmarking
# ============================================================================

def demo_attention_variants():
    """Compare different attention mechanisms."""
    print("=" * 80)
    print("Attention Mechanisms Comparison")
    print("=" * 80)
    
    # Setup
    batch_size, seq_len, d_model, n_heads = 2, 128, 512, 8
    x = torch.randn(batch_size, seq_len, d_model)
    
    # 1. Multi-Head Attention (standard)
    mha = MultiHeadAttention(d_model, n_heads)
    mha_out, _ = mha(x, x, x)
    print(f"\n1. Multi-Head Attention (MHA)")
    print(f"   Input: {x.shape}, Output: {mha_out.shape}")
    print(f"   Parameters: {sum(p.numel() for p in mha.parameters()):,}")
    
    # 2. Multi-Query Attention (fast inference)
    mqa = MultiQueryAttention(d_model, n_heads)
    mqa_out, _ = mqa(x, x, x)
    print(f"\n2. Multi-Query Attention (MQA)")
    print(f"   Input: {x.shape}, Output: {mqa_out.shape}")
    print(f"   Parameters: {sum(p.numel() for p in mqa.parameters()):,}")
    print(f"   Parameter reduction vs MHA: {(1 - sum(p.numel() for p in mqa.parameters()) / sum(p.numel() for p in mha.parameters())) * 100:.1f}%")
    
    # 3. Grouped-Query Attention (hybrid)
    gqa = GroupedQueryAttention(d_model, n_heads, n_groups=4)
    gqa_out, _ = gqa(x, x, x)
    print(f"\n3. Grouped-Query Attention (GQA, 4 groups)")
    print(f"   Input: {x.shape}, Output: {gqa_out.shape}")
    print(f"   Parameters: {sum(p.numel() for p in gqa.parameters()):,}")
    
    # 4. Sliding Window Attention (efficient)
    swa = SlidingWindowAttention(d_model, n_heads, window_size=16, global_tokens=2)
    swa_out = swa(x)
    print(f"\n4. Sliding Window Attention (window=16)")
    print(f"   Input: {x.shape}, Output: {swa_out.shape}")
    print(f"   Computational complexity: O(n*w) vs O(nΒ²) for full attention")
    
    # 5. Linear Attention (kernel approximation)
    la = LinearAttention(d_model, n_heads)
    la_out = la(x)
    print(f"\n5. Linear Attention (kernel approximation)")
    print(f"   Input: {x.shape}, Output: {la_out.shape}")
    print(f"   Complexity: O(n*dΒ²) vs O(nΒ²*d) for full attention")
    
    # 6. Cross-Attention
    context = torch.randn(batch_size, 64, d_model)  # Different length context
    ca = CrossAttention(d_model, n_heads)
    ca_out = ca(x, context)
    print(f"\n6. Cross-Attention")
    print(f"   Query: {x.shape}, Context: {context.shape}, Output: {ca_out.shape}")
    
    # 7. Rotary Position Embedding
    rope = RotaryPositionalEmbedding(d_model // n_heads, max_len=512)
    Q = torch.randn(batch_size, n_heads, seq_len, d_model // n_heads)
    K = torch.randn(batch_size, n_heads, seq_len, d_model // n_heads)
    Q_rot, K_rot = rope(Q, K)
    print(f"\n7. Rotary Position Embedding (RoPE)")
    print(f"   Q before: {Q.shape}, after: {Q_rot.shape}")
    print(f"   Encodes relative positions via rotation")


def benchmark_attention_speed():
    """Benchmark different attention mechanisms for speed."""
    import time
    
    print("\n" + "=" * 80)
    print("Attention Speed Benchmark (1000 forward passes)")
    print("=" * 80)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    batch_size, seq_len, d_model, n_heads = 4, 512, 512, 8
    x = torch.randn(batch_size, seq_len, d_model, device=device)
    
    models = {
        'MHA': MultiHeadAttention(d_model, n_heads).to(device),
        'MQA': MultiQueryAttention(d_model, n_heads).to(device),
        'GQA': GroupedQueryAttention(d_model, n_heads, n_groups=4).to(device),
        'Linear': LinearAttention(d_model, n_heads).to(device),
    }
    
    n_iters = 100 if device.type == 'cuda' else 10
    
    for name, model in models.items():
        model.eval()
        with torch.no_grad():
            # Warmup
            for _ in range(10):
                if name in ['MHA', 'MQA', 'GQA']:
                    _ = model(x, x, x)
                else:
                    _ = model(x)
            
            # Benchmark
            if device.type == 'cuda':
                torch.cuda.synchronize()
            start = time.time()
            
            for _ in range(n_iters):
                if name in ['MHA', 'MQA', 'GQA']:
                    out, _ = model(x, x, x)
                else:
                    out = model(x)
            
            if device.type == 'cuda':
                torch.cuda.synchronize()
            elapsed = time.time() - start
            
            print(f"{name:15s}: {elapsed*1000/n_iters:6.2f} ms/iter")


def visualize_attention_patterns():
    """Visualize attention patterns for different mechanisms."""
    import matplotlib.pyplot as plt
    
    print("\n" + "=" * 80)
    print("Attention Pattern Visualization")
    print("=" * 80)
    
    seq_len, d_model, n_heads = 32, 64, 4
    x = torch.randn(1, seq_len, d_model)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Attention Patterns', fontsize=16)
    
    # 1. Full attention
    mha = MultiHeadAttention(d_model, n_heads)
    _, attn = mha(x, x, x, return_attention=True)
    axes[0, 0].imshow(attn[0, 0].detach().numpy(), cmap='viridis')
    axes[0, 0].set_title('Multi-Head (Full)')
    axes[0, 0].set_xlabel('Key')
    axes[0, 0].set_ylabel('Query')
    
    # 2. Causal mask
    causal_mask = create_causal_mask(seq_len, x.device)
    _, attn_causal = mha(x, x, x, mask=causal_mask, return_attention=True)
    axes[0, 1].imshow(attn_causal[0, 0].detach().numpy(), cmap='viridis')
    axes[0, 1].set_title('Causal (Autoregressive)')
    
    # 3. Sliding window
    swa = SlidingWindowAttention(d_model, n_heads, window_size=8)
    window_mask = swa._create_sliding_window_mask(seq_len, x.device)
    axes[0, 2].imshow(window_mask.cpu().numpy(), cmap='binary')
    axes[0, 2].set_title('Sliding Window (w=8)')
    
    # 4. Attention entropy across layers
    entropies = []
    for _ in range(6):  # Simulate 6 layers
        _, attn = mha(x, x, x, return_attention=True)
        attn_probs = attn[0, 0].detach()
        entropy = -(attn_probs * torch.log(attn_probs + 1e-9)).sum(dim=-1).mean()
        entropies.append(entropy.item())
    
    axes[1, 0].plot(entropies, marker='o')
    axes[1, 0].set_xlabel('Layer')
    axes[1, 0].set_ylabel('Average Attention Entropy')
    axes[1, 0].set_title('Attention Entropy by Layer')
    axes[1, 0].grid(True)
    
    # 5. Head specialization
    _, attn_multi = mha(x, x, x, return_attention=True)
    for head_idx in range(min(4, n_heads)):
        axes[1, 1].plot(attn_multi[0, head_idx, 0].detach().numpy(), label=f'Head {head_idx}')
    axes[1, 1].set_xlabel('Position')
    axes[1, 1].set_ylabel('Attention Weight')
    axes[1, 1].set_title('Different Head Patterns (query pos=0)')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    # 6. Position bias with RoPE
    rope = RotaryPositionalEmbedding(d_model // n_heads)
    Q = torch.randn(1, n_heads, seq_len, d_model // n_heads)
    K = torch.randn(1, n_heads, seq_len, d_model // n_heads)
    
    # Without RoPE
    scores_no_rope = torch.matmul(Q, K.transpose(-2, -1))[0, 0].detach().numpy()
    
    # With RoPE
    Q_rope, K_rope = rope(Q, K)
    scores_rope = torch.matmul(Q_rope, K_rope.transpose(-2, -1))[0, 0].detach().numpy()
    
    axes[1, 2].plot(scores_no_rope[seq_len//2], label='No RoPE', alpha=0.7)
    axes[1, 2].plot(scores_rope[seq_len//2], label='With RoPE', alpha=0.7)
    axes[1, 2].set_xlabel('Position')
    axes[1, 2].set_ylabel('Similarity Score')
    axes[1, 2].set_title('RoPE Effect (query at middle)')
    axes[1, 2].legend()
    axes[1, 2].grid(True)
    
    plt.tight_layout()
    plt.savefig('attention_patterns.png', dpi=150, bbox_inches='tight')
    print("\nVisualization saved as 'attention_patterns.png'")


# ============================================================================
# 8. Performance Comparison Table
# ============================================================================

def print_performance_comparison():
    """Print comprehensive comparison table."""
    print("\n" + "=" * 80)
    print("Attention Mechanism Comparison Table")
    print("=" * 80)
    
    comparison = """
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Mechanism               β”‚ Time         β”‚ Space        β”‚ Quality    β”‚ Use Case    β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Multi-Head (MHA)        β”‚ O(nΒ²d)       β”‚ O(nΒ²)        β”‚ Excellent  β”‚ Standard    β”‚
    β”‚ Multi-Query (MQA)       β”‚ O(nΒ²d)       β”‚ O(nΒ²)        β”‚ Very Good  β”‚ Fast Inf.   β”‚
    β”‚ Grouped-Query (GQA)     β”‚ O(nΒ²d)       β”‚ O(nΒ²)        β”‚ Very Good  β”‚ Balanced    β”‚
    β”‚ Sliding Window          β”‚ O(nwd)       β”‚ O(nw)        β”‚ Good       β”‚ Long Seq.   β”‚
    β”‚ Linear Attention        β”‚ O(ndΒ²)       β”‚ O(nd)        β”‚ Fair       β”‚ Very Long   β”‚
    β”‚ Flash Attention         β”‚ O(nΒ²d)       β”‚ O(n)         β”‚ Excellent  β”‚ Memory Opt. β”‚
    β”‚ Sparse (Learned)        β”‚ O(nkd)       β”‚ O(nk)        β”‚ Good       β”‚ Task-Spec.  β”‚
    β”‚ Cross-Attention         β”‚ O(n_qΒ·n_kΒ·d) β”‚ O(n_qΒ·n_k)   β”‚ Excellent  β”‚ Multimodal  β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    n = sequence length, d = dimension, w = window size, k = sparsity
    
    Parameter Efficiency (KV cache size for generation):
    - MHA: n_heads Γ— d_k Γ— 2 (K + V for each head)
    - MQA: 1 Γ— d_k Γ— 2 (shared K, V)
    - GQA (4 groups): 4 Γ— d_k Γ— 2 (intermediate)
    
    Inference Speed (relative, seq_len=2048):
    - MHA:     1.00Γ— (baseline)
    - MQA:     1.3-1.5Γ— faster
    - GQA:     1.1-1.3Γ— faster
    - Linear:  2-4Γ— faster (long sequences)
    - Flash:   2-4Γ— faster (memory bound tasks)
    
    When to Use What:
    βœ“ MHA: Default choice, best quality
    βœ“ MQA: Inference-heavy workloads (chatbots, code completion)
    βœ“ GQA: Balance between MHA and MQA (LLaMA-2, Mistral)
    βœ“ Sliding Window: Document processing, long context (>4k tokens)
    βœ“ Linear: Extremely long sequences (>16k tokens)
    βœ“ Flash: Memory-constrained training, long sequences
    βœ“ Cross: Encoder-decoder, multimodal, conditioning
    """
    print(comparison)


# ============================================================================
# 9. Main Execution
# ============================================================================

if __name__ == "__main__":
    print("Advanced Attention Mechanisms - Comprehensive Implementation\n")
    
    # Demo different attention variants
    demo_attention_variants()
    
    # Performance comparison table
    print_performance_comparison()
    
    # Benchmark speed
    benchmark_attention_speed()
    
    # Visualize patterns
    visualize_attention_patterns()
    
    print("\n" + "=" * 80)
    print("Implementation complete! All attention mechanisms ready for use.")
    print("=" * 80)