import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Memory Networks ConceptΒΆ

Architecture:ΒΆ

  • Memory Matrix: \(M_t \in \mathbb{R}^{N \times M}\)

  • Controller: LSTM/GRU

  • Read Head: Attention-based retrieval

  • Write Head: Memory updates

Read Operation:ΒΆ

\[r_t = \sum_{i=1}^N w_t^r(i) M_t(i)\]

where \(w_t^r\) is read attention weight.

πŸ“š Reference Materials:

class MemoryBank(nn.Module):
    """External memory module."""
    
    def __init__(self, n_slots, slot_size):
        super().__init__()
        self.n_slots = n_slots
        self.slot_size = slot_size
        
        # Initialize memory
        self.register_buffer('memory', torch.zeros(n_slots, slot_size))
    
    def read(self, weights):
        """Read from memory using attention weights."""
        # weights: (batch, n_slots)
        # memory: (n_slots, slot_size)
        return torch.mm(weights, self.memory)  # (batch, slot_size)
    
    def write(self, weights, erase_vector, add_vector):
        """Write to memory."""
        # Erase
        erase = torch.outer(weights.squeeze(), erase_vector)  # (n_slots, slot_size)
        self.memory = self.memory * (1 - erase)
        
        # Add
        add = torch.outer(weights.squeeze(), add_vector)
        self.memory = self.memory + add
    
    def reset(self):
        self.memory.zero_()

print("MemoryBank defined")

2. Content-Based AddressingΒΆ

Similarity Measure:ΒΆ

\[K(k, M_t(i)) = \frac{k \cdot M_t(i)}{\|k\| \|M_t(i)\|}\]

Attention:ΒΆ

\[w_t^c(i) = \frac{\exp(\beta_t K(k_t, M_t(i)))}{\sum_j \exp(\beta_t K(k_t, M_t(j)))}\]
def content_addressing(key, memory, strength):
    """Content-based addressing."""
    # Cosine similarity
    key = key / (torch.norm(key, dim=1, keepdim=True) + 1e-8)
    memory_norm = memory / (torch.norm(memory, dim=1, keepdim=True) + 1e-8)
    
    similarity = torch.mm(key, memory_norm.t())  # (batch, n_slots)
    
    # Apply strength and softmax
    weights = F.softmax(strength * similarity, dim=1)
    
    return weights

print("Content addressing defined")

Neural Turing MachineΒΆ

The Neural Turing Machine (NTM) augments a neural network controller with an external differentiable memory matrix that can be read from and written to via attention-based addressing. The read head computes a weighted sum over memory rows using a soft attention vector, and the write head uses separate erase and add vectors to modify memory contents. Addressing can be content-based (attend to memory rows that match a key vector) or location-based (shift attention to adjacent positions), enabling both associative retrieval and sequential access patterns. This architecture can learn algorithmic tasks like copying, sorting, and sequence recall that are impossible for standard RNNs.

class NTMController(nn.Module):
    """Controller for NTM."""
    
    def __init__(self, input_size, hidden_size, output_size, memory_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.memory_size = memory_size
        
        # LSTM controller
        self.lstm = nn.LSTMCell(input_size + memory_size, hidden_size)
        
        # Read parameters
        self.read_key = nn.Linear(hidden_size, memory_size)
        self.read_strength = nn.Linear(hidden_size, 1)
        
        # Write parameters
        self.write_key = nn.Linear(hidden_size, memory_size)
        self.write_strength = nn.Linear(hidden_size, 1)
        self.erase = nn.Linear(hidden_size, memory_size)
        self.add = nn.Linear(hidden_size, memory_size)
        
        # Output
        self.fc_out = nn.Linear(hidden_size + memory_size, output_size)
    
    def forward(self, x, prev_state, prev_read):
        """
        Args:
            x: input (batch, input_size)
            prev_state: (h, c) from LSTM
            prev_read: previous read vector
        """
        # Concatenate input with previous read
        controller_input = torch.cat([x, prev_read], dim=1)
        
        # LSTM step
        h, c = self.lstm(controller_input, prev_state)
        
        # Read parameters
        k_read = torch.tanh(self.read_key(h))
        beta_read = F.softplus(self.read_strength(h))
        
        # Write parameters
        k_write = torch.tanh(self.write_key(h))
        beta_write = F.softplus(self.write_strength(h))
        erase_vec = torch.sigmoid(self.erase(h))
        add_vec = torch.tanh(self.add(h))
        
        # Output
        output = self.fc_out(torch.cat([h, prev_read], dim=1))
        
        return output, (h, c), k_read, beta_read, k_write, beta_write, erase_vec, add_vec

print("NTMController defined")

Complete NTMΒΆ

The complete NTM combines the controller (an LSTM or feedforward network), the memory matrix, and the read/write heads into an end-to-end differentiable system. At each time step, the controller receives the current input and the previous read vector, produces a hidden state, and emits addressing parameters for the read and write heads. The memory is then updated (write) and queried (read), with the read output fed back to the controller for the next step. All operations – attention, memory read, memory write – are differentiable, so the entire system is trained with standard backpropagation through time.

class NeuralTuringMachine(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_slots=128, slot_size=20):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_slots = n_slots
        self.slot_size = slot_size
        
        self.memory = MemoryBank(n_slots, slot_size)
        self.controller = NTMController(input_size, hidden_size, output_size, slot_size)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, input_size)
        """
        batch_size, seq_len, _ = x.size()
        
        # Initialize
        h = torch.zeros(batch_size, self.hidden_size).to(x.device)
        c = torch.zeros(batch_size, self.hidden_size).to(x.device)
        read = torch.zeros(batch_size, self.slot_size).to(x.device)
        self.memory.reset()
        
        outputs = []
        read_weights_history = []
        write_weights_history = []
        
        for t in range(seq_len):
            # Controller step
            out, (h, c), k_r, beta_r, k_w, beta_w, erase, add = self.controller(
                x[:, t], (h, c), read
            )
            
            # Read from memory
            w_read = content_addressing(k_r, self.memory.memory, beta_r)
            read = self.memory.read(w_read)
            
            # Write to memory
            w_write = content_addressing(k_w, self.memory.memory, beta_w)
            self.memory.write(w_write, erase.squeeze(), add.squeeze())
            
            outputs.append(out)
            read_weights_history.append(w_read)
            write_weights_history.append(w_write)
        
        outputs = torch.stack(outputs, dim=1)
        
        return outputs, read_weights_history, write_weights_history

print("NeuralTuringMachine defined")

Copy TaskΒΆ

The copy task is the canonical benchmark for memory-augmented networks: the model receives a sequence of binary vectors, followed by a delimiter, and must reproduce the sequence from memory. Solving this task requires learning to write each input to a distinct memory location (using location-based addressing with incremental shifts), then read them back in order. Standard RNNs/LSTMs struggle with this task for long sequences because they must compress the entire input into a fixed-size hidden state, whereas the NTM can use its external memory as a buffer with capacity proportional to the memory matrix size.

def generate_copy_data(batch_size, seq_len, n_bits):
    """Generate copy task data."""
    # Random binary sequence
    seq = torch.randint(0, 2, (batch_size, seq_len, n_bits)).float()
    
    # Input: sequence + delimiter + zeros
    delimiter = torch.zeros(batch_size, 1, n_bits)
    zeros = torch.zeros(batch_size, seq_len, n_bits)
    
    input_seq = torch.cat([seq, delimiter, zeros], dim=1)
    
    # Target: zeros + sequence
    target_seq = torch.cat([zeros, delimiter, seq], dim=1)
    
    return input_seq, target_seq

# Test data generation
x, y = generate_copy_data(2, 5, 8)
print(f"Input shape: {x.shape}, Target shape: {y.shape}")

Train NTMΒΆ

Training the NTM on the copy task uses standard supervised learning: the model’s output at each time step after the delimiter is compared to the target via binary cross-entropy loss. Learning curves typically show a phase transition: the model struggles for many epochs as it learns the addressing mechanism, then rapidly achieves near-perfect performance once it discovers the correct read/write strategy. Gradient clipping is essential because the backpropagation-through-time gradients can explode when memory addressing weights are near the boundaries of the softmax.

def train_ntm(model, n_epochs=100, seq_len=5, n_bits=8):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    losses = []
    
    for epoch in range(n_epochs):
        # Generate batch
        x, y = generate_copy_data(32, seq_len, n_bits)
        x, y = x.to(device), y.to(device)
        
        # Forward
        output, _, _ = model(x)
        
        # Loss (binary cross-entropy)
        loss = F.binary_cross_entropy_with_logits(output, y)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
        losses.append(loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
    
    return losses

# Train
ntm = NeuralTuringMachine(
    input_size=8, hidden_size=100, output_size=8,
    n_slots=128, slot_size=20
).to(device)

losses = train_ntm(ntm, n_epochs=100, seq_len=5, n_bits=8)

Visualize ResultsΒΆ

Visualizing the memory access patterns reveals how the NTM has learned to solve the copy task. During the input phase, write attention weights should show a sequential pattern (writing to consecutive memory locations). During the output phase, read attention weights should retrace the same sequential pattern, retrieving stored values in order. These attention visualizations confirm that the NTM has learned an algorithm, not merely memorized training examples – a qualitative difference from standard neural network learning.

# Test
ntm.eval()
x_test, y_test = generate_copy_data(1, 8, 8)
x_test = x_test.to(device)

with torch.no_grad():
    pred, read_weights, write_weights = ntm(x_test)
    pred = torch.sigmoid(pred)

# Plot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curve
axes[0, 0].plot(losses)
axes[0, 0].set_xlabel('Epoch', fontsize=11)
axes[0, 0].set_ylabel('Loss', fontsize=11)
axes[0, 0].set_title('Training Loss', fontsize=12)
axes[0, 0].grid(True, alpha=0.3)

# Input
axes[0, 1].imshow(x_test[0].cpu().T, aspect='auto', cmap='Blues')
axes[0, 1].set_xlabel('Time', fontsize=11)
axes[0, 1].set_ylabel('Bit', fontsize=11)
axes[0, 1].set_title('Input Sequence', fontsize=12)

# Prediction
axes[1, 0].imshow(pred[0].cpu().T, aspect='auto', cmap='Blues')
axes[1, 0].set_xlabel('Time', fontsize=11)
axes[1, 0].set_ylabel('Bit', fontsize=11)
axes[1, 0].set_title('Output Prediction', fontsize=12)

# Read weights
read_w = torch.stack(read_weights).squeeze().cpu().numpy()
axes[1, 1].imshow(read_w.T[:50], aspect='auto', cmap='Blues')
axes[1, 1].set_xlabel('Time', fontsize=11)
axes[1, 1].set_ylabel('Memory Slot', fontsize=11)
axes[1, 1].set_title('Read Attention Weights', fontsize=12)

plt.tight_layout()
plt.show()

SummaryΒΆ

Memory Networks:ΒΆ

Components:

  1. External memory matrix

  2. Controller (LSTM/GRU)

  3. Read/write heads

  4. Addressing mechanisms

Addressing:

  • Content-based (similarity)

  • Location-based (shift)

  • Interpolation (gating)

Operations:

  • Read: weighted sum

  • Write: erase + add

Applications:ΒΆ

  • Question answering

  • Program learning

  • Graph algorithms

  • Reasoning tasks

Variants:ΒΆ

  • NTM: Neural Turing Machine

  • DNC: Differentiable Neural Computer

  • End-to-End Memory Networks: Simpler, task-specific

Advanced Memory Networks TheoryΒΆ

1. Introduction to Memory-Augmented Neural NetworksΒΆ

1.1 MotivationΒΆ

Standard neural networks:

  • Parameters encode all knowledge

  • No explicit external memory

  • Struggle with tasks requiring explicit recall

Memory-augmented networks:

  • Explicit external memory module

  • Read/write mechanisms

  • Better for question answering, reasoning, few-shot learning

Key applications:

  • Question answering: Store facts, retrieve relevant information

  • One-shot learning: Store examples in memory

  • Reasoning: Multi-hop reasoning over facts

  • Dialog systems: Maintain conversation history

1.2 Memory Network ComponentsΒΆ

General architecture:

  1. Memory module M: Array of memory slots M = [m₁, mβ‚‚, …, mβ‚™]

    • Each mα΅’ ∈ β„α΅ˆ (vector representation)

  2. Input module I: Converts input to internal representation

    • I: x β†’ u (embed input)

  3. Generalization module G: Updates memory

    • G: (M_old, I(x)) β†’ M_new

  4. Output module O: Produces response

    • O: (o, u) β†’ r where o is retrieved memory

  5. Response module R: Converts to output format

    • R: r β†’ a (final answer)

Forward pass:

u = I(x)           # Embed input
o = O(M, u)        # Retrieve from memory
r = R(o, u)        # Generate response

2. End-to-End Memory Networks (MemN2N)ΒΆ

2.1 Architecture (Sukhbaatar et al., 2015)ΒΆ

Problem: Original memory networks require supervision at every layer. MemN2N is fully differentiable.

Input representation: Given story sentences {s₁, sβ‚‚, …, sβ‚™} and query q:

  • Embed each sentence: mα΅’ = Ξ£β±Ό Aβ‚“α΅’β±Ό (position encoding)

  • Embed query: u = Ξ£β±Ό Bqβ±Ό

  • A, B: Embedding matrices

Single layer attention:

  1. Attention weights (similarity):

pα΅’ = softmax(uα΅€mα΅’)
  1. Output representation: Different embedding C

cα΅’ = Ξ£β±Ό Cβ‚“α΅’β±Ό
  1. Weighted sum:

o = Ξ£α΅’ pα΅’cα΅’
  1. Update query:

uβ‚‚ = u + o

2.2 Multiple Hops (Layers)ΒΆ

K computational hops:

uΒΉ = u                          # Initial query
For k = 1 to K:
    p^k = softmax(u^k α΅€ m^k)    # Attention
    o^k = Ξ£α΅’ p^k_i c^k_i        # Read
    u^(k+1) = u^k + o^k         # Update (residual)

Final prediction:

Γ’ = softmax(W(u^(K+1)))

Key insight: Multi-hop allows reasoning over multiple facts.

2.3 Position EncodingΒΆ

Problem: Bag-of-words loses word order.

Solution: Position encoding

mα΅’ = Ξ£β±Ό lβ±Ό Β· Axα΅’β±Ό

where lβ±Ό encodes position j:

lβ±Ό = (1 - j/J) - (k/d)(1 - 2j/J)

J: sentence length, d: embedding dimension, k: dimension index.

Alternative: Use sinusoidal encoding (like Transformers).

2.4 TrainingΒΆ

Loss: Cross-entropy between predicted and true answer

L = -Ξ£β‚™ log p(aβ‚™ | xβ‚™)

Gradient flow: Backpropagate through all hops.

Weight tying:

  • Adjacent: A^(k+1) = C^k (reduce parameters)

  • Layer-wise (RNN-like): AΒΉ = AΒ² = … = A^K

2.5 Complexity AnalysisΒΆ

Time complexity per sample:

  • Attention computation: O(K Β· N Β· d) where K hops, N memory slots, d embedding dim

  • Memory size: O(N Β· d)

Space complexity:

  • Memory: O(N Β· d)

  • Parameters: O(V Β· d) where V vocabulary size

3. Neural Turing Machines (NTM)ΒΆ

3.1 Architecture (Graves et al., 2014)ΒΆ

Motivation: Learn algorithms from data. Turing-complete neural network.

Components:

  1. Controller: LSTM/Feedforward network

  2. Memory matrix: M_t ∈ ℝ^(NΓ—M) (N slots, M dimensions each)

  3. Read/write heads: Access memory via attention

3.2 Addressing MechanismsΒΆ

Two modes:

  1. Content-based addressing:

w^c_t[i] = exp(Ξ²_t K(k_t, M_t[i])) / Ξ£β±Ό exp(Ξ²_t K(k_t, M_t[j]))

where:

  • k_t: Key vector (from controller)

  • K: Similarity measure (cosine similarity)

  • Ξ²_t: Key strength (sharpness)

  1. Location-based addressing:

Interpolation (gate between content and previous):

w^g_t = g_t w^c_t + (1 - g_t) w_{t-1}

Convolutional shift:

w̃_t[i] = Σⱼ w^g_t[j] s_t[(i-j) mod N]

where s_t is shift distribution (e.g., s_t = [0.1, 0.8, 0.1] for left, stay, right).

Sharpening:

w_t[i] = w̃_t[i]^γ_t / Σⱼ w̃_t[j]^γ_t

where Ξ³_t β‰₯ 1 (sharpening factor).

3.3 Read and Write OperationsΒΆ

Read:

r_t = Ξ£α΅’ w^r_t[i] M_t[i]

Write (erase + add):

  1. Erase:

M̃_t[i] = M_{t-1}[i] (1 - w^w_t[i] e_t)

where e_t ∈ [0,1]^M is erase vector.

  1. Add:

M_t[i] = M̃_t[i] + w^w_t[i] a_t

where a_t ∈ ℝ^M is add vector.

Intuition: Erase old content, write new content (like LSTM gates).

3.4 ControllerΒΆ

LSTM controller outputs:

  • Key vectors: k^r_t, k^w_t

  • Key strengths: Ξ²^r_t, Ξ²^w_t

  • Interpolation gates: g^r_t, g^w_t

  • Shift distributions: s^r_t, s^w_t

  • Sharpening: Ξ³^r_t, Ξ³^w_t

  • Erase vector: e_t

  • Add vector: a_t

Total parameters: Large (many heads Γ— many parameters per head).

3.5 TrainingΒΆ

Supervised learning:

  • Input/output sequences

  • Backpropagation through time (BPTT)

  • Gradient clipping essential

Tasks:

  • Copy, repeat copy, associative recall

  • Priority sort, dynamic N-grams

  • Learning simple algorithms

4. Differentiable Neural Computer (DNC)ΒΆ

4.1 Improvements over NTM (Graves et al., 2016)ΒΆ

Key enhancements:

  1. Dynamic memory allocation: Track memory usage, allocate free slots

  2. Temporal links: Maintain write order for sequential access

  3. Better addressing: Content + allocation + temporal

4.2 Memory Usage TrackingΒΆ

Usage vector: u_t ∈ [0,1]^N (how much each slot is used)

Update:

u_t = (u_{t-1} + w^w_t - u_{t-1} βŠ™ w^w_t) βŠ™ ψ_t

where ψ_t ∈ [0,1]^N is retention vector (free slots).

Free gates: f^i_t ∈ [0,1] for each read head i

ψ_t = Πᡒ (1 - f^i_t w^r,i_t)

Free list: Sort memory by usage (ascending)

Ο†_t = argsort(u_t)

Allocation weight:

a_t[Ο†_t[j]] = (1 - u_t[Ο†_t[j]]) Ξ β‚–<β±Ό u_t[Ο†_t[k]]

Allocates to least-used available slot.

4.4 Final Read WeightΒΆ

Combination of three modes:

w^r_t = Ο€^r_t[1] b_t + Ο€^r_t[2] c^r_t + Ο€^r_t[3] f_t

where Ο€^r_t ∈ Δ³ (simplex) determines mode mixture.

Read output:

r_t = M_t^T w^r_t

4.5 Write WeightΒΆ

Allocation + content:

w^w_t = g^w_t (g^a_t a_t + (1 - g^a_t) c^w_t)

where:

  • g^w_t: Write gate (whether to write)

  • g^a_t: Allocation gate (new vs content-addressed)

  • a_t: Allocation weight

  • c^w_t: Content weight

5. Key-Value Memory NetworksΒΆ

5.1 Architecture (Miller et al., 2016)ΒΆ

Motivation: Separate keys (for addressing) from values (retrieved content).

Memory structure:

  • Keys: K = {k₁, kβ‚‚, …, kβ‚™}

  • Values: V = {v₁, vβ‚‚, …, vβ‚™}

Retrieval:

  1. Compute attention over keys:

Ξ±_i = softmax(q^T W_K k_i)
  1. Retrieve weighted values:

o = Ξ£α΅’ Ξ±_i W_V v_i

Advantages:

  • Keys can be short (efficient retrieval)

  • Values can be rich (informative)

  • Allows hash-like access

5.2 Knowledge Base as MemoryΒΆ

Key-value pairs from knowledge graph:

  • Key: Subject-relation embedding

  • Value: Object embedding

Example:

Key: embed("Paris" + "capital_of")
Value: embed("France")

Query: β€œWhat is the capital of France?” β†’ Retrieves value associated with key

6. Memory Networks for Few-Shot LearningΒΆ

6.1 Matching Networks (Vinyals et al., 2016)ΒΆ

Idea: Memory = support set in few-shot learning.

Architecture:

Ε· = Ξ£α΅’ a(xΜ‚, xα΅’) yα΅’

where:

  • (xα΅’, yα΅’): Support examples in memory

  • xΜ‚: Query

  • a: Attention mechanism

Attention:

a(xΜ‚, xα΅’) = exp(c(f(xΜ‚), g(xα΅’))) / Ξ£β±Ό exp(c(f(xΜ‚), g(xβ±Ό)))

Full context embeddings:

  • f(xΜ‚): Bi-LSTM with attention over support set

  • g(xα΅’): Bi-LSTM encoding

6.2 Prototypical Networks as MemoryΒΆ

Memory = class prototypes:

c_k = (1/|S_k|) Σ_{(x,y)∈S_k} f_θ(x)

Retrieval = nearest prototype:

p(y=k|x) = softmax(-d(f_ΞΈ(x), c_k))

7. Transformer as Memory NetworkΒΆ

7.1 Self-Attention as Memory AccessΒΆ

Self-attention = content-addressable memory:

Keys, values from input:

K = XW_K,  V = XW_V,  Q = XW_Q

Attention (retrieval):

Attention(Q,K,V) = softmax(QK^T/√d_k) V

Properties:

  • Content-based: Attention determined by QΒ·K similarity

  • Soft retrieval: Weighted combination (vs hard selection)

  • Differentiable: End-to-end backprop

7.2 Transformers vs Memory NetworksΒΆ

Similarities:

  • External memory (key-value pairs)

  • Attention-based retrieval

  • Multi-hop reasoning (multiple layers)

Differences:

Aspect

Transformer

Memory Network

Memory

Entire sequence

Explicit slots

Capacity

O(N) (sequence length)

O(N) (slots)

Updates

No (read-only)

Yes (write operations)

Architecture

Uniform layers

Specialized components

7.3 Memory-Augmented TransformersΒΆ

Combine both:

  1. Compressive Transformer: Compress old activations into memory

  2. ∞-former: Unbounded long-term memory via retrieval

  3. Memorizing Transformers: kNN over past key-values

8. Advances in Memory MechanismsΒΆ

8.1 Sparse Access PatternsΒΆ

Problem: O(N) attention over all memory expensive.

Solutions:

  1. Product-key memory (PKM):

    • Key = concat(k₁, kβ‚‚) where k₁, kβ‚‚ from separate codebooks

    • Retrieve top-k from each codebook

    • Final candidates = k₁ Γ— kβ‚‚ (small)

    • Complexity: O(√N) instead of O(N)

  2. Locality-sensitive hashing (LSH):

    • Hash similar keys to same buckets

    • Only attend within bucket

    • Complexity: O(N/B) where B buckets

8.2 External Knowledge BasesΒΆ

Retrieval-augmented generation (RAG):

Architecture:

  1. Retriever: Finds relevant documents from corpus

    • Dense retrieval (DPR): Embed query, retrieve via similarity

    • Sparse retrieval (BM25): Keyword-based

  2. Generator: LLM generates answer conditioned on retrieved docs

    p(y|x) = Ξ£β‚– p(d_k|x) p(y|x, d_k)
    

Training:

  • End-to-end: Backprop through retrieval (MIPS approximation)

  • Pipeline: Train retriever, then generator separately

State-of-the-art:

  • REALM (2020): Retrieve from millions of documents

  • RAG (2020): Wikipedia as memory

  • Atlas (2022): Few-shot learning via retrieval

8.3 Learnable Memory TokensΒΆ

Perceiver, Set Transformers:

Learnable latent vectors:

Z ∈ ℝ^(MΓ—d)  (M << N)

Cross-attention to compress input:

Z = Attention(Q=Z, K=X, V=X)

Benefits:

  • Fixed-size bottleneck (memory)

  • O(NΒ·M) complexity instead of O(NΒ²)

  • Learnable compression

9. Training TechniquesΒΆ

9.1 Curriculum LearningΒΆ

Memory tasks benefit from curriculum:

  1. Length curriculum: Start with short sequences, increase gradually

  2. Hop curriculum: Single-hop first, then multi-hop

  3. Size curriculum: Small memory, then large

Example (MemN2N on bAbI):

  • 1K training: 93% accuracy

  • 10K training: 95.8% accuracy

  • Curriculum: 96.4% accuracy

9.2 RegularizationΒΆ

Memory networks prone to overfitting:

  1. Dropout on attention weights:

p̃ᡒ = dropout(pᡒ)
pΜƒα΅’ ← pΜƒα΅’ / Ξ£β±Ό pΜƒβ±Ό  (renormalize)
  1. Weight decay on memory embeddings

  2. Attention entropy regularization:

L_entropy = -Ξ£α΅’ pα΅’ log pα΅’

Encourages diverse attention.

9.3 Pre-trainingΒΆ

Transfer learning for memory networks:

  1. Pre-train embeddings: Word2Vec, GloVe, BERT

  2. Pre-train on related tasks: General QA β†’ specific domain

  3. Multi-task learning: Train on multiple QA datasets jointly

10. Evaluation and BenchmarksΒΆ

10.1 bAbI Tasks (Facebook)ΒΆ

20 synthetic reasoning tasks:

Examples:

  • Task 1: Single supporting fact

    • β€œMary went to the bathroom. John moved to the hallway. Where is Mary? β†’ bathroom”

  • Task 2: Two supporting facts

    • β€œJohn is in the playground. John picked up the football. Where is the football? β†’ playground”

  • Task 3: Three supporting facts (harder)

Metrics:

  • Per-task accuracy

  • Mean accuracy across tasks

  • Failed tasks (< 95% threshold)

Baseline results:

  • MemN2N: 0.5 failed tasks (almost perfect)

  • LSTM: 11.5 failed tasks

  • Human: 0 failed tasks

10.2 Question AnsweringΒΆ

SQuAD (reading comprehension):

  • Given paragraph + question β†’ extract answer span

  • Memory network: Store sentences, retrieve relevant ones

Results (SQuAD v1.1):

  • MemN2N: 77.2% F1

  • Bidirectional Attention Flow (BiDAF): 84.1% F1

  • BERT: 93.2% F1 (pre-training helps!)

Natural Questions (Google):

  • Real Google search queries

  • Long answers from Wikipedia

  • Memory networks competitive on retrieving relevant paragraphs

10.3 Few-Shot LearningΒΆ

Omniglot (character recognition):

  • Matching Networks: 98.1% (5-way 5-shot)

  • Prototypical Networks: 98.8%

miniImageNet:

  • Matching Networks: 46.6% (5-way 1-shot)

  • Memory-augmented networks: Store examples, retrieve similar

11. Architectural VariantsΒΆ

11.1 Recurrent MemoryΒΆ

MANN (Memory-Augmented Neural Network for Meta-Learning):

One-shot learning via external memory:

  1. Read from memory (similar past examples)

  2. Controller processes input + read

  3. Write to memory (store current example)

Key idea: Memory persists across episodes (unlike weights).

11.2 Hierarchical MemoryΒΆ

Multi-scale memory:

  1. Short-term: Recent context (high resolution)

  2. Medium-term: Compressed recent history

  3. Long-term: Semantic summaries

Example: Hierarchical Transformer

  • Level 1: Token-level (1K tokens)

  • Level 2: Sentence-level (100 sentences)

  • Level 3: Paragraph-level (10 paragraphs)

11.3 Associative MemoryΒΆ

Hopfield Networks as memory:

Modern Hopfield layer:

Output = softmax(Ξ²Xα΅€ΞΎ)α΅€ X

where:

  • X: Stored patterns

  • ΞΎ: Query pattern

  • Ξ²: Inverse temperature

Properties:

  • Retrieves pattern most similar to query

  • Exponential storage capacity (vs linear in classical Hopfield)

Integration with Transformers:

  • Hopfield layer β‰ˆ attention layer

  • Can replace or augment self-attention

12. Complexity and ScalabilityΒΆ

12.1 Time ComplexityΒΆ

Memory network forward pass:

Operation

Complexity

Embedding

O(L Β· d) where L input length

Attention (single hop)

O(N Β· d) where N memory slots

K hops

O(K Β· N Β· d)

Output

O(V) where V vocab size

Total: O(LΒ·d + KΒ·NΒ·d + V)

Bottleneck: Attention over large memory (N large).

12.2 Space ComplexityΒΆ

Storage:

  • Memory: O(N Β· d)

  • Parameters: O(V Β· d + dΒ²) (embeddings + transforms)

Activations (training):

  • Per hop: O(N) (attention weights)

  • K hops: O(K Β· N)

12.3 Scalability SolutionsΒΆ

Large-scale memory (millions of slots):

  1. Approximate nearest neighbors:

    • FAISS, ScaNN for fast retrieval

    • O(log N) or O(√N) lookup

  2. Memory pruning:

    • Keep top-k attended slots

    • Discard low-attention memories

  3. Hierarchical retrieval:

    • Cluster memory, retrieve cluster first

    • Then retrieve within cluster

  4. Distributed memory:

    • Shard memory across GPUs/machines

    • All-reduce for aggregation

13. ApplicationsΒΆ

13.1 Question Answering SystemsΒΆ

Conversational QA:

  • Store conversation history in memory

  • Retrieve relevant turns for context

  • Generate contextual response

Multi-hop reasoning:

  • HotpotQA: Requires reasoning over 2+ paragraphs

  • Memory network: Store all paragraphs, multi-hop retrieval

13.2 Language ModelingΒΆ

Long-range dependencies:

  • Standard Transformer: O(NΒ²) attention limits context

  • Memory-augmented: Extend context via external memory

Example: Transformer-XL

  • Segment-level recurrence

  • Cache previous segment’s hidden states

  • Attend to current + cached

13.3 Dialog SystemsΒΆ

Persistent memory of facts:

  • User preferences, past conversations

  • Retrieve relevant history for personalization

End-to-end dialog (MemN2N for dialog):

  • Memory = dialog history + knowledge base

  • Query = current user utterance

  • Response = generated from retrieved context

13.4 Program SynthesisΒΆ

Neural Turing Machine tasks:

  • Copy, reverse, sort sequences

  • Learn algorithms from input-output examples

Results:

  • NTM learns perfect copy (vs LSTM fails)

  • Generalizes to longer sequences

13.5 Visual Question AnsweringΒΆ

Image + question β†’ answer:

Memory structure:

  • Memory slots = image regions (from CNN features)

  • Query = question embedding

  • Retrieval = attend to relevant regions

Attention visualization:

  • Shows which image regions model focuses on

  • Interpretable decision-making

14. Recent Advances (2020-2024)ΒΆ

14.1 Infinite Memory TransformersΒΆ

∞-former (Martins et al., 2022):

  • Unbounded memory via kNN retrieval

  • Store all past (key, value) pairs

  • Retrieve top-k for each query

Memorizing Transformers (Wu et al., 2022):

  • External memory = all past activations

  • kNN retrieval from memory

  • Interpolate between attention and kNN

14.2 Retrieval-Augmented LLMsΒΆ

RETRO (DeepMind, 2022):

  • Retrieve chunks from trillion-token database

  • Cross-attend to retrieved chunks

  • 25Γ— fewer parameters for same performance

Atlas (Meta, 2022):

  • Few-shot learning via retrieval

  • Retrieves from Wikipedia

  • State-of-the-art on MMLU

14.3 Memory for Long ContextΒΆ

Compressive Transformers:

  • Compress old memories (activations)

  • Keep recent memories (full resolution)

  • Lossy compression for efficiency

Landmark attention:

  • Designate certain tokens as β€œlandmarks”

  • All tokens attend to landmarks

  • Reduces memory from O(NΒ²) to O(NΒ·L) where L << N

15. Limitations and ChallengesΒΆ

15.1 Known IssuesΒΆ

Scalability:

  • O(N) attention expensive for large memory

  • Trade-off: memory size vs speed

Catastrophic forgetting:

  • Overwriting important memories

  • Need memory management strategies

Interpretability:

  • Attention weights don’t always reflect reasoning

  • β€œAttention is not explanation” debate

Generalization:

  • May memorize training examples

  • Doesn’t generalize to truly novel situations

15.2 Open ProblemsΒΆ

Continual learning:

  • How to update memory without forgetting?

  • Meta-learning over memory update rules

Compositional reasoning:

  • Combine retrieved facts compositionally

  • Neural-symbolic integration

Efficient large-scale memory:

  • Billions of memory slots

  • Sub-linear retrieval time

16. Comparison with Human MemoryΒΆ

16.1 ParallelsΒΆ

Working memory:

  • Limited capacity (~7 items)

  • MemN2N attention similar (focuses on few slots)

Long-term memory:

  • Associative retrieval (content-addressable)

  • Consolidation (important memories strengthened)

Episodic memory:

  • NTM temporal links β‰ˆ sequential memory

16.2 DifferencesΒΆ

Human advantages:

  • Highly efficient encoding/retrieval

  • Rich compositional reasoning

  • Transfer to novel domains

Neural network advantages:

  • Perfect recall (no degradation)

  • Massive parallel access

  • Explicit gradient-based learning

17. Future DirectionsΒΆ

17.1 Research FrontiersΒΆ

Neurosymbolic memory:

  • Combine neural retrieval with symbolic reasoning

  • Logic rules + soft attention

Lifelong learning:

  • Memory that grows over lifetime

  • Selective consolidation/forgetting

Multi-modal memory:

  • Unified memory for text, images, audio, video

  • Cross-modal retrieval

17.2 Practical DeploymentΒΆ

Efficiency:

  • Low-rank approximations for attention

  • Quantization of memory slots

  • Pruning redundant memories

Privacy:

  • Federated memory (distributed across users)

  • Differential privacy for retrieval

18. Key TakeawaysΒΆ

  1. Memory augmentation enables explicit retrieval:

    • Better than encoding everything in parameters

    • Crucial for QA, reasoning, few-shot learning

  2. Attention = soft content-addressable memory:

    • Differentiable retrieval mechanism

    • Enables end-to-end learning

  3. Multi-hop reasoning requires multiple attention layers:

    • Each hop refines the query

    • Combines information from multiple sources

  4. Scalability is key challenge:

    • O(N) attention limits memory size

    • Solutions: sparse access, hierarchical retrieval, external KB

  5. Many architectures = variations on memory theme:

    • Transformers: Self-attention over sequence

    • MemN2N: Attention over explicit memory slots

    • NTM/DNC: Read/write operations

  6. Recent trend: Retrieval-augmented LLMs:

    • External knowledge bases as memory

    • kNN over trillions of tokens

    • State-of-the-art few-shot learning

19. ReferencesΒΆ

Foundational papers:

  • Weston et al. (2015): β€œMemory Networks” (ICLR)

  • Sukhbaatar et al. (2015): β€œEnd-To-End Memory Networks” (NeurIPS)

  • Graves et al. (2014): β€œNeural Turing Machines” (arXiv)

  • Graves et al. (2016): β€œHybrid Computing Using a Neural Network with Dynamic External Memory” (Nature)

Few-shot learning:

  • Vinyals et al. (2016): β€œMatching Networks for One Shot Learning” (NeurIPS)

  • Snell et al. (2017): β€œPrototypical Networks for Few-shot Learning” (NeurIPS)

Recent advances:

  • Borgeaud et al. (2022): β€œImproving Language Models by Retrieving from Trillions of Tokens” (ICML - RETRO)

  • Wu et al. (2022): β€œMemorizing Transformers” (ICLR)

  • Izacard et al. (2022): β€œAtlas: Few-shot Learning with Retrieval Augmented Language Models” (arXiv)

Surveys:

  • Santoro et al. (2018): β€œRelational Recurrent Neural Networks” (NeurIPS)

  • Graves et al. (2016): β€œNeural Turing Machines” tutorial

"""
Complete Memory Network Implementations
========================================
Includes: End-to-End Memory Networks (MemN2N), Neural Turing Machine (NTM),
Key-Value Memory, Matching Networks, external memory mechanisms.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# ============================================================================
# 1. End-to-End Memory Network (MemN2N)
# ============================================================================

class PositionEncoding(nn.Module):
    """
    Position encoding for MemN2N (Sukhbaatar et al., 2015).
    
    Encodes position j in sentence with vector l_j:
    l_kj = (1 - j/J) - (k/d)(1 - 2j/J)
    
    Args:
        sentence_len: Maximum sentence length J
        embed_dim: Embedding dimension d
    """
    def __init__(self, sentence_len, embed_dim):
        super(PositionEncoding, self).__init__()
        self.sentence_len = sentence_len
        self.embed_dim = embed_dim
        
        # Precompute position encoding
        encoding = torch.zeros(sentence_len, embed_dim)
        for j in range(sentence_len):
            for k in range(embed_dim):
                encoding[j, k] = (1 - (j+1)/sentence_len) - \
                                (k/embed_dim) * (1 - 2*(j+1)/sentence_len)
        
        self.register_buffer('encoding', encoding)
    
    def forward(self, x):
        """
        Args:
            x: Embeddings [batch, num_sentences, sentence_len, embed_dim]
        Returns:
            Position-encoded embeddings [batch, num_sentences, embed_dim]
        """
        # Sum over sentence length dimension weighted by position
        batch, num_sent, sent_len, d = x.size()
        encoding = self.encoding[:sent_len, :].unsqueeze(0).unsqueeze(0)
        return (x * encoding).sum(dim=2)


class MemN2NLayer(nn.Module):
    """
    Single hop of End-to-End Memory Network.
    
    Performs:
    1. Attention: p = softmax(u^T m)
    2. Read: o = Ξ£ p_i c_i
    3. Update: u' = u + o
    
    Args:
        embed_dim: Embedding dimension
        num_hops: Total number of hops (for weight tying)
        hop_idx: Index of this hop
    """
    def __init__(self, embed_dim, num_hops, hop_idx):
        super(MemN2NLayer, self).__init__()
        self.embed_dim = embed_dim
        self.hop_idx = hop_idx
    
    def forward(self, u, m, c):
        """
        Args:
            u: Query embedding [batch, embed_dim]
            m: Memory embeddings for attention [batch, num_sentences, embed_dim]
            c: Memory embeddings for output [batch, num_sentences, embed_dim]
        Returns:
            o: Output vector [batch, embed_dim]
            p: Attention weights [batch, num_sentences]
        """
        # Compute attention: p_i = softmax(u^T m_i)
        p = torch.matmul(m, u.unsqueeze(-1)).squeeze(-1)  # [batch, num_sentences]
        p = F.softmax(p, dim=-1)
        
        # Read from memory: o = Ξ£ p_i c_i
        o = torch.matmul(p.unsqueeze(1), c).squeeze(1)  # [batch, embed_dim]
        
        return o, p


class EndToEndMemoryNetwork(nn.Module):
    """
    End-to-End Memory Network (Sukhbaatar et al., 2015).
    
    Multi-hop reasoning over memory with differentiable attention.
    
    Args:
        vocab_size: Size of vocabulary
        embed_dim: Embedding dimension
        num_hops: Number of computational hops
        sentence_len: Maximum sentence length
        memory_size: Maximum number of memory slots
        weight_tying: 'adjacent' or 'layer_wise' or None
    """
    def __init__(self, vocab_size, embed_dim, num_hops=3, sentence_len=10,
                 memory_size=50, weight_tying='adjacent'):
        super(EndToEndMemoryNetwork, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_hops = num_hops
        self.sentence_len = sentence_len
        self.memory_size = memory_size
        self.weight_tying = weight_tying
        
        # Embeddings for each hop
        if weight_tying == 'layer_wise':
            # Share all embeddings across hops
            self.A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)])
            self.C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)])
        else:
            self.A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) 
                                   for _ in range(num_hops)])
            self.C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) 
                                   for _ in range(num_hops)])
        
        # Query embedding
        self.B = nn.Embedding(vocab_size, embed_dim)
        
        # Position encoding
        self.position_encoding = PositionEncoding(sentence_len, embed_dim)
        
        # Output layer
        self.W = nn.Linear(embed_dim, vocab_size)
        
        # Temporal encoding (optional)
        self.temporal_A = nn.Parameter(torch.randn(memory_size, embed_dim))
        self.temporal_C = nn.Parameter(torch.randn(memory_size, embed_dim))
    
    def forward(self, story, query):
        """
        Args:
            story: Story sentences [batch, num_sentences, sentence_len]
            query: Query sentence [batch, query_len]
        Returns:
            logits: Answer logits [batch, vocab_size]
            attentions: Attention weights per hop [batch, num_hops, num_sentences]
        """
        batch_size = story.size(0)
        num_sentences = story.size(1)
        
        # Embed query: u = Ξ£_j B q_j
        u = self.B(query)  # [batch, query_len, embed_dim]
        u = self.position_encoding(u.unsqueeze(1)).squeeze(1)  # [batch, embed_dim]
        
        attentions = []
        
        # Multi-hop reasoning
        for k in range(self.num_hops):
            # Get embeddings for this hop
            if self.weight_tying == 'layer_wise':
                A_k, C_k = self.A[0], self.C[0]
            elif self.weight_tying == 'adjacent' and k > 0:
                A_k = self.C[k-1]
                C_k = self.C[k]
            else:
                A_k, C_k = self.A[k], self.C[k]
            
            # Embed story for attention: m_i = Ξ£_j A x_{ij}
            m = A_k(story)  # [batch, num_sentences, sentence_len, embed_dim]
            m = self.position_encoding(m)  # [batch, num_sentences, embed_dim]
            
            # Add temporal encoding
            m = m + self.temporal_A[:num_sentences, :]
            
            # Embed story for output: c_i = Ξ£_j C x_{ij}
            c = C_k(story)  # [batch, num_sentences, sentence_len, embed_dim]
            c = self.position_encoding(c)  # [batch, num_sentences, embed_dim]
            c = c + self.temporal_C[:num_sentences, :]
            
            # Attention and read
            layer = MemN2NLayer(self.embed_dim, self.num_hops, k)
            o, p = layer(u, m, c)
            
            attentions.append(p)
            
            # Update query (residual)
            u = u + o
        
        # Final prediction
        logits = self.W(u)
        
        attentions = torch.stack(attentions, dim=1)  # [batch, num_hops, num_sentences]
        
        return logits, attentions


# ============================================================================
# 2. Neural Turing Machine (NTM) Components
# ============================================================================

class NTMMemory(nn.Module):
    """
    Memory module for Neural Turing Machine.
    
    Args:
        memory_size: Number of memory slots N
        memory_dim: Dimension of each memory slot M
    """
    def __init__(self, memory_size, memory_dim):
        super(NTMMemory, self).__init__()
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        
        # Initialize memory
        self.register_buffer('memory', torch.zeros(memory_size, memory_dim))
    
    def reset(self, batch_size):
        """Initialize memory for batch."""
        self.memory = torch.zeros(batch_size, self.memory_size, self.memory_dim,
                                  device=self.memory.device)
    
    def read(self, w):
        """
        Read from memory with attention weights.
        
        Args:
            w: Read weights [batch, memory_size]
        Returns:
            Read vector [batch, memory_dim]
        """
        return torch.matmul(w.unsqueeze(1), self.memory).squeeze(1)
    
    def write(self, w, e, a):
        """
        Write to memory (erase + add).
        
        Args:
            w: Write weights [batch, memory_size]
            e: Erase vector [batch, memory_dim]
            a: Add vector [batch, memory_dim]
        """
        # Erase: M_i = M_i * (1 - w_i * e)
        erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))  # [batch, N, M]
        self.memory = self.memory * (1 - erase)
        
        # Add: M_i = M_i + w_i * a
        add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))  # [batch, N, M]
        self.memory = self.memory + add


class NTMHead(nn.Module):
    """
    Read/Write head for Neural Turing Machine.
    
    Implements content-based + location-based addressing.
    
    Args:
        memory_size: Number of memory slots
        memory_dim: Dimension of each slot
        controller_dim: Dimension of controller output
    """
    def __init__(self, memory_size, memory_dim, controller_dim):
        super(NTMHead, self).__init__()
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        
        # Content addressing
        self.key_layer = nn.Linear(controller_dim, memory_dim)
        self.beta_layer = nn.Linear(controller_dim, 1)  # Key strength
        
        # Location addressing
        self.g_layer = nn.Linear(controller_dim, 1)  # Interpolation gate
        self.shift_layer = nn.Linear(controller_dim, 3)  # Shift distribution (left, stay, right)
        self.gamma_layer = nn.Linear(controller_dim, 1)  # Sharpening
        
        # Write-specific
        self.erase_layer = nn.Linear(controller_dim, memory_dim)
        self.add_layer = nn.Linear(controller_dim, memory_dim)
    
    def content_addressing(self, k, beta, memory):
        """
        Content-based addressing via cosine similarity.
        
        Args:
            k: Key vector [batch, memory_dim]
            beta: Key strength [batch, 1]
            memory: Memory matrix [batch, memory_size, memory_dim]
        Returns:
            Content weights [batch, memory_size]
        """
        # Cosine similarity
        k_norm = k / (k.norm(dim=1, keepdim=True) + 1e-8)
        mem_norm = memory / (memory.norm(dim=2, keepdim=True) + 1e-8)
        similarity = torch.matmul(mem_norm, k_norm.unsqueeze(-1)).squeeze(-1)
        
        # Softmax with beta
        w_c = F.softmax(beta * similarity, dim=1)
        return w_c
    
    def location_addressing(self, w_c, w_prev, g, s, gamma):
        """
        Location-based addressing (interpolation + shift + sharpen).
        
        Args:
            w_c: Content weights [batch, memory_size]
            w_prev: Previous weights [batch, memory_size]
            g: Interpolation gate [batch, 1]
            s: Shift distribution [batch, 3]
            gamma: Sharpening factor [batch, 1]
        Returns:
            Final weights [batch, memory_size]
        """
        # Interpolation
        g = torch.sigmoid(g)
        w_g = g * w_c + (1 - g) * w_prev
        
        # Convolutional shift
        w_shifted = self._circular_conv(w_g, s)
        
        # Sharpening
        gamma = 1 + F.softplus(gamma)
        w = w_shifted ** gamma
        w = w / (w.sum(dim=1, keepdim=True) + 1e-8)
        
        return w
    
    def _circular_conv(self, w, s):
        """Circular convolution for shift."""
        batch_size = w.size(0)
        shifted = torch.zeros_like(w)
        
        for b in range(batch_size):
            for i in range(self.memory_size):
                for j in range(3):
                    idx = (i + j - 1) % self.memory_size
                    shifted[b, i] += w[b, idx] * s[b, j]
        
        return shifted
    
    def forward(self, h, memory, w_prev):
        """
        Compute attention weights.
        
        Args:
            h: Controller output [batch, controller_dim]
            memory: Memory matrix [batch, memory_size, memory_dim]
            w_prev: Previous weights [batch, memory_size]
        Returns:
            w: Attention weights [batch, memory_size]
            read_vec: Read vector [batch, memory_dim] (if read head)
        """
        # Content addressing
        k = torch.tanh(self.key_layer(h))
        beta = F.softplus(self.beta_layer(h))
        w_c = self.content_addressing(k, beta, memory)
        
        # Location addressing
        g = self.g_layer(h)
        s = F.softmax(self.shift_layer(h), dim=1)
        gamma = self.gamma_layer(h)
        w = self.location_addressing(w_c, w_prev, g, s, gamma)
        
        return w


class NeuralTuringMachine(nn.Module):
    """
    Neural Turing Machine (Graves et al., 2014).
    
    Controller (LSTM) + External memory + Read/Write heads.
    
    Args:
        input_dim: Input dimension
        output_dim: Output dimension
        controller_dim: LSTM hidden dimension
        memory_size: Number of memory slots
        memory_dim: Dimension of each slot
        num_read_heads: Number of read heads
        num_write_heads: Number of write heads
    """
    def __init__(self, input_dim, output_dim, controller_dim=100,
                 memory_size=128, memory_dim=20, num_read_heads=1, num_write_heads=1):
        super(NeuralTuringMachine, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.controller_dim = controller_dim
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        self.num_read_heads = num_read_heads
        self.num_write_heads = num_write_heads
        
        # Controller (LSTM)
        controller_input_dim = input_dim + num_read_heads * memory_dim
        self.controller = nn.LSTMCell(controller_input_dim, controller_dim)
        
        # Memory
        self.memory = NTMMemory(memory_size, memory_dim)
        
        # Read heads
        self.read_heads = nn.ModuleList([
            NTMHead(memory_size, memory_dim, controller_dim)
            for _ in range(num_read_heads)
        ])
        
        # Write heads
        self.write_heads = nn.ModuleList([
            NTMHead(memory_size, memory_dim, controller_dim)
            for _ in range(num_write_heads)
        ])
        
        # Output layer
        output_input_dim = controller_dim + num_read_heads * memory_dim
        self.output_layer = nn.Linear(output_input_dim, output_dim)
    
    def forward(self, x, prev_state=None):
        """
        Single step forward.
        
        Args:
            x: Input [batch, input_dim]
            prev_state: Previous state (h_c, c_c, read_vecs, read_w, write_w)
        Returns:
            output: Output [batch, output_dim]
            state: New state
        """
        batch_size = x.size(0)
        
        if prev_state is None:
            # Initialize state
            h_c = torch.zeros(batch_size, self.controller_dim, device=x.device)
            c_c = torch.zeros(batch_size, self.controller_dim, device=x.device)
            read_vecs = [torch.zeros(batch_size, self.memory_dim, device=x.device)
                        for _ in range(self.num_read_heads)]
            read_w = [torch.zeros(batch_size, self.memory_size, device=x.device)
                     for _ in range(self.num_read_heads)]
            write_w = [torch.zeros(batch_size, self.memory_size, device=x.device)
                      for _ in range(self.num_write_heads)]
            self.memory.reset(batch_size)
        else:
            h_c, c_c, read_vecs, read_w, write_w = prev_state
        
        # Controller input: x + previous reads
        controller_input = torch.cat([x] + read_vecs, dim=1)
        
        # Controller forward
        h_c, c_c = self.controller(controller_input, (h_c, c_c))
        
        # Read from memory
        new_read_vecs = []
        new_read_w = []
        for i, head in enumerate(self.read_heads):
            w = head(h_c, self.memory.memory, read_w[i])
            r = self.memory.read(w)
            new_read_vecs.append(r)
            new_read_w.append(w)
        
        # Write to memory
        new_write_w = []
        for i, head in enumerate(self.write_heads):
            w = head(h_c, self.memory.memory, write_w[i])
            e = torch.sigmoid(head.erase_layer(h_c))
            a = torch.tanh(head.add_layer(h_c))
            self.memory.write(w, e, a)
            new_write_w.append(w)
        
        # Output
        output_input = torch.cat([h_c] + new_read_vecs, dim=1)
        output = self.output_layer(output_input)
        
        state = (h_c, c_c, new_read_vecs, new_read_w, new_write_w)
        
        return output, state


# ============================================================================
# 3. Key-Value Memory Network
# ============================================================================

class KeyValueMemoryNetwork(nn.Module):
    """
    Key-Value Memory Network (Miller et al., 2016).
    
    Separate keys (for addressing) from values (retrieved content).
    
    Args:
        key_dim: Dimension of keys
        value_dim: Dimension of values
        query_dim: Dimension of queries
        num_hops: Number of hops
    """
    def __init__(self, key_dim, value_dim, query_dim, num_hops=1):
        super(KeyValueMemoryNetwork, self).__init__()
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.query_dim = query_dim
        self.num_hops = num_hops
        
        # Key transformation
        self.W_K = nn.Linear(query_dim, key_dim)
        
        # Value transformation
        self.W_V = nn.Linear(value_dim, query_dim)
        
        # Output transformation
        self.W_O = nn.Linear(query_dim, query_dim)
    
    def forward(self, query, keys, values):
        """
        Args:
            query: Query vector [batch, query_dim]
            keys: Key matrix [batch, num_keys, key_dim]
            values: Value matrix [batch, num_values, value_dim]
        Returns:
            output: Retrieved information [batch, query_dim]
            attention: Attention weights [batch, num_keys]
        """
        q = query
        
        for _ in range(self.num_hops):
            # Transform query to key space
            q_k = self.W_K(q)  # [batch, key_dim]
            
            # Attention over keys
            scores = torch.matmul(keys, q_k.unsqueeze(-1)).squeeze(-1)  # [batch, num_keys]
            attention = F.softmax(scores, dim=-1)
            
            # Retrieve values
            v = torch.matmul(attention.unsqueeze(1), values).squeeze(1)  # [batch, value_dim]
            v = self.W_V(v)  # [batch, query_dim]
            
            # Update query (residual)
            q = q + v
        
        output = self.W_O(q)
        
        return output, attention


# ============================================================================
# 4. Matching Networks (for Few-Shot Learning)
# ============================================================================

class MatchingNetwork(nn.Module):
    """
    Matching Networks (Vinyals et al., 2016).
    
    Memory = support set, retrieval = attention-weighted sum.
    
    Args:
        input_dim: Input feature dimension
        hidden_dim: Hidden dimension
        num_classes: Number of classes (for output)
    """
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(MatchingNetwork, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        # Feature encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Attention mechanism (cosine similarity)
    
    def attention(self, query, support):
        """
        Cosine similarity attention.
        
        Args:
            query: Query embedding [batch, hidden_dim]
            support: Support embeddings [batch, num_support, hidden_dim]
        Returns:
            Attention weights [batch, num_support]
        """
        query_norm = query / (query.norm(dim=1, keepdim=True) + 1e-8)
        support_norm = support / (support.norm(dim=2, keepdim=True) + 1e-8)
        
        similarity = torch.matmul(support_norm, query_norm.unsqueeze(-1)).squeeze(-1)
        attention = F.softmax(similarity, dim=-1)
        
        return attention
    
    def forward(self, query, support, support_labels):
        """
        Args:
            query: Query examples [batch, input_dim]
            support: Support examples [batch, num_support, input_dim]
            support_labels: Support labels [batch, num_support] (class indices)
        Returns:
            predictions: Class probabilities [batch, num_classes]
        """
        # Encode
        query_embed = self.encoder(query)  # [batch, hidden_dim]
        support_embed = self.encoder(support.view(-1, self.input_dim))
        support_embed = support_embed.view(support.size(0), support.size(1), -1)
        
        # Attention
        attn = self.attention(query_embed, support_embed)  # [batch, num_support]
        
        # Weighted sum over support labels (one-hot encoded)
        support_one_hot = F.one_hot(support_labels, self.num_classes).float()
        predictions = torch.matmul(attn.unsqueeze(1), support_one_hot).squeeze(1)
        
        return predictions


# ============================================================================
# 5. Demonstrations
# ============================================================================

def demo_memn2n():
    """Demonstrate End-to-End Memory Network."""
    print("="*70)
    print("End-to-End Memory Network (MemN2N) Demo")
    print("="*70)
    
    # Hyperparameters
    vocab_size = 100
    embed_dim = 32
    num_hops = 3
    sentence_len = 10
    memory_size = 20
    batch_size = 4
    
    # Create model
    model = EndToEndMemoryNetwork(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        num_hops=num_hops,
        sentence_len=sentence_len,
        memory_size=memory_size,
        weight_tying='adjacent'
    )
    
    # Sample input (toy bAbI-like data)
    story = torch.randint(0, vocab_size, (batch_size, memory_size, sentence_len))
    query = torch.randint(0, vocab_size, (batch_size, sentence_len))
    
    # Forward
    model.eval()
    with torch.no_grad():
        logits, attentions = model(story, query)
    
    print(f"Input:")
    print(f"  Story shape: {story.shape} (batch, num_sentences, sentence_len)")
    print(f"  Query shape: {query.shape} (batch, query_len)")
    print()
    print(f"Output:")
    print(f"  Logits shape: {logits.shape} (batch, vocab_size)")
    print(f"  Attentions shape: {attentions.shape} (batch, num_hops, num_sentences)")
    print()
    print(f"Attention visualization (first sample, first hop):")
    print(f"  {attentions[0, 0, :10].numpy()}")
    print(f"  Sum: {attentions[0, 0].sum().item():.4f} (should be ~1.0)")
    print()
    
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {num_params:,}")
    print()


def demo_ntm():
    """Demonstrate Neural Turing Machine."""
    print("="*70)
    print("Neural Turing Machine (NTM) Demo")
    print("="*70)
    
    # Hyperparameters (copy task)
    input_dim = 8
    output_dim = 8
    controller_dim = 100
    memory_size = 128
    memory_dim = 20
    batch_size = 2
    seq_len = 5
    
    # Create model
    model = NeuralTuringMachine(
        input_dim=input_dim,
        output_dim=output_dim,
        controller_dim=controller_dim,
        memory_size=memory_size,
        memory_dim=memory_dim,
        num_read_heads=1,
        num_write_heads=1
    )
    
    # Sample input (binary sequences for copy task)
    x = torch.randint(0, 2, (batch_size, seq_len, input_dim)).float()
    
    # Forward (sequence processing)
    model.eval()
    outputs = []
    state = None
    
    with torch.no_grad():
        for t in range(seq_len):
            output, state = model(x[:, t, :], state)
            outputs.append(output)
    
    outputs = torch.stack(outputs, dim=1)  # [batch, seq_len, output_dim]
    
    print(f"Task: Copy binary sequences")
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {outputs.shape}")
    print()
    print(f"Architecture:")
    print(f"  Controller: LSTM with {controller_dim} units")
    print(f"  Memory: {memory_size} slots Γ— {memory_dim} dimensions")
    print(f"  Read heads: 1, Write heads: 1")
    print()
    
    # Parameters
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {num_params:,}")
    print()
    print("Note: NTM can learn to copy sequences of arbitrary length!")
    print()


def demo_kv_memory():
    """Demonstrate Key-Value Memory Network."""
    print("="*70)
    print("Key-Value Memory Network Demo")
    print("="*70)
    
    # Hyperparameters
    key_dim = 64
    value_dim = 128
    query_dim = 64
    num_hops = 2
    batch_size = 4
    num_keys = 10
    
    # Create model
    model = KeyValueMemoryNetwork(
        key_dim=key_dim,
        value_dim=value_dim,
        query_dim=query_dim,
        num_hops=num_hops
    )
    
    # Sample data
    query = torch.randn(batch_size, query_dim)
    keys = torch.randn(batch_size, num_keys, key_dim)
    values = torch.randn(batch_size, num_keys, value_dim)
    
    # Forward
    model.eval()
    with torch.no_grad():
        output, attention = model(query, keys, values)
    
    print(f"Input:")
    print(f"  Query shape: {query.shape}")
    print(f"  Keys shape: {keys.shape}")
    print(f"  Values shape: {values.shape}")
    print()
    print(f"Output:")
    print(f"  Retrieved shape: {output.shape}")
    print(f"  Attention shape: {attention.shape}")
    print()
    print(f"Attention weights (first sample):")
    print(f"  {attention[0].numpy()}")
    print(f"  Max attention: {attention[0].max().item():.4f}")
    print()
    print("Use case: Knowledge base QA (keys=entities, values=descriptions)")
    print()


def demo_matching_network():
    """Demonstrate Matching Network (few-shot learning)."""
    print("="*70)
    print("Matching Network Demo (Few-Shot Learning)")
    print("="*70)
    
    # 5-way 1-shot task
    input_dim = 128
    hidden_dim = 64
    num_classes = 5
    num_support = 5  # 1 per class
    batch_size = 8
    
    # Create model
    model = MatchingNetwork(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes
    )
    
    # Sample data (random features, in practice from CNN)
    query = torch.randn(batch_size, input_dim)
    support = torch.randn(batch_size, num_support, input_dim)
    support_labels = torch.arange(num_classes).unsqueeze(0).expand(batch_size, -1)
    
    # Forward
    model.eval()
    with torch.no_grad():
        predictions = model(query, support, support_labels)
    
    print(f"Task: 5-way 1-shot classification")
    print(f"  Query shape: {query.shape}")
    print(f"  Support shape: {support.shape}")
    print(f"  Support labels: {support_labels[0].tolist()}")
    print()
    print(f"Predictions shape: {predictions.shape} (batch, num_classes)")
    print(f"Sample prediction (first query): {predictions[0].numpy()}")
    print(f"Predicted class: {predictions[0].argmax().item()}")
    print()
    print("Memory = Support Set | Retrieval = Attention-weighted classification")
    print()


def print_complexity_comparison():
    """Print complexity comparison of memory architectures."""
    print("="*70)
    print("Computational Complexity Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Architecture   β”‚ Time (per step)  β”‚ Space (memory)   β”‚ Scalability     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ MemN2N         β”‚ O(KΒ·NΒ·d)         β”‚ O(NΒ·d)           β”‚ Medium          β”‚
β”‚                β”‚ K hops, N slots  β”‚                  β”‚ (hundreds)      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ NTM            β”‚ O(NΒ·M + CΒ²)      β”‚ O(NΒ·M)           β”‚ Low             β”‚
β”‚                β”‚ N slots, M dims  β”‚                  β”‚ (128 slots)     β”‚
β”‚                β”‚ C controller     β”‚                  β”‚                 β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ DNC            β”‚ O(NΒ·M + CΒ²)      β”‚ O(NΒ·M + NΒ²)      β”‚ Medium          β”‚
β”‚                β”‚ + temporal links β”‚ (link matrix)    β”‚ (256 slots)     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Key-Value      β”‚ O(NΒ·K + NΒ·V)     β”‚ O(NΒ·(K+V))       β”‚ High            β”‚
β”‚                β”‚ K,V key/val dims β”‚                  β”‚ (thousands)     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Transformer    β”‚ O(NΒ²Β·d)          β”‚ O(NΒ·d)           β”‚ High            β”‚
β”‚                β”‚ Self-attention   β”‚                  β”‚ (with tricks)   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Scalability to Large Memory:**

1. **MemN2N:** O(N) attention feasible up to ~1000 slots
2. **NTM/DNC:** Complex addressing limits to ~256 slots
3. **Key-Value:** Efficient retrieval, scales to 10K+ slots
4. **Sparse Retrieval (kNN):** Scales to millions (FAISS, ScaNN)
"""
    
    print(comparison)
    print()


def print_method_selection_guide():
    """Print guide for selecting memory architecture."""
    print("="*70)
    print("Memory Network Selection Guide")
    print("="*70)
    print()
    
    guide = """
**Decision Tree:**

1. **Task: Question Answering over text**
   β†’ Use MemN2N
   - Multi-hop reasoning over facts
   - Interpretable attention
   - Example: bAbI tasks

2. **Task: Algorithm learning (copy, sort, etc.)**
   β†’ Use NTM or DNC
   - Explicit read/write operations
   - Location-based addressing
   - Example: Copy sequences, algorithmic tasks

3. **Task: Few-shot learning**
   β†’ Use Matching Networks or Prototypical Networks
   - Memory = support examples
   - Attention-based retrieval
   - Example: Omniglot, miniImageNet

4. **Task: Knowledge base QA**
   β†’ Use Key-Value Memory
   - Keys for retrieval, values for content
   - Efficient for large KBs
   - Example: WikiMovies, entity QA

5. **Task: Long-range dependencies in sequences**
   β†’ Use Memory-Augmented Transformer
   - Combine self-attention + external memory
   - Example: Long documents, code

6. **Task: Retrieval-augmented generation**
   β†’ Use RAG or RETRO
   - External corpus as memory
   - Dense retrieval + generation
   - Example: Open-domain QA

**Performance Expectations:**

bAbI Tasks (Mean accuracy):
- MemN2N: 96.4% (state-of-the-art for memory networks)
- LSTM: 51.2%
- Human: ~100%

Algorithm Learning (Copy task):
- NTM: 100% (perfect generalization)
- LSTM: ~60% (fails on longer sequences)

Few-Shot (Omniglot 5-way 1-shot):
- Matching Networks: 98.1%
- Prototypical Networks: 98.8%
"""
    
    print(guide)
    print()


# ============================================================================
# Run All Demonstrations
# ============================================================================

if __name__ == "__main__":
    torch.manual_seed(42)
    
    demo_memn2n()
    demo_ntm()
    demo_kv_memory()
    demo_matching_network()
    print_complexity_comparison()
    print_method_selection_guide()
    
    print("="*70)
    print("Memory Network Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ MemN2N: Multi-hop reasoning via attention (QA tasks)")
    print("  β€’ NTM: Learn algorithms via read/write operations")
    print("  β€’ Key-Value: Efficient retrieval from knowledge bases")
    print("  β€’ Matching: Few-shot learning via memory = support set")
    print()
    print("Core principle: Explicit external memory > implicit parameters")
    print("Key insight: Attention = differentiable content-addressable memory")
    print()

Advanced Memory Networks: Mathematical Foundations and Modern ArchitecturesΒΆ

1. Introduction to Memory-Augmented Neural NetworksΒΆ

Memory Networks extend neural networks with explicit external memory, enabling them to store and retrieve information over long time horizons, unlike LSTMs which compress history into fixed-size hidden states.

Core motivation: Standard RNNs/LSTMs have limited capacity to remember facts and reasoning chains. Memory networks separate:

  1. Computation: Neural network processing

  2. Storage: External memory bank

Key innovation (Weston et al., 2015): Memory as an array of vectors $\(\mathcal{M} = [m_1, m_2, \ldots, m_N] \in \mathbb{R}^{N \times d}\)$

Access mechanism: Soft attention over memory slots $\(\mathbf{o} = \sum_{i=1}^N \alpha_i m_i, \quad \alpha_i = \frac{\exp(s(q, m_i))}{\sum_j \exp(s(q, m_j))}\)$

where:

  • \(q\): Query vector (e.g., question embedding)

  • \(s(q, m_i)\): Similarity function (dot product, cosine, MLP)

  • \(\mathbf{o}\): Output (weighted memory)

Advantages:

  • Scalability: Memory size \(N\) independent of network parameters

  • Interpretability: Can inspect which memories are accessed

  • Explicit reasoning: Multi-hop attention for complex reasoning

  • Long-term storage: Information persists beyond gradient flow

2. Memory Network ArchitectureΒΆ

2.1 Four Components (I, G, O, R)ΒΆ

I (Input): Convert input \(x\) to internal representation $\(\mathbf{i} = I(x) = \text{Embedding}(x)\)$

G (Generalization): Update memory given new input $\(m_i \leftarrow G(m_i, \mathbf{i}, m) = \begin{cases} \mathbf{i} & \text{if slot } i \text{ selected} \\ m_i & \text{otherwise} \end{cases}\)$

O (Output): Retrieve relevant memories given query \(q\) $\(\mathbf{o} = O(q, m) = \sum_{i=1}^N \text{softmax}(s(q, m_i)) \cdot m_i\)$

R (Response): Generate final output $\(\text{answer} = R(\mathbf{o}, q) = \text{Decoder}(\mathbf{o}, q)\)$

2.2 End-to-End Memory Networks (MemN2N)ΒΆ

Motivation: Make memory networks fully differentiable (remove hard max in original).

Multiple hops: Perform \(K\) reasoning steps $\(q^{(k+1)} = q^{(k)} + \mathbf{o}^{(k)}, \quad k = 0, 1, \ldots, K-1\)$

Hop \(k\) attention: $\(p_i^{(k)} = \text{softmax}((q^{(k)})^T m_i^{(k)})\)\( \)\(\mathbf{o}^{(k)} = \sum_i p_i^{(k)} c_i^{(k)}\)$

where:

  • \(m_i^{(k)}\): Key memory at hop \(k\)

  • \(c_i^{(k)}\): Value memory at hop \(k\)

  • \(q^{(k)}\): Query at hop \(k\)

Final answer: $\(\hat{a} = \text{softmax}(W (q^{(K)} + \mathbf{o}^{(K-1)}))\)$

Embedding schemes:

  1. Adjacent: \(A^{(k+1)} = C^{(k)}\) (parameter sharing across hops)

  2. Layer-wise: Each hop has independent embeddings

  3. Recurrent: \(A^{(k)} = A^{(1)}\), \(C^{(k)} = C^{(1)}\) (tied)

2.3 Position EncodingΒΆ

Motivation: Memory slots are unordered, but temporal/positional info matters.

Temporal encoding: \(m_i = \sum_j l_j \odot w_j\) where \(l_j\) is position encoding

Sinusoidal (from Transformers): $\(l_j^{(d)} = \begin{cases} \sin(j / 10000^{d/D}) & \text{if } d \text{ even} \\ \cos(j / 10000^{d/D}) & \text{if } d \text{ odd} \end{cases}\)$

Learned: Position embeddings as parameters.

3. Attention Mechanisms in MemoryΒΆ

3.1 Content-Based AddressingΒΆ

Similarity function: Dot product (scaled) $\(s(q, m_i) = \frac{q^T m_i}{\sqrt{d}}\)$

Softmax attention: $\(\alpha_i = \frac{\exp(s(q, m_i))}{\sum_j \exp(s(q, m_j))}\)$

Variants:

  • Cosine similarity: \(s(q, m_i) = \frac{q^T m_i}{\|q\| \|m_i\|}\)

  • Euclidean distance: \(s(q, m_i) = -\|q - m_i\|^2\)

  • MLP: \(s(q, m_i) = v^T \tanh(W[q; m_i])\)

3.2 Multi-Head AttentionΒΆ

Parallel attention heads: Each head attends to different aspects $\(\text{Head}_h = \text{Attention}(Q W_h^Q, K W_h^K, V W_h^V)\)$

Concatenate and project: $\(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{Head}_1, \ldots, \text{Head}_H) W^O\)$

Benefits:

  • Attend to multiple memories simultaneously

  • Richer representations

  • Better for complex reasoning

3.3 Sparse AttentionΒΆ

Problem: Dense attention over all \(N\) memories is \(O(N)\).

Top-k attention: Select \(k \ll N\) highest-scoring memories $\(\alpha_i = \begin{cases} \frac{\exp(s(q, m_i))}{\sum_{j \in \text{top-k}} \exp(s(q, m_j))} & \text{if } i \in \text{top-k} \\ 0 & \text{otherwise} \end{cases}\)$

Locality-sensitive hashing (LSH): Approximate nearest neighbors

  • Hash query and keys to buckets

  • Attend only within same bucket

  • \(O(\log N)\) or \(O(\sqrt{N})\) complexity

4. Dynamic Memory Networks (DMN)ΒΆ

Architecture (Kumar et al., 2016): Episodic memory with attention-based reasoning.

4.1 Input ModuleΒΆ

Sentence encoding: BiGRU or Transformer over sentences $\(h_i = \text{BiGRU}(w_{i,1}, \ldots, w_{i,T_i})\)$

Input memory: \(\{h_1, \ldots, h_N\}\) (one per sentence/fact).

4.2 Question ModuleΒΆ

Question encoding: $\(q = \text{GRU}(w_{q,1}, \ldots, w_{q,T_q})\)$

4.3 Episodic Memory ModuleΒΆ

Attention mechanism: Compute relevance of each fact to question $\(g_i^t = \text{Attention}(h_i, q, m^{t-1})\)$

Attention gate: Context vector for episode \(t\) $\(c^t = \sum_i g_i^t h_i\)$

Memory update: GRU over context $\(m^t = \text{GRU}(c^t, m^{t-1})\)$

Multiple episodes: Iterate \(T\) times (multi-hop reasoning).

4.4 Answer ModuleΒΆ

Decoder: GRU initialized with final memory $\(a = \text{Decoder}(m^T, q)\)$

Output: Word-by-word generation or classification.

4.5 Attention Gate DetailsΒΆ

Feature vector: $\(z_i^t = [h_i; q; m^{t-1}; h_i \odot q; h_i \odot m^{t-1}; |h_i - q|; |h_i - m^{t-1}|]\)$

Two-layer network: $\(g_i^t = \text{sigmoid}(W^{(2)} \tanh(W^{(1)} z_i^t + b^{(1)}) + b^{(2)})\)$

Softmax normalization (optional): $\(g_i^t = \frac{\exp(g_i^t)}{\sum_j \exp(g_j^t)}\)$

5. Neural Turing Machines (NTM)ΒΆ

Graves et al. (2014): Differentiable Turing Machine with read/write heads.

5.1 ArchitectureΒΆ

Controller: LSTM or feedforward network $\(h_t = \text{Controller}(x_t, h_{t-1})\)$

Memory: \(M_t \in \mathbb{R}^{N \times M}\) (rows = addresses, columns = content)

Read head: Weighted read from memory $\(\mathbf{r}_t = \sum_{i=1}^N w_t^r(i) M_t(i)\)$

Write head: Erase + add $\(M_t(i) = M_{t-1}(i) \cdot (\mathbf{1} - w_t^w(i) \mathbf{e}_t) + w_t^w(i) \mathbf{a}_t\)$

where:

  • \(w_t^w\): Write attention weights

  • \(\mathbf{e}_t\): Erase vector (what to remove)

  • \(\mathbf{a}_t\): Add vector (what to write)

5.2 Addressing MechanismsΒΆ

Content-based addressing: Similarity to key \(\mathbf{k}_t\) $\(w_t^c(i) \propto \exp\left(\beta_t \cdot \text{cosine}(\mathbf{k}_t, M_t(i))\right)\)$

Interpolation: Blend content and previous weights $\(w_t^g = g_t w_t^c + (1 - g_t) w_{t-1}\)$

where \(g_t \in [0, 1]\) is interpolation gate.

Shift: Convolutional shift (move attention) $\(\tilde{w}_t(i) = \sum_j w_t^g(j) s_t(i - j)\)$

where \(s_t\) is shift distribution (e.g., shift left/right).

Sharpening: Focus attention $\(w_t(i) = \frac{\tilde{w}_t(i)^{\gamma_t}}{\sum_j \tilde{w}_t(j)^{\gamma_t}}\)$

where \(\gamma_t \geq 1\) is sharpening factor.

5.3 TrainingΒΆ

Backpropagation through time: Gradients flow through memory and attention.

Challenges:

  • Vanishing/exploding gradients (long sequences)

  • Sensitive to initialization

  • Slow convergence

Solutions:

  • Gradient clipping

  • Curriculum learning (start with short sequences)

  • Pre-training

6. Differentiable Neural Computer (DNC)ΒΆ

DeepMind (2016): Evolution of NTM with more sophisticated memory management.

6.1 Enhancements Over NTMΒΆ

Temporal links: Track write order $\(L_t[i, j] = \text{link from slot } i \text{ to } j\)$

Usage vector: Track which slots are occupied $\(u_t(i) = (u_{t-1}(i) + w_t^w(i) - u_{t-1}(i) \odot w_t^w(i)) \odot (1 - f_t(i))\)$

where \(f_t\) is free gate (deallocate).

Allocation: Write to least-used slots $\(a_t = \text{sort}(u_t)\)$

6.2 Attention MechanismsΒΆ

Three read modes:

  1. Forward: Follow temporal links forward $\(w_t^{f,i} = \sum_j L_t[i, j] w_{t-1}^r\)$

  2. Backward: Follow temporal links backward $\(w_t^{b,i} = \sum_j L_t[j, i] w_{t-1}^r\)$

  3. Content: Content-based (like NTM)

Read weights: Interpolate modes $\(w_t^r = \pi_t^{(1)} w_t^{b} + \pi_t^{(2)} w_t^{c} + \pi_t^{(3)} w_t^{f}\)$

where \(\pi_t\) are learned mode weights (\(\sum \pi_t^{(i)} = 1\)).

6.3 ApplicationsΒΆ

bAbI tasks: 20/20 tasks solved (vs 14/20 for MemN2N)

Graph reasoning: Shortest path, connectivity

Algorithmic tasks:

  • Copy, repeat copy

  • Priority sort

  • Mini-SHRDLU (blocks world)

Results: Perfect generalization on trained tasks, some transfer.

7. Key-Value Memory NetworksΒΆ

Miller et al. (2016): Separate keys (for retrieval) and values (for content).

Memory: Pairs \((k_i, v_i)\) where \(k_i, v_i \in \mathbb{R}^d\)

Attention over keys: $\(\alpha_i = \text{softmax}(q^T k_i)\)$

Retrieve values: $\(\mathbf{o} = \sum_i \alpha_i v_i\)$

Advantages:

  • Keys can be hashed/compressed (fast retrieval)

  • Values can be rich (full information)

  • Decouples addressing from content

Example: Question answering

  • Keys: Sentence embeddings (for retrieval)

  • Values: Full sentence representations (for reasoning)

8. Transformer as Memory NetworkΒΆ

Connection: Transformers are memory networks with:

  • Memory: Sequence of tokens (key-value pairs)

  • Addressing: Self-attention (content-based)

  • Updates: Feedforward layers

Self-attention: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\)$

is equivalent to memory read with:

  • Queries: \(Q = X W^Q\)

  • Keys: \(K = X W^K\)

  • Values: \(V = X W^V\)

Multi-head: Parallel memory systems (each head = different memory perspective).

Feedforward: Per-position processing (update memory slot).

Advantages over traditional memory networks:

  • Fully parallelizable (no sequential memory updates)

  • State-of-the-art on many tasks

  • Scalable to long sequences (with efficient attention)

Limitations:

  • Quadratic complexity \(O(n^2)\) in sequence length

  • No explicit long-term memory (limited by context window)

9. External Memory ExtensionsΒΆ

9.1 Compressive TransformersΒΆ

Motivation: Store compressed past memories beyond context window.

Architecture:

  1. Active memory: Recent \(n_m\) tokens (regular attention)

  2. Compressed memory: Older tokens compressed to \(n_c\) slots

  3. Compression: Max-pooling, convolution, or learned network

Attention: Attend to both active and compressed memory.

Results: Better long-range dependencies (PG19 dataset).

9.2 Memory Transformer (Memorizing Transformer)ΒΆ

Approach: Retrieve from external key-value store.

Memory: Large database of \((k, v)\) pairs from past tokens

Retrieval: kNN search for top-k keys closest to query $\(\text{kNN}(q) = \{(k_i, v_i) : i \in \text{top-k by } \|q - k_i\|\}\)$

Integrate: Add retrieved values to Transformer layers

Benefits:

  • Memory scales independently (gigabytes)

  • Constant-time lookup (approximate NN)

  • Improves perplexity on long documents

9.3 Neural CacheΒΆ

Simple approach: Cache recent hidden states as memory.

Retrieval: $\(p(w_t | h_t) = \lambda p_{\text{model}}(w_t | h_t) + (1-\lambda) p_{\text{cache}}(w_t | h_t)\)$

where \(p_{\text{cache}}\) is based on similarity to cached states.

Effective: Improves language modeling perplexity with minimal overhead.

10. Training Memory NetworksΒΆ

10.1 Supervision StrategiesΒΆ

Weak supervision: Only final answer labeled

  • Train end-to-end with backpropagation

  • Attention learned implicitly

Strong supervision: Attention labels provided

  • Supervise which memory slots to attend to

  • Faster convergence, better performance

Example (bAbI): Provide supporting facts for each question.

10.2 Curriculum LearningΒΆ

Strategy: Train on easier tasks first, gradually increase difficulty.

For memory networks:

  1. Start with 1-hop reasoning

  2. Increase to 2-hop, 3-hop

  3. Increase memory size \(N\)

  4. Increase sequence length

Benefits: Better convergence, avoids local minima.

10.3 Loss FunctionsΒΆ

Classification: Cross-entropy $\(\mathcal{L} = -\sum_c y_c \log \hat{y}_c\)$

Reinforcement learning: REINFORCE for hard attention $\(\mathcal{L} = -\mathbb{E}_{\alpha \sim p_\theta}[R(\alpha) \log p_\theta(\alpha)]\)$

where \(R(\alpha)\) is reward (e.g., answer correctness).

Margin loss: For ranking (e.g., correct vs incorrect memories) $\(\mathcal{L} = \max(0, \gamma - s(q, m_+) + s(q, m_-))\)$

11. ApplicationsΒΆ

11.1 Question AnsweringΒΆ

bAbI tasks (Facebook AI):

  • 20 tasks testing different reasoning skills

  • Memory networks: 20/20 tasks solved

  • Baselines (LSTMs): 7/20 tasks

SQuAD (reading comprehension):

  • Key-value memory networks competitive with BiDAF

  • Attention visualizations interpretable

11.2 Dialog SystemsΒΆ

Persona-based dialog: Remember user preferences

  • Memory: Previous conversation turns

  • Retrieval: Relevant past context

  • Generation: Consistent responses

Task-oriented dialog:

  • Memory: Slot-value pairs (e.g., restaurant attributes)

  • Reasoning: Multi-hop to find matching entities

11.3 Reasoning TasksΒΆ

Visual question answering (VQA):

  • Memory: Image regions (extracted by object detector)

  • Query: Question embedding

  • Multi-hop attention: Compositional reasoning

Text entailment:

  • Memory: Premise sentences

  • Query: Hypothesis

  • Output: Entailment/contradiction/neutral

11.4 Program SynthesisΒΆ

Neural Programmer-Interpreter (NPI):

  • Memory: Program states

  • Retrieval: Relevant sub-programs

  • Hierarchical execution

Differentiable Forth:

  • Memory: Stack and heap

  • Operations: Differentiable push/pop/arithmetic

12. Theoretical AnalysisΒΆ

12.1 ExpressivenessΒΆ

Theorem: Memory networks with sufficient capacity are Turing-complete.

Proof sketch:

  • Memory = Turing machine tape

  • Attention = read/write head

  • Controller = state transitions

Caveat: In practice, limited by:

  • Finite memory size \(N\)

  • Approximate attention (softmax vs hard)

  • Training difficulties

12.2 Sample ComplexityΒΆ

Hypothesis: Explicit memory reduces sample complexity for tasks requiring long-term dependencies.

Evidence:

  • Memory networks learn bAbI with fewer examples than LSTMs

  • Copy tasks: NTM sample-efficient

Theory gap: Formal bounds lacking.

12.3 Attention EntropyΒΆ

Measure: Entropy of attention distribution $\(H(\alpha) = -\sum_i \alpha_i \log \alpha_i\)$

Low entropy: Focused attention (confident retrieval)

High entropy: Distributed attention (uncertain or uniform)

Observations:

  • Early hops: High entropy (exploration)

  • Later hops: Low entropy (focused retrieval)

  • Difficult tasks: Higher average entropy

13. Challenges and LimitationsΒΆ

13.1 ScalabilityΒΆ

Memory size: \(O(N)\) attention cost prohibitive for large \(N\) (millions).

Solutions:

  • Hierarchical memory (tree structure)

  • Sparse attention (top-k, LSH)

  • Learned memory management (forget irrelevant)

13.2 Training InstabilityΒΆ

Problem: Attention can be chaotic (small input changes β†’ large attention shifts).

Causes:

  • Softmax saturation

  • Gradient vanishing/exploding through attention

Solutions:

  • Gradient clipping

  • Attention dropout

  • Temperature scaling in softmax

13.3 Interpretability vs PerformanceΒΆ

Trade-off: Sparse attention (interpretable) vs dense attention (better performance).

Example: Hard attention (reinforcement learning) interpretable but slower to train.

Compromise: Top-k attention with sufficient \(k\) for performance.

14. Recent Advances (2020-2024)ΒΆ

14.1 Memory-Efficient TransformersΒΆ

Linear attention: Replace softmax with kernelized attention $\(\text{Attn}(Q, K, V) = \phi(Q) (\phi(K)^T V)\)$

reduces complexity from \(O(n^2)\) to \(O(n)\).

Reformer: LSH attention + reversible layers

  • \(O(n \log n)\) complexity

  • Constant memory (reversible)

14.2 Retrieval-Augmented Generation (RAG)ΒΆ

Approach: Retrieve documents from external corpus, condition generation.

Architecture: $\(p(y | x) = \sum_d p(d | x) p(y | x, d)\)$

where:

  • \(p(d | x)\): Retriever (e.g., DPR)

  • \(p(y | x, d)\): Generator (e.g., BART)

Applications: Open-domain QA, fact-checking.

14.3 Memory as ParametersΒΆ

Mega (2022): Exponential moving average attention

  • Memory encoded in EMA weights

  • Constant memory, linear complexity

Results: Competitive with Transformers, much faster.

14.4 Associative MemoryΒΆ

Hopfield Networks for Transformers:

  • Modern Hopfield networks = attention mechanism

  • Energy-based interpretation

Continuous Hopfield: $\(E = -\log \sum_i \exp(\beta \langle \xi, \xi_i \rangle)\)$

minimized by attention.

15. Comparison with Other ArchitecturesΒΆ

15.1 Memory Networks vs RNNs/LSTMsΒΆ

Aspect

Memory Networks

RNNs/LSTMs

Memory

Explicit (external)

Implicit (hidden state)

Capacity

Scalable (independent \(N\))

Fixed (hidden size)

Access

Random (attention)

Sequential

Interpretability

High (visualize attention)

Low (opaque state)

Training

Stable (explicit memory)

Vanishing gradients

Complexity

\(O(N)\) per step

\(O(1)\) per step

Verdict: Memory networks better for tasks requiring explicit storage/retrieval.

15.2 Memory Networks vs TransformersΒΆ

Aspect

Memory Networks

Transformers

Memory mechanism

Explicit external

Implicit (self-attention)

Temporal modeling

Built-in (position encoding)

Via position encoding

Parallelization

Moderate (depends on design)

Full parallelization

Long-term memory

Can have persistent memory

Limited by context

Scalability

Depends on \(N\)

Quadratic in sequence length

Convergence: Transformers can be viewed as memory networks (self-attention = memory read).

15.3 NTM/DNC vs TransformersΒΆ

Aspect

NTM/DNC

Transformers

Memory

Explicit read/write

Implicit (attention)

Sequential ops

Required (state updates)

Fully parallel

Algorithmic tasks

Excellent (exact copy)

Poor without training tricks

Language modeling

Moderate

State-of-the-art

Training

Difficult (BPTT, sensitive init)

Easier (parallel, stable)

Niche: NTM/DNC for algorithmic reasoning, Transformers for NLP.

16. Implementation Best PracticesΒΆ

16.1 Memory InitializationΒΆ

Options:

  1. Zero initialization: \(m_i = \mathbf{0}\)

  2. Random: \(m_i \sim \mathcal{N}(0, \sigma^2 I)\)

  3. Pre-filled: Initialize with input embeddings

Recommendation: Pre-fill with input for faster convergence.

16.2 Attention TemperatureΒΆ

Softmax with temperature: $\(\alpha_i = \frac{\exp(s(q, m_i) / \tau)}{\sum_j \exp(s(q, m_j) / \tau)}\)$

Low \(\tau\) (< 1): Sharper attention (more focused) High \(\tau\) (> 1): Softer attention (more uniform)

Annealing: Start high (exploration), decrease (exploitation).

16.3 Gradient ClippingΒΆ

Essential for NTM/DNC due to BPTT through memory.

Norm clipping: $\(g \leftarrow \frac{g}{\max(1, \|g\| / C)}\)$

Typical: \(C = 5\) or \(10\).

16.4 HyperparametersΒΆ

Memory size: \(N = 128\) (small), \(N = 1024\) (large)

Memory dimension: \(d = 128\) or \(256\)

Number of hops: \(K = 1\) (simple), \(K = 3\) (complex reasoning)

Learning rate: \(10^{-4}\) (Adam), with warmup

Batch size: 32-128 (depends on task)

17. Summary and Future DirectionsΒΆ

17.1 Key TakeawaysΒΆ

Memory networks:

  • Explicit external memory (scalable capacity)

  • Soft attention for differentiability

  • Multi-hop reasoning for complex tasks

  • Interpretable (attention weights)

Architectures:

  • MemN2N: End-to-end, simple, effective

  • DMN: Episodic memory, strong on bAbI

  • NTM/DNC: Algorithmic tasks, challenging to train

  • Key-Value: Decoupled retrieval and content

Transformers: Modern incarnation (self-attention = memory).

17.2 Open ProblemsΒΆ

  1. Scalability: Efficient attention for \(N > 10^6\)

  2. Continual learning: Update memory without catastrophic forgetting

  3. Compositional generalization: Combine memories in novel ways

  4. Memory management: Automatic pruning, consolidation

  5. Theoretical foundations: Formal expressiveness, sample complexity

17.3 Future DirectionsΒΆ

Hybrid architectures:

  • Combine Transformers with explicit external memory

  • Best of both worlds (parallelization + persistent storage)

Neurosymbolic integration:

  • Memory as symbolic knowledge base

  • Neural retrieval, symbolic reasoning

Lifelong learning:

  • Memory networks for continual learning

  • Store past tasks, prevent forgetting

Multimodal memory:

  • Unified memory for vision, language, audio

  • Cross-modal retrieval and reasoning

18. ConclusionΒΆ

Memory-augmented neural networks represent a fundamental shift from compressing all knowledge into parameters to explicit, inspectable storage. While Transformers have dominated recent NLP advances, the core ideasβ€”attention over external memory, multi-hop reasoningβ€”remain central.

Current state:

  • MemN2N/DMN: Proven on reasoning tasks (bAbI)

  • NTM/DNC: Niche (algorithmic tasks)

  • Transformers: Dominant (NLP, vision), can be viewed as memory networks

Verdict: Memory networks are a conceptually elegant framework that informs modern architectures. Explicit external memory will likely resurface as models scale beyond current context limits, enabling true lifelong learning and continual knowledge accumulation.

The future likely involves hybrid systems: Transformers for in-context learning, external memory for long-term knowledgeβ€”combining the parallelization benefits of modern architectures with the explicit reasoning capabilities of memory networks.

"""
Advanced Memory Networks - Production Implementation
Comprehensive PyTorch implementations of MemN2N, DMN, and NTM architectures
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List

# ===========================
# 1. Position Encoding
# ===========================

class PositionEncoding(nn.Module):
    """Sinusoidal position encoding for memory slots"""
    
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            x + positional encoding
        """
        return x + self.pe[:x.size(1), :]


# ===========================
# 2. End-to-End Memory Network (MemN2N)
# ===========================

class MemN2N(nn.Module):
    """
    End-to-End Memory Network (Sukhbaatar et al., 2015)
    Multi-hop attention over external memory
    """
    
    def __init__(self,
                 vocab_size: int,
                 embed_dim: int,
                 num_hops: int = 3,
                 memory_size: int = 50,
                 dropout: float = 0.1):
        super().__init__()
        self.num_hops = num_hops
        self.embed_dim = embed_dim
        self.memory_size = memory_size
        
        # Embedding layers for each hop (keys and values)
        # Adjacent weight tying: A^(k+1) = C^(k)
        self.embed_A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) for _ in range(num_hops + 1)])
        self.embed_C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) for _ in range(num_hops)])
        
        # Question embedding
        self.embed_B = nn.Embedding(vocab_size, embed_dim)
        
        # Output projection
        self.W = nn.Linear(embed_dim, vocab_size, bias=False)
        
        # Position encoding
        self.position_encoding = PositionEncoding(embed_dim, max_len=memory_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, 
                stories: torch.Tensor, 
                queries: torch.Tensor,
                story_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            stories: (batch, memory_size, story_len) - memory slot contents
            queries: (batch, query_len) - question
            story_lengths: (batch, memory_size) - actual lengths of each story
        
        Returns:
            logits: (batch, vocab_size) - answer distribution
        """
        batch_size = stories.size(0)
        
        # Embed query: (batch, query_len, embed_dim) -> (batch, embed_dim)
        q = self.embed_B(queries)
        q = torch.sum(q, dim=1)  # Bag-of-words (can use RNN instead)
        
        # Initial query
        u = q  # (batch, embed_dim)
        
        # Multi-hop attention
        for hop in range(self.num_hops):
            # Embed memory keys (for attention)
            m = self.embed_A[hop](stories)  # (batch, memory_size, story_len, embed_dim)
            m = torch.sum(m, dim=2)  # Bag-of-words per story: (batch, memory_size, embed_dim)
            m = self.position_encoding(m)
            
            # Embed memory values (for output)
            c = self.embed_C[hop](stories)  # (batch, memory_size, story_len, embed_dim)
            c = torch.sum(c, dim=2)  # (batch, memory_size, embed_dim)
            c = self.position_encoding(c)
            
            # Attention: p_i = softmax(u^T m_i)
            attention_logits = torch.matmul(m, u.unsqueeze(2)).squeeze(2)  # (batch, memory_size)
            
            # Mask padding if story_lengths provided
            if story_lengths is not None:
                mask = torch.arange(self.memory_size, device=stories.device)[None, :] >= story_lengths[:, None]
                attention_logits = attention_logits.masked_fill(mask, -1e9)
            
            p = F.softmax(attention_logits, dim=1)  # (batch, memory_size)
            
            # Output: o = Ξ£ p_i c_i
            o = torch.matmul(p.unsqueeze(1), c).squeeze(1)  # (batch, embed_dim)
            
            # Update query: u = u + o (residual connection)
            u = u + o
            u = self.dropout(u)
        
        # Final answer
        logits = self.W(u)  # (batch, vocab_size)
        
        return logits, p  # Return logits and final attention for visualization


# ===========================
# 3. Dynamic Memory Network (DMN)
# ===========================

class AttentionGRUCell(nn.Module):
    """Attention-based GRU cell for episodic memory"""
    
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Attention gate
        self.attn_gate = nn.Sequential(
            nn.Linear(input_size * 7, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
        
        # GRU cell
        self.gru_cell = nn.GRUCell(input_size, hidden_size)
    
    def forward(self, fact: torch.Tensor, prev_memory: torch.Tensor, 
                question: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            fact: (batch, input_size) - current fact
            prev_memory: (batch, hidden_size) - previous episode memory
            question: (batch, input_size) - question embedding
        
        Returns:
            gate: (batch, 1) - attention gate
            memory: (batch, hidden_size) - updated memory
        """
        # Compute attention gate features
        z = torch.cat([
            fact,
            prev_memory,
            question,
            fact * question,
            fact * prev_memory,
            torch.abs(fact - question),
            torch.abs(fact - prev_memory)
        ], dim=1)
        
        gate = self.attn_gate(z)  # (batch, 1)
        
        # Update memory with GRU (weighted by gate)
        fact_weighted = gate * fact
        memory = self.gru_cell(fact_weighted, prev_memory)
        
        return gate, memory


class EpisodicMemory(nn.Module):
    """Episodic memory module with attention-based reasoning"""
    
    def __init__(self, input_size: int, hidden_size: int, num_hops: int = 3):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_hops = num_hops
        
        self.attn_gru = AttentionGRUCell(input_size, hidden_size)
        self.memory_update = nn.GRUCell(input_size, hidden_size)
    
    def forward(self, facts: torch.Tensor, question: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Args:
            facts: (batch, num_facts, input_size) - input facts
            question: (batch, input_size) - question embedding
        
        Returns:
            memory: (batch, hidden_size) - final episodic memory
            attention_history: List of attention weights per hop
        """
        batch_size, num_facts, _ = facts.shape
        
        # Initialize memory with question
        memory = question.clone()
        
        attention_history = []
        
        # Multiple passes (hops) over facts
        for hop in range(self.num_hops):
            # Context vector for this episode
            context = torch.zeros(batch_size, self.hidden_size, device=facts.device)
            
            gates = []
            
            # Iterate over facts
            for i in range(num_facts):
                fact = facts[:, i, :]  # (batch, input_size)
                gate, context = self.attn_gru(fact, memory, question)
                gates.append(gate)
            
            # Stack gates
            gates = torch.cat(gates, dim=1)  # (batch, num_facts)
            attention_history.append(gates)
            
            # Update episode memory
            memory = self.memory_update(context, memory)
        
        return memory, attention_history


class DMN(nn.Module):
    """
    Dynamic Memory Network (Kumar et al., 2016)
    Episodic memory with attention-based reasoning
    """
    
    def __init__(self,
                 vocab_size: int,
                 embed_dim: int,
                 hidden_size: int,
                 num_hops: int = 3,
                 dropout: float = 0.1):
        super().__init__()
        
        # Embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Input module (encode facts)
        self.input_gru = nn.GRU(embed_dim, hidden_size, batch_first=True, bidirectional=True)
        
        # Question module
        self.question_gru = nn.GRU(embed_dim, hidden_size, batch_first=True)
        
        # Episodic memory module
        self.episodic_memory = EpisodicMemory(hidden_size * 2, hidden_size, num_hops)
        
        # Answer module (decoder)
        self.answer_gru = nn.GRUCell(vocab_size, hidden_size)
        self.answer_fc = nn.Linear(hidden_size, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, 
                facts: torch.Tensor, 
                question: torch.Tensor,
                answer: Optional[torch.Tensor] = None,
                max_len: int = 10) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Args:
            facts: (batch, num_facts, fact_len) - input facts
            question: (batch, question_len) - question
            answer: (batch, answer_len) - ground truth answer (for teacher forcing)
            max_len: maximum answer length for generation
        
        Returns:
            output: (batch, answer_len, vocab_size) - answer logits
            attention_history: attention weights per hop
        """
        batch_size = facts.size(0)
        num_facts = facts.size(1)
        
        # Input module: encode facts
        fact_embeds = self.embedding(facts)  # (batch, num_facts, fact_len, embed_dim)
        
        encoded_facts = []
        for i in range(num_facts):
            fact_embed = fact_embeds[:, i, :, :]  # (batch, fact_len, embed_dim)
            _, h = self.input_gru(fact_embed)  # h: (2, batch, hidden_size)
            h = torch.cat([h[0], h[1]], dim=1)  # Concatenate forward and backward
            encoded_facts.append(h)
        
        encoded_facts = torch.stack(encoded_facts, dim=1)  # (batch, num_facts, hidden_size*2)
        
        # Question module
        question_embed = self.embedding(question)  # (batch, question_len, embed_dim)
        _, q = self.question_gru(question_embed)  # (1, batch, hidden_size)
        q = q.squeeze(0)  # (batch, hidden_size)
        
        # Episodic memory module
        memory, attention_history = self.episodic_memory(encoded_facts, q)
        
        # Answer module (decoder)
        if answer is not None:
            # Teacher forcing
            answer_len = answer.size(1)
        else:
            answer_len = max_len
        
        outputs = []
        hidden = memory
        input_token = torch.zeros(batch_size, self.answer_fc.out_features, device=facts.device)
        
        for t in range(answer_len):
            hidden = self.answer_gru(input_token, hidden)
            output = self.answer_fc(hidden)  # (batch, vocab_size)
            outputs.append(output)
            
            # Next input
            if answer is not None and t < answer_len - 1:
                # Teacher forcing: use ground truth
                input_token = F.one_hot(answer[:, t], num_classes=self.answer_fc.out_features).float()
            else:
                # Use prediction
                input_token = F.softmax(output, dim=1)
        
        outputs = torch.stack(outputs, dim=1)  # (batch, answer_len, vocab_size)
        
        return outputs, attention_history


# ===========================
# 4. Neural Turing Machine (NTM)
# ===========================

class NTMMemory(nn.Module):
    """NTM memory with read and write operations"""
    
    def __init__(self, memory_size: int, memory_dim: int):
        super().__init__()
        self.memory_size = memory_size
        self.memory_dim = memory_dim
    
    def reset(self, batch_size: int, device: torch.device) -> torch.Tensor:
        """Initialize memory to small random values"""
        return torch.randn(batch_size, self.memory_size, self.memory_dim, device=device) * 0.01
    
    def read(self, memory: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """
        Read from memory using attention weights
        Args:
            memory: (batch, memory_size, memory_dim)
            w: (batch, memory_size) - read weights
        Returns:
            read_vector: (batch, memory_dim)
        """
        return torch.matmul(w.unsqueeze(1), memory).squeeze(1)
    
    def write(self, memory: torch.Tensor, w: torch.Tensor, 
              erase: torch.Tensor, add: torch.Tensor) -> torch.Tensor:
        """
        Write to memory: erase + add
        Args:
            memory: (batch, memory_size, memory_dim)
            w: (batch, memory_size) - write weights
            erase: (batch, memory_dim) - erase vector
            add: (batch, memory_dim) - add vector
        Returns:
            updated_memory: (batch, memory_size, memory_dim)
        """
        # Erase: M_t = M_{t-1} * (1 - w * e^T)
        erase_matrix = torch.matmul(w.unsqueeze(2), erase.unsqueeze(1))  # (batch, memory_size, memory_dim)
        memory = memory * (1 - erase_matrix)
        
        # Add: M_t = M_t + w * a^T
        add_matrix = torch.matmul(w.unsqueeze(2), add.unsqueeze(1))
        memory = memory + add_matrix
        
        return memory


class NTMHead(nn.Module):
    """NTM read/write head with content and location addressing"""
    
    def __init__(self, memory_dim: int, controller_dim: int):
        super().__init__()
        self.memory_dim = memory_dim
        
        # Key for content addressing
        self.key_layer = nn.Linear(controller_dim, memory_dim)
        
        # Key strength (beta)
        self.beta_layer = nn.Linear(controller_dim, 1)
        
        # Interpolation gate (g)
        self.g_layer = nn.Linear(controller_dim, 1)
        
        # Shift (s) - convolutional shift
        self.shift_layer = nn.Linear(controller_dim, 3)  # shift left, stay, shift right
        
        # Sharpening (gamma)
        self.gamma_layer = nn.Linear(controller_dim, 1)
    
    def content_addressing(self, memory: torch.Tensor, key: torch.Tensor, 
                          beta: torch.Tensor) -> torch.Tensor:
        """
        Content-based addressing using cosine similarity
        w_c = softmax(beta * cosine(key, M))
        """
        # Normalize
        key = key / (torch.norm(key, dim=1, keepdim=True) + 1e-8)
        memory = memory / (torch.norm(memory, dim=2, keepdim=True) + 1e-8)
        
        # Cosine similarity
        similarity = torch.matmul(memory, key.unsqueeze(2)).squeeze(2)  # (batch, memory_size)
        
        # Apply key strength
        w_c = F.softmax(beta * similarity, dim=1)
        
        return w_c
    
    def forward(self, controller_output: torch.Tensor, memory: torch.Tensor,
                prev_w: torch.Tensor) -> torch.Tensor:
        """
        Compute attention weights
        Args:
            controller_output: (batch, controller_dim)
            memory: (batch, memory_size, memory_dim)
            prev_w: (batch, memory_size) - previous weights
        Returns:
            w: (batch, memory_size) - new attention weights
        """
        # Extract addressing parameters
        key = torch.tanh(self.key_layer(controller_output))  # (batch, memory_dim)
        beta = F.softplus(self.beta_layer(controller_output))  # (batch, 1)
        g = torch.sigmoid(self.g_layer(controller_output))  # (batch, 1)
        shift = F.softmax(self.shift_layer(controller_output), dim=1)  # (batch, 3)
        gamma = 1 + F.softplus(self.gamma_layer(controller_output))  # (batch, 1)
        
        # Content addressing
        w_c = self.content_addressing(memory, key, beta)
        
        # Interpolation
        w_g = g * w_c + (1 - g) * prev_w
        
        # Convolutional shift
        w_tilde = self._convolutional_shift(w_g, shift)
        
        # Sharpening
        w = torch.pow(w_tilde, gamma)
        w = w / (torch.sum(w, dim=1, keepdim=True) + 1e-8)
        
        return w
    
    def _convolutional_shift(self, w: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
        """Apply convolutional shift to weights"""
        batch_size, memory_size = w.shape
        
        # Pad for circular shift
        w_padded = torch.cat([w[:, -1:], w, w[:, :1]], dim=1)
        
        # Convolve with shift kernel
        w_shifted = torch.zeros_like(w)
        for i in range(memory_size):
            w_shifted[:, i] = (shift[:, 0] * w_padded[:, i] +      # shift left
                               shift[:, 1] * w_padded[:, i + 1] +  # stay
                               shift[:, 2] * w_padded[:, i + 2])   # shift right
        
        return w_shifted


class NTM(nn.Module):
    """Neural Turing Machine (Graves et al., 2014)"""
    
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 controller_dim: int,
                 memory_size: int,
                 memory_dim: int,
                 num_heads: int = 1):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.controller_dim = controller_dim
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        self.num_heads = num_heads
        
        # Memory
        self.memory = NTMMemory(memory_size, memory_dim)
        
        # Controller (LSTM)
        self.controller = nn.LSTMCell(input_size + num_heads * memory_dim, controller_dim)
        
        # Read heads
        self.read_heads = nn.ModuleList([NTMHead(memory_dim, controller_dim) for _ in range(num_heads)])
        
        # Write head
        self.write_head = NTMHead(memory_dim, controller_dim)
        self.erase_layer = nn.Linear(controller_dim, memory_dim)
        self.add_layer = nn.Linear(controller_dim, memory_dim)
        
        # Output
        self.fc_out = nn.Linear(controller_dim + num_heads * memory_dim, output_size)
    
    def forward(self, x: torch.Tensor, prev_state: Optional[dict] = None) -> Tuple[torch.Tensor, dict]:
        """
        Single step forward
        Args:
            x: (batch, input_size)
            prev_state: Previous NTM state (memory, controller, weights)
        Returns:
            output: (batch, output_size)
            state: Current NTM state
        """
        batch_size = x.size(0)
        device = x.device
        
        # Initialize state if needed
        if prev_state is None:
            prev_state = self.init_state(batch_size, device)
        
        # Read from memory
        read_vectors = []
        for i, head in enumerate(self.read_heads):
            r = self.memory.read(prev_state['memory'], prev_state['read_w'][i])
            read_vectors.append(r)
        read_vector = torch.cat(read_vectors, dim=1)  # (batch, num_heads * memory_dim)
        
        # Controller input: x + read_vectors
        controller_input = torch.cat([x, read_vector], dim=1)
        controller_h, controller_c = self.controller(controller_input, 
                                                     (prev_state['controller_h'], prev_state['controller_c']))
        
        # Update read weights
        new_read_w = []
        for i, head in enumerate(self.read_heads):
            w = head(controller_h, prev_state['memory'], prev_state['read_w'][i])
            new_read_w.append(w)
        
        # Update write weights
        write_w = self.write_head(controller_h, prev_state['memory'], prev_state['write_w'])
        
        # Write to memory
        erase = torch.sigmoid(self.erase_layer(controller_h))
        add = torch.tanh(self.add_layer(controller_h))
        new_memory = self.memory.write(prev_state['memory'], write_w, erase, add)
        
        # Output
        output_input = torch.cat([controller_h, read_vector], dim=1)
        output = self.fc_out(output_input)
        
        # New state
        new_state = {
            'memory': new_memory,
            'controller_h': controller_h,
            'controller_c': controller_c,
            'read_w': new_read_w,
            'write_w': write_w
        }
        
        return output, new_state
    
    def init_state(self, batch_size: int, device: torch.device) -> dict:
        """Initialize NTM state"""
        return {
            'memory': self.memory.reset(batch_size, device),
            'controller_h': torch.zeros(batch_size, self.controller_dim, device=device),
            'controller_c': torch.zeros(batch_size, self.controller_dim, device=device),
            'read_w': [torch.zeros(batch_size, self.memory_size, device=device) + 1.0 / self.memory_size 
                      for _ in range(self.num_heads)],
            'write_w': torch.zeros(batch_size, self.memory_size, device=device) + 1.0 / self.memory_size
        }
    
    def forward_sequence(self, x_seq: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for entire sequence
        Args:
            x_seq: (batch, seq_len, input_size)
        Returns:
            output_seq: (batch, seq_len, output_size)
        """
        batch_size, seq_len, _ = x_seq.shape
        
        state = None
        outputs = []
        
        for t in range(seq_len):
            output, state = self.forward(x_seq[:, t, :], state)
            outputs.append(output)
        
        return torch.stack(outputs, dim=1)


# ===========================
# 5. Demo Functions
# ===========================

def demo_memn2n():
    """Demonstrate End-to-End Memory Network"""
    print("=" * 60)
    print("Demo: End-to-End Memory Network (MemN2N)")
    print("=" * 60)
    
    vocab_size = 100
    embed_dim = 32
    model = MemN2N(vocab_size, embed_dim, num_hops=3, memory_size=20)
    
    # Dummy data
    batch_size = 4
    stories = torch.randint(0, vocab_size, (batch_size, 20, 6))  # 20 stories, 6 words each
    queries = torch.randint(0, vocab_size, (batch_size, 5))  # 5-word question
    
    logits, attention = model(stories, queries)
    
    print(f"Stories: {stories.shape} (batch, memory_size, story_len)")
    print(f"Queries: {queries.shape} (batch, query_len)")
    print(f"Output logits: {logits.shape} (batch, vocab_size)")
    print(f"Final attention: {attention.shape} (batch, memory_size)")
    print(f"Attention weights (sample): {attention[0][:10]}")
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {num_params:,}")
    print()


def demo_dmn():
    """Demonstrate Dynamic Memory Network"""
    print("=" * 60)
    print("Demo: Dynamic Memory Network (DMN)")
    print("=" * 60)
    
    vocab_size = 100
    embed_dim = 32
    hidden_size = 64
    model = DMN(vocab_size, embed_dim, hidden_size, num_hops=3)
    
    # Dummy data
    batch_size = 4
    facts = torch.randint(0, vocab_size, (batch_size, 10, 8))  # 10 facts, 8 words each
    question = torch.randint(0, vocab_size, (batch_size, 5))  # 5-word question
    answer = torch.randint(0, vocab_size, (batch_size, 3))  # 3-word answer
    
    outputs, attention_history = model(facts, question, answer)
    
    print(f"Facts: {facts.shape} (batch, num_facts, fact_len)")
    print(f"Question: {question.shape} (batch, question_len)")
    print(f"Answer: {answer.shape} (batch, answer_len)")
    print(f"Output: {outputs.shape} (batch, answer_len, vocab_size)")
    print(f"Attention hops: {len(attention_history)}")
    print(f"Attention per hop: {attention_history[0].shape} (batch, num_facts)")
    
    # Show attention evolution
    print(f"\nAttention evolution (sample, fact 0):")
    for hop, attn in enumerate(attention_history):
        print(f"  Hop {hop}: {attn[0, 0].item():.4f}")
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {num_params:,}")
    print()


def demo_ntm():
    """Demonstrate Neural Turing Machine"""
    print("=" * 60)
    print("Demo: Neural Turing Machine (NTM)")
    print("=" * 60)
    
    input_size = 8
    output_size = 8
    controller_dim = 100
    memory_size = 128
    memory_dim = 20
    
    model = NTM(input_size, output_size, controller_dim, memory_size, memory_dim, num_heads=1)
    
    # Copy task: copy input sequence
    batch_size = 2
    seq_len = 10
    x_seq = torch.randn(batch_size, seq_len, input_size)
    
    output_seq = model.forward_sequence(x_seq)
    
    print(f"Input sequence: {x_seq.shape} (batch, seq_len, input_size)")
    print(f"Output sequence: {output_seq.shape} (batch, seq_len, output_size)")
    
    # Single step
    x = torch.randn(batch_size, input_size)
    output, state = model(x)
    
    print(f"\nSingle step:")
    print(f"  Input: {x.shape}")
    print(f"  Output: {output.shape}")
    print(f"  Memory: {state['memory'].shape} (batch, memory_size, memory_dim)")
    print(f"  Read weights: {state['read_w'][0].shape} (batch, memory_size)")
    print(f"  Write weights: {state['write_w'].shape}")
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {num_params:,}")
    print(f"Memory parameters: 0 (external memory)")
    print()


def print_performance_comparison():
    """Comprehensive performance comparison"""
    print("=" * 80)
    print("PERFORMANCE COMPARISON: Memory Networks")
    print("=" * 80)
    
    # 1. bAbI Tasks
    print("\n1. bAbI Question Answering (% tasks solved)")
    print("-" * 80)
    data = [
        ("Model", "1-hop", "2-hop", "3-hop", "Overall"),
        ("-" * 30, "-" * 8, "-" * 8, "-" * 8, "-" * 10),
        ("LSTM", "95%", "45%", "20%", "35% (7/20)"),
        ("MemN2N (1 hop)", "100%", "75%", "55%", "70% (14/20)"),
        ("MemN2N (3 hops)", "100%", "95%", "85%", "95% (19/20)"),
        ("DMN+", "100%", "97%", "93%", "97% (19.4/20)"),
        ("DNC", "100%", "100%", "98%", "100% (20/20)"),
        ("", "", "", "", ""),
        ("Observation:", "", "", "", "Multi-hop crucial for complex reasoning"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<8} {row[2]:<8} {row[3]:<8} {row[4]:<10}")
    
    # 2. Algorithmic Tasks
    print("\n2. Algorithmic Tasks (Accuracy %)")
    print("-" * 80)
    data = [
        ("Task", "LSTM", "MemN2N", "NTM", "DNC"),
        ("-" * 25, "-" * 8, "-" * 10, "-" * 8, "-" * 8),
        ("Copy", "45%", "N/A", "100%", "100%"),
        ("Repeat Copy", "30%", "N/A", "95%", "100%"),
        ("Associative Recall", "40%", "85%", "98%", "100%"),
        ("Priority Sort", "20%", "N/A", "75%", "98%"),
        ("", "", "", "", ""),
        ("Best for:", "Simple", "Reasoning", "Algorithms", "Complex"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<8} {row[2]:<10} {row[3]:<8} {row[4]:<8}")
    
    # 3. Computational Complexity
    print("\n3. Computational Complexity")
    print("-" * 80)
    data = [
        ("Model", "Time per step", "Memory", "Parameters"),
        ("-" * 20, "-" * 20, "-" * 20, "-" * 20),
        ("LSTM", "O(1)", "O(h)", "O(hΒ²)"),
        ("MemN2N", "O(KΒ·N)", "O(NΒ·d)", "O(VΒ·dΒ·K)"),
        ("DMN", "O(KΒ·N)", "O(NΒ·d)", "O(VΒ·d + hΒ²)"),
        ("NTM", "O(NΒ·d)", "O(NΒ·d) external", "O(hΒ·d)"),
        ("DNC", "O(NΒ·d)", "O(NΒ·d) external", "O(hΒ·d)"),
        ("", "", "", ""),
        ("Legend:", "K=hops, N=memory", "h=hidden, d=mem_dim", "V=vocab"),
    ]
    for row in data:
        print(f"{row[0]:<20} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
    
    # 4. Comparison Table
    print("\n4. Architecture Comparison")
    print("-" * 80)
    data = [
        ("Aspect", "MemN2N", "DMN", "NTM/DNC"),
        ("-" * 25, "-" * 20, "-" * 20, "-" * 20),
        ("Memory type", "Static (input)", "Episodic", "Dynamic R/W"),
        ("Addressing", "Content (attention)", "Attention gate", "Content + location"),
        ("Differentiable", "Yes (soft attn)", "Yes", "Yes"),
        ("Multi-hop", "Built-in", "Built-in", "Via controller"),
        ("Training", "Easy", "Moderate", "Hard (BPTT)"),
        ("Interpretability", "High", "High", "Moderate"),
        ("Best for", "QA, reasoning", "QA, dialog", "Algorithms"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
    
    # 5. Hyperparameters
    print("\n5. Recommended Hyperparameters")
    print("-" * 80)
    data = [
        ("Parameter", "MemN2N", "DMN", "NTM"),
        ("-" * 25, "-" * 15, "-" * 15, "-" * 15),
        ("Embedding dim", "32-128", "64-128", "N/A"),
        ("Hidden size", "N/A", "128-256", "100-200"),
        ("Memory size", "20-100", "10-50", "128-512"),
        ("Memory dim", "embed_dim", "2Γ—hidden", "20-40"),
        ("Num hops", "1-3", "2-4", "N/A"),
        ("Learning rate", "0.01", "0.001", "0.0001"),
        ("Gradient clip", "40", "10", "10"),
        ("Batch size", "32", "32", "16"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<15} {row[2]:<15} {row[3]:<15}")
    
    # 6. Decision Guide
    print("\n6. DECISION GUIDE: When to Use Memory Networks")
    print("=" * 80)
    
    print("\nβœ“ USE Memory Networks When:")
    print("β€’ Multi-hop reasoning required (compositional questions)")
    print("β€’ Explicit memory inspection needed (interpretability)")
    print("β€’ Tasks with supporting facts (QA with evidence)")
    print("β€’ Algorithmic tasks (copy, sort, recall)")
    print("β€’ Dialog systems (track conversation history)")
    print("β€’ Long-term dependencies (beyond LSTM capacity)")
    
    print("\nβœ— AVOID Memory Networks When:")
    print("β€’ Simple sequential tasks (LSTM sufficient)")
    print("β€’ Large-scale language modeling (Transformers better)")
    print("β€’ Real-time inference critical (overhead of attention)")
    print("β€’ Limited computational budget (complex architectures)")
    
    print("\n→ RECOMMENDED ALTERNATIVES:")
    print("β€’ General NLP β†’ Transformers (BERT, GPT)")
    print("β€’ Simple QA β†’ BiDAF, BERT-based models")
    print("β€’ Reasoning β†’ Graph Neural Networks, Transformers")
    print("β€’ Algorithms β†’ Specialized neural architectures")
    
    print("\n" + "=" * 80)
    print()


# ===========================
# Run All Demos
# ===========================

if __name__ == "__main__":
    print("\n" + "=" * 80)
    print("MEMORY NETWORKS - COMPREHENSIVE IMPLEMENTATION")
    print("=" * 80 + "\n")
    
    demo_memn2n()
    demo_dmn()
    demo_ntm()
    print_performance_comparison()
    
    print("\n" + "=" * 80)
    print("All demos completed successfully!")
    print("=" * 80)