import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Memory Networks ConceptΒΆ
Architecture:ΒΆ
Memory Matrix: \(M_t \in \mathbb{R}^{N \times M}\)
Controller: LSTM/GRU
Read Head: Attention-based retrieval
Write Head: Memory updates
Read Operation:ΒΆ
where \(w_t^r\) is read attention weight.
π Reference Materials:
deep_nlp.pdf - Deep Nlp
class MemoryBank(nn.Module):
"""External memory module."""
def __init__(self, n_slots, slot_size):
super().__init__()
self.n_slots = n_slots
self.slot_size = slot_size
# Initialize memory
self.register_buffer('memory', torch.zeros(n_slots, slot_size))
def read(self, weights):
"""Read from memory using attention weights."""
# weights: (batch, n_slots)
# memory: (n_slots, slot_size)
return torch.mm(weights, self.memory) # (batch, slot_size)
def write(self, weights, erase_vector, add_vector):
"""Write to memory."""
# Erase
erase = torch.outer(weights.squeeze(), erase_vector) # (n_slots, slot_size)
self.memory = self.memory * (1 - erase)
# Add
add = torch.outer(weights.squeeze(), add_vector)
self.memory = self.memory + add
def reset(self):
self.memory.zero_()
print("MemoryBank defined")
2. Content-Based AddressingΒΆ
Similarity Measure:ΒΆ
Attention:ΒΆ
def content_addressing(key, memory, strength):
"""Content-based addressing."""
# Cosine similarity
key = key / (torch.norm(key, dim=1, keepdim=True) + 1e-8)
memory_norm = memory / (torch.norm(memory, dim=1, keepdim=True) + 1e-8)
similarity = torch.mm(key, memory_norm.t()) # (batch, n_slots)
# Apply strength and softmax
weights = F.softmax(strength * similarity, dim=1)
return weights
print("Content addressing defined")
Neural Turing MachineΒΆ
The Neural Turing Machine (NTM) augments a neural network controller with an external differentiable memory matrix that can be read from and written to via attention-based addressing. The read head computes a weighted sum over memory rows using a soft attention vector, and the write head uses separate erase and add vectors to modify memory contents. Addressing can be content-based (attend to memory rows that match a key vector) or location-based (shift attention to adjacent positions), enabling both associative retrieval and sequential access patterns. This architecture can learn algorithmic tasks like copying, sorting, and sequence recall that are impossible for standard RNNs.
class NTMController(nn.Module):
"""Controller for NTM."""
def __init__(self, input_size, hidden_size, output_size, memory_size):
super().__init__()
self.hidden_size = hidden_size
self.memory_size = memory_size
# LSTM controller
self.lstm = nn.LSTMCell(input_size + memory_size, hidden_size)
# Read parameters
self.read_key = nn.Linear(hidden_size, memory_size)
self.read_strength = nn.Linear(hidden_size, 1)
# Write parameters
self.write_key = nn.Linear(hidden_size, memory_size)
self.write_strength = nn.Linear(hidden_size, 1)
self.erase = nn.Linear(hidden_size, memory_size)
self.add = nn.Linear(hidden_size, memory_size)
# Output
self.fc_out = nn.Linear(hidden_size + memory_size, output_size)
def forward(self, x, prev_state, prev_read):
"""
Args:
x: input (batch, input_size)
prev_state: (h, c) from LSTM
prev_read: previous read vector
"""
# Concatenate input with previous read
controller_input = torch.cat([x, prev_read], dim=1)
# LSTM step
h, c = self.lstm(controller_input, prev_state)
# Read parameters
k_read = torch.tanh(self.read_key(h))
beta_read = F.softplus(self.read_strength(h))
# Write parameters
k_write = torch.tanh(self.write_key(h))
beta_write = F.softplus(self.write_strength(h))
erase_vec = torch.sigmoid(self.erase(h))
add_vec = torch.tanh(self.add(h))
# Output
output = self.fc_out(torch.cat([h, prev_read], dim=1))
return output, (h, c), k_read, beta_read, k_write, beta_write, erase_vec, add_vec
print("NTMController defined")
Complete NTMΒΆ
The complete NTM combines the controller (an LSTM or feedforward network), the memory matrix, and the read/write heads into an end-to-end differentiable system. At each time step, the controller receives the current input and the previous read vector, produces a hidden state, and emits addressing parameters for the read and write heads. The memory is then updated (write) and queried (read), with the read output fed back to the controller for the next step. All operations β attention, memory read, memory write β are differentiable, so the entire system is trained with standard backpropagation through time.
class NeuralTuringMachine(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_slots=128, slot_size=20):
super().__init__()
self.hidden_size = hidden_size
self.n_slots = n_slots
self.slot_size = slot_size
self.memory = MemoryBank(n_slots, slot_size)
self.controller = NTMController(input_size, hidden_size, output_size, slot_size)
def forward(self, x):
"""
Args:
x: (batch, seq_len, input_size)
"""
batch_size, seq_len, _ = x.size()
# Initialize
h = torch.zeros(batch_size, self.hidden_size).to(x.device)
c = torch.zeros(batch_size, self.hidden_size).to(x.device)
read = torch.zeros(batch_size, self.slot_size).to(x.device)
self.memory.reset()
outputs = []
read_weights_history = []
write_weights_history = []
for t in range(seq_len):
# Controller step
out, (h, c), k_r, beta_r, k_w, beta_w, erase, add = self.controller(
x[:, t], (h, c), read
)
# Read from memory
w_read = content_addressing(k_r, self.memory.memory, beta_r)
read = self.memory.read(w_read)
# Write to memory
w_write = content_addressing(k_w, self.memory.memory, beta_w)
self.memory.write(w_write, erase.squeeze(), add.squeeze())
outputs.append(out)
read_weights_history.append(w_read)
write_weights_history.append(w_write)
outputs = torch.stack(outputs, dim=1)
return outputs, read_weights_history, write_weights_history
print("NeuralTuringMachine defined")
Copy TaskΒΆ
The copy task is the canonical benchmark for memory-augmented networks: the model receives a sequence of binary vectors, followed by a delimiter, and must reproduce the sequence from memory. Solving this task requires learning to write each input to a distinct memory location (using location-based addressing with incremental shifts), then read them back in order. Standard RNNs/LSTMs struggle with this task for long sequences because they must compress the entire input into a fixed-size hidden state, whereas the NTM can use its external memory as a buffer with capacity proportional to the memory matrix size.
def generate_copy_data(batch_size, seq_len, n_bits):
"""Generate copy task data."""
# Random binary sequence
seq = torch.randint(0, 2, (batch_size, seq_len, n_bits)).float()
# Input: sequence + delimiter + zeros
delimiter = torch.zeros(batch_size, 1, n_bits)
zeros = torch.zeros(batch_size, seq_len, n_bits)
input_seq = torch.cat([seq, delimiter, zeros], dim=1)
# Target: zeros + sequence
target_seq = torch.cat([zeros, delimiter, seq], dim=1)
return input_seq, target_seq
# Test data generation
x, y = generate_copy_data(2, 5, 8)
print(f"Input shape: {x.shape}, Target shape: {y.shape}")
Train NTMΒΆ
Training the NTM on the copy task uses standard supervised learning: the modelβs output at each time step after the delimiter is compared to the target via binary cross-entropy loss. Learning curves typically show a phase transition: the model struggles for many epochs as it learns the addressing mechanism, then rapidly achieves near-perfect performance once it discovers the correct read/write strategy. Gradient clipping is essential because the backpropagation-through-time gradients can explode when memory addressing weights are near the boundaries of the softmax.
def train_ntm(model, n_epochs=100, seq_len=5, n_bits=8):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []
for epoch in range(n_epochs):
# Generate batch
x, y = generate_copy_data(32, seq_len, n_bits)
x, y = x.to(device), y.to(device)
# Forward
output, _, _ = model(x)
# Loss (binary cross-entropy)
loss = F.binary_cross_entropy_with_logits(output, y)
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
optimizer.step()
losses.append(loss.item())
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
return losses
# Train
ntm = NeuralTuringMachine(
input_size=8, hidden_size=100, output_size=8,
n_slots=128, slot_size=20
).to(device)
losses = train_ntm(ntm, n_epochs=100, seq_len=5, n_bits=8)
Visualize ResultsΒΆ
Visualizing the memory access patterns reveals how the NTM has learned to solve the copy task. During the input phase, write attention weights should show a sequential pattern (writing to consecutive memory locations). During the output phase, read attention weights should retrace the same sequential pattern, retrieving stored values in order. These attention visualizations confirm that the NTM has learned an algorithm, not merely memorized training examples β a qualitative difference from standard neural network learning.
# Test
ntm.eval()
x_test, y_test = generate_copy_data(1, 8, 8)
x_test = x_test.to(device)
with torch.no_grad():
pred, read_weights, write_weights = ntm(x_test)
pred = torch.sigmoid(pred)
# Plot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Loss curve
axes[0, 0].plot(losses)
axes[0, 0].set_xlabel('Epoch', fontsize=11)
axes[0, 0].set_ylabel('Loss', fontsize=11)
axes[0, 0].set_title('Training Loss', fontsize=12)
axes[0, 0].grid(True, alpha=0.3)
# Input
axes[0, 1].imshow(x_test[0].cpu().T, aspect='auto', cmap='Blues')
axes[0, 1].set_xlabel('Time', fontsize=11)
axes[0, 1].set_ylabel('Bit', fontsize=11)
axes[0, 1].set_title('Input Sequence', fontsize=12)
# Prediction
axes[1, 0].imshow(pred[0].cpu().T, aspect='auto', cmap='Blues')
axes[1, 0].set_xlabel('Time', fontsize=11)
axes[1, 0].set_ylabel('Bit', fontsize=11)
axes[1, 0].set_title('Output Prediction', fontsize=12)
# Read weights
read_w = torch.stack(read_weights).squeeze().cpu().numpy()
axes[1, 1].imshow(read_w.T[:50], aspect='auto', cmap='Blues')
axes[1, 1].set_xlabel('Time', fontsize=11)
axes[1, 1].set_ylabel('Memory Slot', fontsize=11)
axes[1, 1].set_title('Read Attention Weights', fontsize=12)
plt.tight_layout()
plt.show()
SummaryΒΆ
Memory Networks:ΒΆ
Components:
External memory matrix
Controller (LSTM/GRU)
Read/write heads
Addressing mechanisms
Addressing:
Content-based (similarity)
Location-based (shift)
Interpolation (gating)
Operations:
Read: weighted sum
Write: erase + add
Applications:ΒΆ
Question answering
Program learning
Graph algorithms
Reasoning tasks
Variants:ΒΆ
NTM: Neural Turing Machine
DNC: Differentiable Neural Computer
End-to-End Memory Networks: Simpler, task-specific
Advanced Memory Networks TheoryΒΆ
1. Introduction to Memory-Augmented Neural NetworksΒΆ
1.1 MotivationΒΆ
Standard neural networks:
Parameters encode all knowledge
No explicit external memory
Struggle with tasks requiring explicit recall
Memory-augmented networks:
Explicit external memory module
Read/write mechanisms
Better for question answering, reasoning, few-shot learning
Key applications:
Question answering: Store facts, retrieve relevant information
One-shot learning: Store examples in memory
Reasoning: Multi-hop reasoning over facts
Dialog systems: Maintain conversation history
1.2 Memory Network ComponentsΒΆ
General architecture:
Memory module M: Array of memory slots M = [mβ, mβ, β¦, mβ]
Each mα΅’ β βα΅ (vector representation)
Input module I: Converts input to internal representation
I: x β u (embed input)
Generalization module G: Updates memory
G: (M_old, I(x)) β M_new
Output module O: Produces response
O: (o, u) β r where o is retrieved memory
Response module R: Converts to output format
R: r β a (final answer)
Forward pass:
u = I(x) # Embed input
o = O(M, u) # Retrieve from memory
r = R(o, u) # Generate response
2. End-to-End Memory Networks (MemN2N)ΒΆ
2.1 Architecture (Sukhbaatar et al., 2015)ΒΆ
Problem: Original memory networks require supervision at every layer. MemN2N is fully differentiable.
Input representation: Given story sentences {sβ, sβ, β¦, sβ} and query q:
Embed each sentence: mα΅’ = Ξ£β±Ό Aβα΅’β±Ό (position encoding)
Embed query: u = Ξ£β±Ό Bqβ±Ό
A, B: Embedding matrices
Single layer attention:
Attention weights (similarity):
pα΅’ = softmax(uα΅mα΅’)
Output representation: Different embedding C
cα΅’ = Ξ£β±Ό Cβα΅’β±Ό
Weighted sum:
o = Ξ£α΅’ pα΅’cα΅’
Update query:
uβ = u + o
2.2 Multiple Hops (Layers)ΒΆ
K computational hops:
uΒΉ = u # Initial query
For k = 1 to K:
p^k = softmax(u^k α΅ m^k) # Attention
o^k = Ξ£α΅’ p^k_i c^k_i # Read
u^(k+1) = u^k + o^k # Update (residual)
Final prediction:
Γ’ = softmax(W(u^(K+1)))
Key insight: Multi-hop allows reasoning over multiple facts.
2.3 Position EncodingΒΆ
Problem: Bag-of-words loses word order.
Solution: Position encoding
mα΅’ = Ξ£β±Ό lβ±Ό Β· Axα΅’β±Ό
where lβ±Ό encodes position j:
lβ±Ό = (1 - j/J) - (k/d)(1 - 2j/J)
J: sentence length, d: embedding dimension, k: dimension index.
Alternative: Use sinusoidal encoding (like Transformers).
2.4 TrainingΒΆ
Loss: Cross-entropy between predicted and true answer
L = -Ξ£β log p(aβ | xβ)
Gradient flow: Backpropagate through all hops.
Weight tying:
Adjacent: A^(k+1) = C^k (reduce parameters)
Layer-wise (RNN-like): AΒΉ = AΒ² = β¦ = A^K
2.5 Complexity AnalysisΒΆ
Time complexity per sample:
Attention computation: O(K Β· N Β· d) where K hops, N memory slots, d embedding dim
Memory size: O(N Β· d)
Space complexity:
Memory: O(N Β· d)
Parameters: O(V Β· d) where V vocabulary size
3. Neural Turing Machines (NTM)ΒΆ
3.1 Architecture (Graves et al., 2014)ΒΆ
Motivation: Learn algorithms from data. Turing-complete neural network.
Components:
Controller: LSTM/Feedforward network
Memory matrix: M_t β β^(NΓM) (N slots, M dimensions each)
Read/write heads: Access memory via attention
3.2 Addressing MechanismsΒΆ
Two modes:
Content-based addressing:
w^c_t[i] = exp(Ξ²_t K(k_t, M_t[i])) / Ξ£β±Ό exp(Ξ²_t K(k_t, M_t[j]))
where:
k_t: Key vector (from controller)
K: Similarity measure (cosine similarity)
Ξ²_t: Key strength (sharpness)
Location-based addressing:
Interpolation (gate between content and previous):
w^g_t = g_t w^c_t + (1 - g_t) w_{t-1}
Convolutional shift:
wΜ_t[i] = Ξ£β±Ό w^g_t[j] s_t[(i-j) mod N]
where s_t is shift distribution (e.g., s_t = [0.1, 0.8, 0.1] for left, stay, right).
Sharpening:
w_t[i] = wΜ_t[i]^Ξ³_t / Ξ£β±Ό wΜ_t[j]^Ξ³_t
where Ξ³_t β₯ 1 (sharpening factor).
3.3 Read and Write OperationsΒΆ
Read:
r_t = Ξ£α΅’ w^r_t[i] M_t[i]
Write (erase + add):
Erase:
MΜ_t[i] = M_{t-1}[i] (1 - w^w_t[i] e_t)
where e_t β [0,1]^M is erase vector.
Add:
M_t[i] = MΜ_t[i] + w^w_t[i] a_t
where a_t β β^M is add vector.
Intuition: Erase old content, write new content (like LSTM gates).
3.4 ControllerΒΆ
LSTM controller outputs:
Key vectors: k^r_t, k^w_t
Key strengths: Ξ²^r_t, Ξ²^w_t
Interpolation gates: g^r_t, g^w_t
Shift distributions: s^r_t, s^w_t
Sharpening: Ξ³^r_t, Ξ³^w_t
Erase vector: e_t
Add vector: a_t
Total parameters: Large (many heads Γ many parameters per head).
3.5 TrainingΒΆ
Supervised learning:
Input/output sequences
Backpropagation through time (BPTT)
Gradient clipping essential
Tasks:
Copy, repeat copy, associative recall
Priority sort, dynamic N-grams
Learning simple algorithms
4. Differentiable Neural Computer (DNC)ΒΆ
4.1 Improvements over NTM (Graves et al., 2016)ΒΆ
Key enhancements:
Dynamic memory allocation: Track memory usage, allocate free slots
Temporal links: Maintain write order for sequential access
Better addressing: Content + allocation + temporal
4.2 Memory Usage TrackingΒΆ
Usage vector: u_t β [0,1]^N (how much each slot is used)
Update:
u_t = (u_{t-1} + w^w_t - u_{t-1} β w^w_t) β Ο_t
where Ο_t β [0,1]^N is retention vector (free slots).
Free gates: f^i_t β [0,1] for each read head i
Ο_t = Ξ α΅’ (1 - f^i_t w^r,i_t)
Free list: Sort memory by usage (ascending)
Ο_t = argsort(u_t)
Allocation weight:
a_t[Ο_t[j]] = (1 - u_t[Ο_t[j]]) Ξ β<β±Ό u_t[Ο_t[k]]
Allocates to least-used available slot.
4.3 Temporal LinksΒΆ
Link matrix: L_t[i,j] = degree to which j was written after i
Update:
L_t[i,j] = (1 - w^w_t[i] - w^w_t[j]) L_{t-1}[i,j] + w^w_t[i] p_{t-1}[j]
where p_t is precedence weighting:
p_t = (1 - Ξ£α΅’ w^w_t[i]) p_{t-1} + w^w_t
Forward/backward weights:
f_t = L_t^T w^r_{t-1} (forward in time)
b_t = L_t w^r_{t-1} (backward in time)
4.4 Final Read WeightΒΆ
Combination of three modes:
w^r_t = Ο^r_t[1] b_t + Ο^r_t[2] c^r_t + Ο^r_t[3] f_t
where Ο^r_t β ΞΒ³ (simplex) determines mode mixture.
Read output:
r_t = M_t^T w^r_t
4.5 Write WeightΒΆ
Allocation + content:
w^w_t = g^w_t (g^a_t a_t + (1 - g^a_t) c^w_t)
where:
g^w_t: Write gate (whether to write)
g^a_t: Allocation gate (new vs content-addressed)
a_t: Allocation weight
c^w_t: Content weight
5. Key-Value Memory NetworksΒΆ
5.1 Architecture (Miller et al., 2016)ΒΆ
Motivation: Separate keys (for addressing) from values (retrieved content).
Memory structure:
Keys: K = {kβ, kβ, β¦, kβ}
Values: V = {vβ, vβ, β¦, vβ}
Retrieval:
Compute attention over keys:
Ξ±_i = softmax(q^T W_K k_i)
Retrieve weighted values:
o = Ξ£α΅’ Ξ±_i W_V v_i
Advantages:
Keys can be short (efficient retrieval)
Values can be rich (informative)
Allows hash-like access
5.2 Knowledge Base as MemoryΒΆ
Key-value pairs from knowledge graph:
Key: Subject-relation embedding
Value: Object embedding
Example:
Key: embed("Paris" + "capital_of")
Value: embed("France")
Query: βWhat is the capital of France?β β Retrieves value associated with key
6. Memory Networks for Few-Shot LearningΒΆ
6.1 Matching Networks (Vinyals et al., 2016)ΒΆ
Idea: Memory = support set in few-shot learning.
Architecture:
Ε· = Ξ£α΅’ a(xΜ, xα΅’) yα΅’
where:
(xα΅’, yα΅’): Support examples in memory
xΜ: Query
a: Attention mechanism
Attention:
a(xΜ, xα΅’) = exp(c(f(xΜ), g(xα΅’))) / Ξ£β±Ό exp(c(f(xΜ), g(xβ±Ό)))
Full context embeddings:
f(xΜ): Bi-LSTM with attention over support set
g(xα΅’): Bi-LSTM encoding
6.2 Prototypical Networks as MemoryΒΆ
Memory = class prototypes:
c_k = (1/|S_k|) Ξ£_{(x,y)βS_k} f_ΞΈ(x)
Retrieval = nearest prototype:
p(y=k|x) = softmax(-d(f_ΞΈ(x), c_k))
7. Transformer as Memory NetworkΒΆ
7.1 Self-Attention as Memory AccessΒΆ
Self-attention = content-addressable memory:
Keys, values from input:
K = XW_K, V = XW_V, Q = XW_Q
Attention (retrieval):
Attention(Q,K,V) = softmax(QK^T/βd_k) V
Properties:
Content-based: Attention determined by QΒ·K similarity
Soft retrieval: Weighted combination (vs hard selection)
Differentiable: End-to-end backprop
7.2 Transformers vs Memory NetworksΒΆ
Similarities:
External memory (key-value pairs)
Attention-based retrieval
Multi-hop reasoning (multiple layers)
Differences:
Aspect |
Transformer |
Memory Network |
|---|---|---|
Memory |
Entire sequence |
Explicit slots |
Capacity |
O(N) (sequence length) |
O(N) (slots) |
Updates |
No (read-only) |
Yes (write operations) |
Architecture |
Uniform layers |
Specialized components |
7.3 Memory-Augmented TransformersΒΆ
Combine both:
Compressive Transformer: Compress old activations into memory
β-former: Unbounded long-term memory via retrieval
Memorizing Transformers: kNN over past key-values
8. Advances in Memory MechanismsΒΆ
8.1 Sparse Access PatternsΒΆ
Problem: O(N) attention over all memory expensive.
Solutions:
Product-key memory (PKM):
Key = concat(kβ, kβ) where kβ, kβ from separate codebooks
Retrieve top-k from each codebook
Final candidates = kβ Γ kβ (small)
Complexity: O(βN) instead of O(N)
Locality-sensitive hashing (LSH):
Hash similar keys to same buckets
Only attend within bucket
Complexity: O(N/B) where B buckets
8.2 External Knowledge BasesΒΆ
Retrieval-augmented generation (RAG):
Architecture:
Retriever: Finds relevant documents from corpus
Dense retrieval (DPR): Embed query, retrieve via similarity
Sparse retrieval (BM25): Keyword-based
Generator: LLM generates answer conditioned on retrieved docs
p(y|x) = Ξ£β p(d_k|x) p(y|x, d_k)
Training:
End-to-end: Backprop through retrieval (MIPS approximation)
Pipeline: Train retriever, then generator separately
State-of-the-art:
REALM (2020): Retrieve from millions of documents
RAG (2020): Wikipedia as memory
Atlas (2022): Few-shot learning via retrieval
8.3 Learnable Memory TokensΒΆ
Perceiver, Set Transformers:
Learnable latent vectors:
Z β β^(MΓd) (M << N)
Cross-attention to compress input:
Z = Attention(Q=Z, K=X, V=X)
Benefits:
Fixed-size bottleneck (memory)
O(NΒ·M) complexity instead of O(NΒ²)
Learnable compression
9. Training TechniquesΒΆ
9.1 Curriculum LearningΒΆ
Memory tasks benefit from curriculum:
Length curriculum: Start with short sequences, increase gradually
Hop curriculum: Single-hop first, then multi-hop
Size curriculum: Small memory, then large
Example (MemN2N on bAbI):
1K training: 93% accuracy
10K training: 95.8% accuracy
Curriculum: 96.4% accuracy
9.2 RegularizationΒΆ
Memory networks prone to overfitting:
Dropout on attention weights:
pΜα΅’ = dropout(pα΅’)
pΜα΅’ β pΜα΅’ / Ξ£β±Ό pΜβ±Ό (renormalize)
Weight decay on memory embeddings
Attention entropy regularization:
L_entropy = -Ξ£α΅’ pα΅’ log pα΅’
Encourages diverse attention.
9.3 Pre-trainingΒΆ
Transfer learning for memory networks:
Pre-train embeddings: Word2Vec, GloVe, BERT
Pre-train on related tasks: General QA β specific domain
Multi-task learning: Train on multiple QA datasets jointly
10. Evaluation and BenchmarksΒΆ
10.1 bAbI Tasks (Facebook)ΒΆ
20 synthetic reasoning tasks:
Examples:
Task 1: Single supporting fact
βMary went to the bathroom. John moved to the hallway. Where is Mary? β bathroomβ
Task 2: Two supporting facts
βJohn is in the playground. John picked up the football. Where is the football? β playgroundβ
Task 3: Three supporting facts (harder)
Metrics:
Per-task accuracy
Mean accuracy across tasks
Failed tasks (< 95% threshold)
Baseline results:
MemN2N: 0.5 failed tasks (almost perfect)
LSTM: 11.5 failed tasks
Human: 0 failed tasks
10.2 Question AnsweringΒΆ
SQuAD (reading comprehension):
Given paragraph + question β extract answer span
Memory network: Store sentences, retrieve relevant ones
Results (SQuAD v1.1):
MemN2N: 77.2% F1
Bidirectional Attention Flow (BiDAF): 84.1% F1
BERT: 93.2% F1 (pre-training helps!)
Natural Questions (Google):
Real Google search queries
Long answers from Wikipedia
Memory networks competitive on retrieving relevant paragraphs
10.3 Few-Shot LearningΒΆ
Omniglot (character recognition):
Matching Networks: 98.1% (5-way 5-shot)
Prototypical Networks: 98.8%
miniImageNet:
Matching Networks: 46.6% (5-way 1-shot)
Memory-augmented networks: Store examples, retrieve similar
11. Architectural VariantsΒΆ
11.1 Recurrent MemoryΒΆ
MANN (Memory-Augmented Neural Network for Meta-Learning):
One-shot learning via external memory:
Read from memory (similar past examples)
Controller processes input + read
Write to memory (store current example)
Key idea: Memory persists across episodes (unlike weights).
11.2 Hierarchical MemoryΒΆ
Multi-scale memory:
Short-term: Recent context (high resolution)
Medium-term: Compressed recent history
Long-term: Semantic summaries
Example: Hierarchical Transformer
Level 1: Token-level (1K tokens)
Level 2: Sentence-level (100 sentences)
Level 3: Paragraph-level (10 paragraphs)
11.3 Associative MemoryΒΆ
Hopfield Networks as memory:
Modern Hopfield layer:
Output = softmax(Ξ²Xα΅ΞΎ)α΅ X
where:
X: Stored patterns
ΞΎ: Query pattern
Ξ²: Inverse temperature
Properties:
Retrieves pattern most similar to query
Exponential storage capacity (vs linear in classical Hopfield)
Integration with Transformers:
Hopfield layer β attention layer
Can replace or augment self-attention
12. Complexity and ScalabilityΒΆ
12.1 Time ComplexityΒΆ
Memory network forward pass:
Operation |
Complexity |
|---|---|
Embedding |
O(L Β· d) where L input length |
Attention (single hop) |
O(N Β· d) where N memory slots |
K hops |
O(K Β· N Β· d) |
Output |
O(V) where V vocab size |
Total: O(LΒ·d + KΒ·NΒ·d + V)
Bottleneck: Attention over large memory (N large).
12.2 Space ComplexityΒΆ
Storage:
Memory: O(N Β· d)
Parameters: O(V Β· d + dΒ²) (embeddings + transforms)
Activations (training):
Per hop: O(N) (attention weights)
K hops: O(K Β· N)
12.3 Scalability SolutionsΒΆ
Large-scale memory (millions of slots):
Approximate nearest neighbors:
FAISS, ScaNN for fast retrieval
O(log N) or O(βN) lookup
Memory pruning:
Keep top-k attended slots
Discard low-attention memories
Hierarchical retrieval:
Cluster memory, retrieve cluster first
Then retrieve within cluster
Distributed memory:
Shard memory across GPUs/machines
All-reduce for aggregation
13. ApplicationsΒΆ
13.1 Question Answering SystemsΒΆ
Conversational QA:
Store conversation history in memory
Retrieve relevant turns for context
Generate contextual response
Multi-hop reasoning:
HotpotQA: Requires reasoning over 2+ paragraphs
Memory network: Store all paragraphs, multi-hop retrieval
13.2 Language ModelingΒΆ
Long-range dependencies:
Standard Transformer: O(NΒ²) attention limits context
Memory-augmented: Extend context via external memory
Example: Transformer-XL
Segment-level recurrence
Cache previous segmentβs hidden states
Attend to current + cached
13.3 Dialog SystemsΒΆ
Persistent memory of facts:
User preferences, past conversations
Retrieve relevant history for personalization
End-to-end dialog (MemN2N for dialog):
Memory = dialog history + knowledge base
Query = current user utterance
Response = generated from retrieved context
13.4 Program SynthesisΒΆ
Neural Turing Machine tasks:
Copy, reverse, sort sequences
Learn algorithms from input-output examples
Results:
NTM learns perfect copy (vs LSTM fails)
Generalizes to longer sequences
13.5 Visual Question AnsweringΒΆ
Image + question β answer:
Memory structure:
Memory slots = image regions (from CNN features)
Query = question embedding
Retrieval = attend to relevant regions
Attention visualization:
Shows which image regions model focuses on
Interpretable decision-making
14. Recent Advances (2020-2024)ΒΆ
14.1 Infinite Memory TransformersΒΆ
β-former (Martins et al., 2022):
Unbounded memory via kNN retrieval
Store all past (key, value) pairs
Retrieve top-k for each query
Memorizing Transformers (Wu et al., 2022):
External memory = all past activations
kNN retrieval from memory
Interpolate between attention and kNN
14.2 Retrieval-Augmented LLMsΒΆ
RETRO (DeepMind, 2022):
Retrieve chunks from trillion-token database
Cross-attend to retrieved chunks
25Γ fewer parameters for same performance
Atlas (Meta, 2022):
Few-shot learning via retrieval
Retrieves from Wikipedia
State-of-the-art on MMLU
14.3 Memory for Long ContextΒΆ
Compressive Transformers:
Compress old memories (activations)
Keep recent memories (full resolution)
Lossy compression for efficiency
Landmark attention:
Designate certain tokens as βlandmarksβ
All tokens attend to landmarks
Reduces memory from O(NΒ²) to O(NΒ·L) where L << N
15. Limitations and ChallengesΒΆ
15.1 Known IssuesΒΆ
Scalability:
O(N) attention expensive for large memory
Trade-off: memory size vs speed
Catastrophic forgetting:
Overwriting important memories
Need memory management strategies
Interpretability:
Attention weights donβt always reflect reasoning
βAttention is not explanationβ debate
Generalization:
May memorize training examples
Doesnβt generalize to truly novel situations
15.2 Open ProblemsΒΆ
Continual learning:
How to update memory without forgetting?
Meta-learning over memory update rules
Compositional reasoning:
Combine retrieved facts compositionally
Neural-symbolic integration
Efficient large-scale memory:
Billions of memory slots
Sub-linear retrieval time
16. Comparison with Human MemoryΒΆ
16.1 ParallelsΒΆ
Working memory:
Limited capacity (~7 items)
MemN2N attention similar (focuses on few slots)
Long-term memory:
Associative retrieval (content-addressable)
Consolidation (important memories strengthened)
Episodic memory:
NTM temporal links β sequential memory
16.2 DifferencesΒΆ
Human advantages:
Highly efficient encoding/retrieval
Rich compositional reasoning
Transfer to novel domains
Neural network advantages:
Perfect recall (no degradation)
Massive parallel access
Explicit gradient-based learning
17. Future DirectionsΒΆ
17.1 Research FrontiersΒΆ
Neurosymbolic memory:
Combine neural retrieval with symbolic reasoning
Logic rules + soft attention
Lifelong learning:
Memory that grows over lifetime
Selective consolidation/forgetting
Multi-modal memory:
Unified memory for text, images, audio, video
Cross-modal retrieval
17.2 Practical DeploymentΒΆ
Efficiency:
Low-rank approximations for attention
Quantization of memory slots
Pruning redundant memories
Privacy:
Federated memory (distributed across users)
Differential privacy for retrieval
18. Key TakeawaysΒΆ
Memory augmentation enables explicit retrieval:
Better than encoding everything in parameters
Crucial for QA, reasoning, few-shot learning
Attention = soft content-addressable memory:
Differentiable retrieval mechanism
Enables end-to-end learning
Multi-hop reasoning requires multiple attention layers:
Each hop refines the query
Combines information from multiple sources
Scalability is key challenge:
O(N) attention limits memory size
Solutions: sparse access, hierarchical retrieval, external KB
Many architectures = variations on memory theme:
Transformers: Self-attention over sequence
MemN2N: Attention over explicit memory slots
NTM/DNC: Read/write operations
Recent trend: Retrieval-augmented LLMs:
External knowledge bases as memory
kNN over trillions of tokens
State-of-the-art few-shot learning
19. ReferencesΒΆ
Foundational papers:
Weston et al. (2015): βMemory Networksβ (ICLR)
Sukhbaatar et al. (2015): βEnd-To-End Memory Networksβ (NeurIPS)
Graves et al. (2014): βNeural Turing Machinesβ (arXiv)
Graves et al. (2016): βHybrid Computing Using a Neural Network with Dynamic External Memoryβ (Nature)
Few-shot learning:
Vinyals et al. (2016): βMatching Networks for One Shot Learningβ (NeurIPS)
Snell et al. (2017): βPrototypical Networks for Few-shot Learningβ (NeurIPS)
Recent advances:
Borgeaud et al. (2022): βImproving Language Models by Retrieving from Trillions of Tokensβ (ICML - RETRO)
Wu et al. (2022): βMemorizing Transformersβ (ICLR)
Izacard et al. (2022): βAtlas: Few-shot Learning with Retrieval Augmented Language Modelsβ (arXiv)
Surveys:
Santoro et al. (2018): βRelational Recurrent Neural Networksβ (NeurIPS)
Graves et al. (2016): βNeural Turing Machinesβ tutorial
"""
Complete Memory Network Implementations
========================================
Includes: End-to-End Memory Networks (MemN2N), Neural Turing Machine (NTM),
Key-Value Memory, Matching Networks, external memory mechanisms.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
# ============================================================================
# 1. End-to-End Memory Network (MemN2N)
# ============================================================================
class PositionEncoding(nn.Module):
"""
Position encoding for MemN2N (Sukhbaatar et al., 2015).
Encodes position j in sentence with vector l_j:
l_kj = (1 - j/J) - (k/d)(1 - 2j/J)
Args:
sentence_len: Maximum sentence length J
embed_dim: Embedding dimension d
"""
def __init__(self, sentence_len, embed_dim):
super(PositionEncoding, self).__init__()
self.sentence_len = sentence_len
self.embed_dim = embed_dim
# Precompute position encoding
encoding = torch.zeros(sentence_len, embed_dim)
for j in range(sentence_len):
for k in range(embed_dim):
encoding[j, k] = (1 - (j+1)/sentence_len) - \
(k/embed_dim) * (1 - 2*(j+1)/sentence_len)
self.register_buffer('encoding', encoding)
def forward(self, x):
"""
Args:
x: Embeddings [batch, num_sentences, sentence_len, embed_dim]
Returns:
Position-encoded embeddings [batch, num_sentences, embed_dim]
"""
# Sum over sentence length dimension weighted by position
batch, num_sent, sent_len, d = x.size()
encoding = self.encoding[:sent_len, :].unsqueeze(0).unsqueeze(0)
return (x * encoding).sum(dim=2)
class MemN2NLayer(nn.Module):
"""
Single hop of End-to-End Memory Network.
Performs:
1. Attention: p = softmax(u^T m)
2. Read: o = Ξ£ p_i c_i
3. Update: u' = u + o
Args:
embed_dim: Embedding dimension
num_hops: Total number of hops (for weight tying)
hop_idx: Index of this hop
"""
def __init__(self, embed_dim, num_hops, hop_idx):
super(MemN2NLayer, self).__init__()
self.embed_dim = embed_dim
self.hop_idx = hop_idx
def forward(self, u, m, c):
"""
Args:
u: Query embedding [batch, embed_dim]
m: Memory embeddings for attention [batch, num_sentences, embed_dim]
c: Memory embeddings for output [batch, num_sentences, embed_dim]
Returns:
o: Output vector [batch, embed_dim]
p: Attention weights [batch, num_sentences]
"""
# Compute attention: p_i = softmax(u^T m_i)
p = torch.matmul(m, u.unsqueeze(-1)).squeeze(-1) # [batch, num_sentences]
p = F.softmax(p, dim=-1)
# Read from memory: o = Ξ£ p_i c_i
o = torch.matmul(p.unsqueeze(1), c).squeeze(1) # [batch, embed_dim]
return o, p
class EndToEndMemoryNetwork(nn.Module):
"""
End-to-End Memory Network (Sukhbaatar et al., 2015).
Multi-hop reasoning over memory with differentiable attention.
Args:
vocab_size: Size of vocabulary
embed_dim: Embedding dimension
num_hops: Number of computational hops
sentence_len: Maximum sentence length
memory_size: Maximum number of memory slots
weight_tying: 'adjacent' or 'layer_wise' or None
"""
def __init__(self, vocab_size, embed_dim, num_hops=3, sentence_len=10,
memory_size=50, weight_tying='adjacent'):
super(EndToEndMemoryNetwork, self).__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.num_hops = num_hops
self.sentence_len = sentence_len
self.memory_size = memory_size
self.weight_tying = weight_tying
# Embeddings for each hop
if weight_tying == 'layer_wise':
# Share all embeddings across hops
self.A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)])
self.C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)])
else:
self.A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)
for _ in range(num_hops)])
self.C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim)
for _ in range(num_hops)])
# Query embedding
self.B = nn.Embedding(vocab_size, embed_dim)
# Position encoding
self.position_encoding = PositionEncoding(sentence_len, embed_dim)
# Output layer
self.W = nn.Linear(embed_dim, vocab_size)
# Temporal encoding (optional)
self.temporal_A = nn.Parameter(torch.randn(memory_size, embed_dim))
self.temporal_C = nn.Parameter(torch.randn(memory_size, embed_dim))
def forward(self, story, query):
"""
Args:
story: Story sentences [batch, num_sentences, sentence_len]
query: Query sentence [batch, query_len]
Returns:
logits: Answer logits [batch, vocab_size]
attentions: Attention weights per hop [batch, num_hops, num_sentences]
"""
batch_size = story.size(0)
num_sentences = story.size(1)
# Embed query: u = Ξ£_j B q_j
u = self.B(query) # [batch, query_len, embed_dim]
u = self.position_encoding(u.unsqueeze(1)).squeeze(1) # [batch, embed_dim]
attentions = []
# Multi-hop reasoning
for k in range(self.num_hops):
# Get embeddings for this hop
if self.weight_tying == 'layer_wise':
A_k, C_k = self.A[0], self.C[0]
elif self.weight_tying == 'adjacent' and k > 0:
A_k = self.C[k-1]
C_k = self.C[k]
else:
A_k, C_k = self.A[k], self.C[k]
# Embed story for attention: m_i = Ξ£_j A x_{ij}
m = A_k(story) # [batch, num_sentences, sentence_len, embed_dim]
m = self.position_encoding(m) # [batch, num_sentences, embed_dim]
# Add temporal encoding
m = m + self.temporal_A[:num_sentences, :]
# Embed story for output: c_i = Ξ£_j C x_{ij}
c = C_k(story) # [batch, num_sentences, sentence_len, embed_dim]
c = self.position_encoding(c) # [batch, num_sentences, embed_dim]
c = c + self.temporal_C[:num_sentences, :]
# Attention and read
layer = MemN2NLayer(self.embed_dim, self.num_hops, k)
o, p = layer(u, m, c)
attentions.append(p)
# Update query (residual)
u = u + o
# Final prediction
logits = self.W(u)
attentions = torch.stack(attentions, dim=1) # [batch, num_hops, num_sentences]
return logits, attentions
# ============================================================================
# 2. Neural Turing Machine (NTM) Components
# ============================================================================
class NTMMemory(nn.Module):
"""
Memory module for Neural Turing Machine.
Args:
memory_size: Number of memory slots N
memory_dim: Dimension of each memory slot M
"""
def __init__(self, memory_size, memory_dim):
super(NTMMemory, self).__init__()
self.memory_size = memory_size
self.memory_dim = memory_dim
# Initialize memory
self.register_buffer('memory', torch.zeros(memory_size, memory_dim))
def reset(self, batch_size):
"""Initialize memory for batch."""
self.memory = torch.zeros(batch_size, self.memory_size, self.memory_dim,
device=self.memory.device)
def read(self, w):
"""
Read from memory with attention weights.
Args:
w: Read weights [batch, memory_size]
Returns:
Read vector [batch, memory_dim]
"""
return torch.matmul(w.unsqueeze(1), self.memory).squeeze(1)
def write(self, w, e, a):
"""
Write to memory (erase + add).
Args:
w: Write weights [batch, memory_size]
e: Erase vector [batch, memory_dim]
a: Add vector [batch, memory_dim]
"""
# Erase: M_i = M_i * (1 - w_i * e)
erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1)) # [batch, N, M]
self.memory = self.memory * (1 - erase)
# Add: M_i = M_i + w_i * a
add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1)) # [batch, N, M]
self.memory = self.memory + add
class NTMHead(nn.Module):
"""
Read/Write head for Neural Turing Machine.
Implements content-based + location-based addressing.
Args:
memory_size: Number of memory slots
memory_dim: Dimension of each slot
controller_dim: Dimension of controller output
"""
def __init__(self, memory_size, memory_dim, controller_dim):
super(NTMHead, self).__init__()
self.memory_size = memory_size
self.memory_dim = memory_dim
# Content addressing
self.key_layer = nn.Linear(controller_dim, memory_dim)
self.beta_layer = nn.Linear(controller_dim, 1) # Key strength
# Location addressing
self.g_layer = nn.Linear(controller_dim, 1) # Interpolation gate
self.shift_layer = nn.Linear(controller_dim, 3) # Shift distribution (left, stay, right)
self.gamma_layer = nn.Linear(controller_dim, 1) # Sharpening
# Write-specific
self.erase_layer = nn.Linear(controller_dim, memory_dim)
self.add_layer = nn.Linear(controller_dim, memory_dim)
def content_addressing(self, k, beta, memory):
"""
Content-based addressing via cosine similarity.
Args:
k: Key vector [batch, memory_dim]
beta: Key strength [batch, 1]
memory: Memory matrix [batch, memory_size, memory_dim]
Returns:
Content weights [batch, memory_size]
"""
# Cosine similarity
k_norm = k / (k.norm(dim=1, keepdim=True) + 1e-8)
mem_norm = memory / (memory.norm(dim=2, keepdim=True) + 1e-8)
similarity = torch.matmul(mem_norm, k_norm.unsqueeze(-1)).squeeze(-1)
# Softmax with beta
w_c = F.softmax(beta * similarity, dim=1)
return w_c
def location_addressing(self, w_c, w_prev, g, s, gamma):
"""
Location-based addressing (interpolation + shift + sharpen).
Args:
w_c: Content weights [batch, memory_size]
w_prev: Previous weights [batch, memory_size]
g: Interpolation gate [batch, 1]
s: Shift distribution [batch, 3]
gamma: Sharpening factor [batch, 1]
Returns:
Final weights [batch, memory_size]
"""
# Interpolation
g = torch.sigmoid(g)
w_g = g * w_c + (1 - g) * w_prev
# Convolutional shift
w_shifted = self._circular_conv(w_g, s)
# Sharpening
gamma = 1 + F.softplus(gamma)
w = w_shifted ** gamma
w = w / (w.sum(dim=1, keepdim=True) + 1e-8)
return w
def _circular_conv(self, w, s):
"""Circular convolution for shift."""
batch_size = w.size(0)
shifted = torch.zeros_like(w)
for b in range(batch_size):
for i in range(self.memory_size):
for j in range(3):
idx = (i + j - 1) % self.memory_size
shifted[b, i] += w[b, idx] * s[b, j]
return shifted
def forward(self, h, memory, w_prev):
"""
Compute attention weights.
Args:
h: Controller output [batch, controller_dim]
memory: Memory matrix [batch, memory_size, memory_dim]
w_prev: Previous weights [batch, memory_size]
Returns:
w: Attention weights [batch, memory_size]
read_vec: Read vector [batch, memory_dim] (if read head)
"""
# Content addressing
k = torch.tanh(self.key_layer(h))
beta = F.softplus(self.beta_layer(h))
w_c = self.content_addressing(k, beta, memory)
# Location addressing
g = self.g_layer(h)
s = F.softmax(self.shift_layer(h), dim=1)
gamma = self.gamma_layer(h)
w = self.location_addressing(w_c, w_prev, g, s, gamma)
return w
class NeuralTuringMachine(nn.Module):
"""
Neural Turing Machine (Graves et al., 2014).
Controller (LSTM) + External memory + Read/Write heads.
Args:
input_dim: Input dimension
output_dim: Output dimension
controller_dim: LSTM hidden dimension
memory_size: Number of memory slots
memory_dim: Dimension of each slot
num_read_heads: Number of read heads
num_write_heads: Number of write heads
"""
def __init__(self, input_dim, output_dim, controller_dim=100,
memory_size=128, memory_dim=20, num_read_heads=1, num_write_heads=1):
super(NeuralTuringMachine, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.controller_dim = controller_dim
self.memory_size = memory_size
self.memory_dim = memory_dim
self.num_read_heads = num_read_heads
self.num_write_heads = num_write_heads
# Controller (LSTM)
controller_input_dim = input_dim + num_read_heads * memory_dim
self.controller = nn.LSTMCell(controller_input_dim, controller_dim)
# Memory
self.memory = NTMMemory(memory_size, memory_dim)
# Read heads
self.read_heads = nn.ModuleList([
NTMHead(memory_size, memory_dim, controller_dim)
for _ in range(num_read_heads)
])
# Write heads
self.write_heads = nn.ModuleList([
NTMHead(memory_size, memory_dim, controller_dim)
for _ in range(num_write_heads)
])
# Output layer
output_input_dim = controller_dim + num_read_heads * memory_dim
self.output_layer = nn.Linear(output_input_dim, output_dim)
def forward(self, x, prev_state=None):
"""
Single step forward.
Args:
x: Input [batch, input_dim]
prev_state: Previous state (h_c, c_c, read_vecs, read_w, write_w)
Returns:
output: Output [batch, output_dim]
state: New state
"""
batch_size = x.size(0)
if prev_state is None:
# Initialize state
h_c = torch.zeros(batch_size, self.controller_dim, device=x.device)
c_c = torch.zeros(batch_size, self.controller_dim, device=x.device)
read_vecs = [torch.zeros(batch_size, self.memory_dim, device=x.device)
for _ in range(self.num_read_heads)]
read_w = [torch.zeros(batch_size, self.memory_size, device=x.device)
for _ in range(self.num_read_heads)]
write_w = [torch.zeros(batch_size, self.memory_size, device=x.device)
for _ in range(self.num_write_heads)]
self.memory.reset(batch_size)
else:
h_c, c_c, read_vecs, read_w, write_w = prev_state
# Controller input: x + previous reads
controller_input = torch.cat([x] + read_vecs, dim=1)
# Controller forward
h_c, c_c = self.controller(controller_input, (h_c, c_c))
# Read from memory
new_read_vecs = []
new_read_w = []
for i, head in enumerate(self.read_heads):
w = head(h_c, self.memory.memory, read_w[i])
r = self.memory.read(w)
new_read_vecs.append(r)
new_read_w.append(w)
# Write to memory
new_write_w = []
for i, head in enumerate(self.write_heads):
w = head(h_c, self.memory.memory, write_w[i])
e = torch.sigmoid(head.erase_layer(h_c))
a = torch.tanh(head.add_layer(h_c))
self.memory.write(w, e, a)
new_write_w.append(w)
# Output
output_input = torch.cat([h_c] + new_read_vecs, dim=1)
output = self.output_layer(output_input)
state = (h_c, c_c, new_read_vecs, new_read_w, new_write_w)
return output, state
# ============================================================================
# 3. Key-Value Memory Network
# ============================================================================
class KeyValueMemoryNetwork(nn.Module):
"""
Key-Value Memory Network (Miller et al., 2016).
Separate keys (for addressing) from values (retrieved content).
Args:
key_dim: Dimension of keys
value_dim: Dimension of values
query_dim: Dimension of queries
num_hops: Number of hops
"""
def __init__(self, key_dim, value_dim, query_dim, num_hops=1):
super(KeyValueMemoryNetwork, self).__init__()
self.key_dim = key_dim
self.value_dim = value_dim
self.query_dim = query_dim
self.num_hops = num_hops
# Key transformation
self.W_K = nn.Linear(query_dim, key_dim)
# Value transformation
self.W_V = nn.Linear(value_dim, query_dim)
# Output transformation
self.W_O = nn.Linear(query_dim, query_dim)
def forward(self, query, keys, values):
"""
Args:
query: Query vector [batch, query_dim]
keys: Key matrix [batch, num_keys, key_dim]
values: Value matrix [batch, num_values, value_dim]
Returns:
output: Retrieved information [batch, query_dim]
attention: Attention weights [batch, num_keys]
"""
q = query
for _ in range(self.num_hops):
# Transform query to key space
q_k = self.W_K(q) # [batch, key_dim]
# Attention over keys
scores = torch.matmul(keys, q_k.unsqueeze(-1)).squeeze(-1) # [batch, num_keys]
attention = F.softmax(scores, dim=-1)
# Retrieve values
v = torch.matmul(attention.unsqueeze(1), values).squeeze(1) # [batch, value_dim]
v = self.W_V(v) # [batch, query_dim]
# Update query (residual)
q = q + v
output = self.W_O(q)
return output, attention
# ============================================================================
# 4. Matching Networks (for Few-Shot Learning)
# ============================================================================
class MatchingNetwork(nn.Module):
"""
Matching Networks (Vinyals et al., 2016).
Memory = support set, retrieval = attention-weighted sum.
Args:
input_dim: Input feature dimension
hidden_dim: Hidden dimension
num_classes: Number of classes (for output)
"""
def __init__(self, input_dim, hidden_dim, num_classes):
super(MatchingNetwork, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
# Feature encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Attention mechanism (cosine similarity)
def attention(self, query, support):
"""
Cosine similarity attention.
Args:
query: Query embedding [batch, hidden_dim]
support: Support embeddings [batch, num_support, hidden_dim]
Returns:
Attention weights [batch, num_support]
"""
query_norm = query / (query.norm(dim=1, keepdim=True) + 1e-8)
support_norm = support / (support.norm(dim=2, keepdim=True) + 1e-8)
similarity = torch.matmul(support_norm, query_norm.unsqueeze(-1)).squeeze(-1)
attention = F.softmax(similarity, dim=-1)
return attention
def forward(self, query, support, support_labels):
"""
Args:
query: Query examples [batch, input_dim]
support: Support examples [batch, num_support, input_dim]
support_labels: Support labels [batch, num_support] (class indices)
Returns:
predictions: Class probabilities [batch, num_classes]
"""
# Encode
query_embed = self.encoder(query) # [batch, hidden_dim]
support_embed = self.encoder(support.view(-1, self.input_dim))
support_embed = support_embed.view(support.size(0), support.size(1), -1)
# Attention
attn = self.attention(query_embed, support_embed) # [batch, num_support]
# Weighted sum over support labels (one-hot encoded)
support_one_hot = F.one_hot(support_labels, self.num_classes).float()
predictions = torch.matmul(attn.unsqueeze(1), support_one_hot).squeeze(1)
return predictions
# ============================================================================
# 5. Demonstrations
# ============================================================================
def demo_memn2n():
"""Demonstrate End-to-End Memory Network."""
print("="*70)
print("End-to-End Memory Network (MemN2N) Demo")
print("="*70)
# Hyperparameters
vocab_size = 100
embed_dim = 32
num_hops = 3
sentence_len = 10
memory_size = 20
batch_size = 4
# Create model
model = EndToEndMemoryNetwork(
vocab_size=vocab_size,
embed_dim=embed_dim,
num_hops=num_hops,
sentence_len=sentence_len,
memory_size=memory_size,
weight_tying='adjacent'
)
# Sample input (toy bAbI-like data)
story = torch.randint(0, vocab_size, (batch_size, memory_size, sentence_len))
query = torch.randint(0, vocab_size, (batch_size, sentence_len))
# Forward
model.eval()
with torch.no_grad():
logits, attentions = model(story, query)
print(f"Input:")
print(f" Story shape: {story.shape} (batch, num_sentences, sentence_len)")
print(f" Query shape: {query.shape} (batch, query_len)")
print()
print(f"Output:")
print(f" Logits shape: {logits.shape} (batch, vocab_size)")
print(f" Attentions shape: {attentions.shape} (batch, num_hops, num_sentences)")
print()
print(f"Attention visualization (first sample, first hop):")
print(f" {attentions[0, 0, :10].numpy()}")
print(f" Sum: {attentions[0, 0].sum().item():.4f} (should be ~1.0)")
print()
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
print()
def demo_ntm():
"""Demonstrate Neural Turing Machine."""
print("="*70)
print("Neural Turing Machine (NTM) Demo")
print("="*70)
# Hyperparameters (copy task)
input_dim = 8
output_dim = 8
controller_dim = 100
memory_size = 128
memory_dim = 20
batch_size = 2
seq_len = 5
# Create model
model = NeuralTuringMachine(
input_dim=input_dim,
output_dim=output_dim,
controller_dim=controller_dim,
memory_size=memory_size,
memory_dim=memory_dim,
num_read_heads=1,
num_write_heads=1
)
# Sample input (binary sequences for copy task)
x = torch.randint(0, 2, (batch_size, seq_len, input_dim)).float()
# Forward (sequence processing)
model.eval()
outputs = []
state = None
with torch.no_grad():
for t in range(seq_len):
output, state = model(x[:, t, :], state)
outputs.append(output)
outputs = torch.stack(outputs, dim=1) # [batch, seq_len, output_dim]
print(f"Task: Copy binary sequences")
print(f" Input shape: {x.shape}")
print(f" Output shape: {outputs.shape}")
print()
print(f"Architecture:")
print(f" Controller: LSTM with {controller_dim} units")
print(f" Memory: {memory_size} slots Γ {memory_dim} dimensions")
print(f" Read heads: 1, Write heads: 1")
print()
# Parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
print()
print("Note: NTM can learn to copy sequences of arbitrary length!")
print()
def demo_kv_memory():
"""Demonstrate Key-Value Memory Network."""
print("="*70)
print("Key-Value Memory Network Demo")
print("="*70)
# Hyperparameters
key_dim = 64
value_dim = 128
query_dim = 64
num_hops = 2
batch_size = 4
num_keys = 10
# Create model
model = KeyValueMemoryNetwork(
key_dim=key_dim,
value_dim=value_dim,
query_dim=query_dim,
num_hops=num_hops
)
# Sample data
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, num_keys, key_dim)
values = torch.randn(batch_size, num_keys, value_dim)
# Forward
model.eval()
with torch.no_grad():
output, attention = model(query, keys, values)
print(f"Input:")
print(f" Query shape: {query.shape}")
print(f" Keys shape: {keys.shape}")
print(f" Values shape: {values.shape}")
print()
print(f"Output:")
print(f" Retrieved shape: {output.shape}")
print(f" Attention shape: {attention.shape}")
print()
print(f"Attention weights (first sample):")
print(f" {attention[0].numpy()}")
print(f" Max attention: {attention[0].max().item():.4f}")
print()
print("Use case: Knowledge base QA (keys=entities, values=descriptions)")
print()
def demo_matching_network():
"""Demonstrate Matching Network (few-shot learning)."""
print("="*70)
print("Matching Network Demo (Few-Shot Learning)")
print("="*70)
# 5-way 1-shot task
input_dim = 128
hidden_dim = 64
num_classes = 5
num_support = 5 # 1 per class
batch_size = 8
# Create model
model = MatchingNetwork(
input_dim=input_dim,
hidden_dim=hidden_dim,
num_classes=num_classes
)
# Sample data (random features, in practice from CNN)
query = torch.randn(batch_size, input_dim)
support = torch.randn(batch_size, num_support, input_dim)
support_labels = torch.arange(num_classes).unsqueeze(0).expand(batch_size, -1)
# Forward
model.eval()
with torch.no_grad():
predictions = model(query, support, support_labels)
print(f"Task: 5-way 1-shot classification")
print(f" Query shape: {query.shape}")
print(f" Support shape: {support.shape}")
print(f" Support labels: {support_labels[0].tolist()}")
print()
print(f"Predictions shape: {predictions.shape} (batch, num_classes)")
print(f"Sample prediction (first query): {predictions[0].numpy()}")
print(f"Predicted class: {predictions[0].argmax().item()}")
print()
print("Memory = Support Set | Retrieval = Attention-weighted classification")
print()
def print_complexity_comparison():
"""Print complexity comparison of memory architectures."""
print("="*70)
print("Computational Complexity Comparison")
print("="*70)
print()
comparison = """
ββββββββββββββββββ¬βββββββββββββββββββ¬βββββββββββββββββββ¬ββββββββββββββββββ
β Architecture β Time (per step) β Space (memory) β Scalability β
ββββββββββββββββββΌβββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββ€
β MemN2N β O(KΒ·NΒ·d) β O(NΒ·d) β Medium β
β β K hops, N slots β β (hundreds) β
ββββββββββββββββββΌβββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββ€
β NTM β O(NΒ·M + CΒ²) β O(NΒ·M) β Low β
β β N slots, M dims β β (128 slots) β
β β C controller β β β
ββββββββββββββββββΌβββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββ€
β DNC β O(NΒ·M + CΒ²) β O(NΒ·M + NΒ²) β Medium β
β β + temporal links β (link matrix) β (256 slots) β
ββββββββββββββββββΌβββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββ€
β Key-Value β O(NΒ·K + NΒ·V) β O(NΒ·(K+V)) β High β
β β K,V key/val dims β β (thousands) β
ββββββββββββββββββΌβββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββ€
β Transformer β O(NΒ²Β·d) β O(NΒ·d) β High β
β β Self-attention β β (with tricks) β
ββββββββββββββββββ΄βββββββββββββββββββ΄βββββββββββββββββββ΄ββββββββββββββββββ
**Scalability to Large Memory:**
1. **MemN2N:** O(N) attention feasible up to ~1000 slots
2. **NTM/DNC:** Complex addressing limits to ~256 slots
3. **Key-Value:** Efficient retrieval, scales to 10K+ slots
4. **Sparse Retrieval (kNN):** Scales to millions (FAISS, ScaNN)
"""
print(comparison)
print()
def print_method_selection_guide():
"""Print guide for selecting memory architecture."""
print("="*70)
print("Memory Network Selection Guide")
print("="*70)
print()
guide = """
**Decision Tree:**
1. **Task: Question Answering over text**
β Use MemN2N
- Multi-hop reasoning over facts
- Interpretable attention
- Example: bAbI tasks
2. **Task: Algorithm learning (copy, sort, etc.)**
β Use NTM or DNC
- Explicit read/write operations
- Location-based addressing
- Example: Copy sequences, algorithmic tasks
3. **Task: Few-shot learning**
β Use Matching Networks or Prototypical Networks
- Memory = support examples
- Attention-based retrieval
- Example: Omniglot, miniImageNet
4. **Task: Knowledge base QA**
β Use Key-Value Memory
- Keys for retrieval, values for content
- Efficient for large KBs
- Example: WikiMovies, entity QA
5. **Task: Long-range dependencies in sequences**
β Use Memory-Augmented Transformer
- Combine self-attention + external memory
- Example: Long documents, code
6. **Task: Retrieval-augmented generation**
β Use RAG or RETRO
- External corpus as memory
- Dense retrieval + generation
- Example: Open-domain QA
**Performance Expectations:**
bAbI Tasks (Mean accuracy):
- MemN2N: 96.4% (state-of-the-art for memory networks)
- LSTM: 51.2%
- Human: ~100%
Algorithm Learning (Copy task):
- NTM: 100% (perfect generalization)
- LSTM: ~60% (fails on longer sequences)
Few-Shot (Omniglot 5-way 1-shot):
- Matching Networks: 98.1%
- Prototypical Networks: 98.8%
"""
print(guide)
print()
# ============================================================================
# Run All Demonstrations
# ============================================================================
if __name__ == "__main__":
torch.manual_seed(42)
demo_memn2n()
demo_ntm()
demo_kv_memory()
demo_matching_network()
print_complexity_comparison()
print_method_selection_guide()
print("="*70)
print("Memory Network Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ MemN2N: Multi-hop reasoning via attention (QA tasks)")
print(" β’ NTM: Learn algorithms via read/write operations")
print(" β’ Key-Value: Efficient retrieval from knowledge bases")
print(" β’ Matching: Few-shot learning via memory = support set")
print()
print("Core principle: Explicit external memory > implicit parameters")
print("Key insight: Attention = differentiable content-addressable memory")
print()
Advanced Memory Networks: Mathematical Foundations and Modern ArchitecturesΒΆ
1. Introduction to Memory-Augmented Neural NetworksΒΆ
Memory Networks extend neural networks with explicit external memory, enabling them to store and retrieve information over long time horizons, unlike LSTMs which compress history into fixed-size hidden states.
Core motivation: Standard RNNs/LSTMs have limited capacity to remember facts and reasoning chains. Memory networks separate:
Computation: Neural network processing
Storage: External memory bank
Key innovation (Weston et al., 2015): Memory as an array of vectors $\(\mathcal{M} = [m_1, m_2, \ldots, m_N] \in \mathbb{R}^{N \times d}\)$
Access mechanism: Soft attention over memory slots $\(\mathbf{o} = \sum_{i=1}^N \alpha_i m_i, \quad \alpha_i = \frac{\exp(s(q, m_i))}{\sum_j \exp(s(q, m_j))}\)$
where:
\(q\): Query vector (e.g., question embedding)
\(s(q, m_i)\): Similarity function (dot product, cosine, MLP)
\(\mathbf{o}\): Output (weighted memory)
Advantages:
Scalability: Memory size \(N\) independent of network parameters
Interpretability: Can inspect which memories are accessed
Explicit reasoning: Multi-hop attention for complex reasoning
Long-term storage: Information persists beyond gradient flow
2. Memory Network ArchitectureΒΆ
2.1 Four Components (I, G, O, R)ΒΆ
I (Input): Convert input \(x\) to internal representation $\(\mathbf{i} = I(x) = \text{Embedding}(x)\)$
G (Generalization): Update memory given new input $\(m_i \leftarrow G(m_i, \mathbf{i}, m) = \begin{cases} \mathbf{i} & \text{if slot } i \text{ selected} \\ m_i & \text{otherwise} \end{cases}\)$
O (Output): Retrieve relevant memories given query \(q\) $\(\mathbf{o} = O(q, m) = \sum_{i=1}^N \text{softmax}(s(q, m_i)) \cdot m_i\)$
R (Response): Generate final output $\(\text{answer} = R(\mathbf{o}, q) = \text{Decoder}(\mathbf{o}, q)\)$
2.2 End-to-End Memory Networks (MemN2N)ΒΆ
Motivation: Make memory networks fully differentiable (remove hard max in original).
Multiple hops: Perform \(K\) reasoning steps $\(q^{(k+1)} = q^{(k)} + \mathbf{o}^{(k)}, \quad k = 0, 1, \ldots, K-1\)$
Hop \(k\) attention: $\(p_i^{(k)} = \text{softmax}((q^{(k)})^T m_i^{(k)})\)\( \)\(\mathbf{o}^{(k)} = \sum_i p_i^{(k)} c_i^{(k)}\)$
where:
\(m_i^{(k)}\): Key memory at hop \(k\)
\(c_i^{(k)}\): Value memory at hop \(k\)
\(q^{(k)}\): Query at hop \(k\)
Final answer: $\(\hat{a} = \text{softmax}(W (q^{(K)} + \mathbf{o}^{(K-1)}))\)$
Embedding schemes:
Adjacent: \(A^{(k+1)} = C^{(k)}\) (parameter sharing across hops)
Layer-wise: Each hop has independent embeddings
Recurrent: \(A^{(k)} = A^{(1)}\), \(C^{(k)} = C^{(1)}\) (tied)
2.3 Position EncodingΒΆ
Motivation: Memory slots are unordered, but temporal/positional info matters.
Temporal encoding: \(m_i = \sum_j l_j \odot w_j\) where \(l_j\) is position encoding
Sinusoidal (from Transformers): $\(l_j^{(d)} = \begin{cases} \sin(j / 10000^{d/D}) & \text{if } d \text{ even} \\ \cos(j / 10000^{d/D}) & \text{if } d \text{ odd} \end{cases}\)$
Learned: Position embeddings as parameters.
3. Attention Mechanisms in MemoryΒΆ
3.1 Content-Based AddressingΒΆ
Similarity function: Dot product (scaled) $\(s(q, m_i) = \frac{q^T m_i}{\sqrt{d}}\)$
Softmax attention: $\(\alpha_i = \frac{\exp(s(q, m_i))}{\sum_j \exp(s(q, m_j))}\)$
Variants:
Cosine similarity: \(s(q, m_i) = \frac{q^T m_i}{\|q\| \|m_i\|}\)
Euclidean distance: \(s(q, m_i) = -\|q - m_i\|^2\)
MLP: \(s(q, m_i) = v^T \tanh(W[q; m_i])\)
3.2 Multi-Head AttentionΒΆ
Parallel attention heads: Each head attends to different aspects $\(\text{Head}_h = \text{Attention}(Q W_h^Q, K W_h^K, V W_h^V)\)$
Concatenate and project: $\(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{Head}_1, \ldots, \text{Head}_H) W^O\)$
Benefits:
Attend to multiple memories simultaneously
Richer representations
Better for complex reasoning
3.3 Sparse AttentionΒΆ
Problem: Dense attention over all \(N\) memories is \(O(N)\).
Top-k attention: Select \(k \ll N\) highest-scoring memories $\(\alpha_i = \begin{cases} \frac{\exp(s(q, m_i))}{\sum_{j \in \text{top-k}} \exp(s(q, m_j))} & \text{if } i \in \text{top-k} \\ 0 & \text{otherwise} \end{cases}\)$
Locality-sensitive hashing (LSH): Approximate nearest neighbors
Hash query and keys to buckets
Attend only within same bucket
\(O(\log N)\) or \(O(\sqrt{N})\) complexity
4. Dynamic Memory Networks (DMN)ΒΆ
Architecture (Kumar et al., 2016): Episodic memory with attention-based reasoning.
4.1 Input ModuleΒΆ
Sentence encoding: BiGRU or Transformer over sentences $\(h_i = \text{BiGRU}(w_{i,1}, \ldots, w_{i,T_i})\)$
Input memory: \(\{h_1, \ldots, h_N\}\) (one per sentence/fact).
4.2 Question ModuleΒΆ
Question encoding: $\(q = \text{GRU}(w_{q,1}, \ldots, w_{q,T_q})\)$
4.3 Episodic Memory ModuleΒΆ
Attention mechanism: Compute relevance of each fact to question $\(g_i^t = \text{Attention}(h_i, q, m^{t-1})\)$
Attention gate: Context vector for episode \(t\) $\(c^t = \sum_i g_i^t h_i\)$
Memory update: GRU over context $\(m^t = \text{GRU}(c^t, m^{t-1})\)$
Multiple episodes: Iterate \(T\) times (multi-hop reasoning).
4.4 Answer ModuleΒΆ
Decoder: GRU initialized with final memory $\(a = \text{Decoder}(m^T, q)\)$
Output: Word-by-word generation or classification.
4.5 Attention Gate DetailsΒΆ
Feature vector: $\(z_i^t = [h_i; q; m^{t-1}; h_i \odot q; h_i \odot m^{t-1}; |h_i - q|; |h_i - m^{t-1}|]\)$
Two-layer network: $\(g_i^t = \text{sigmoid}(W^{(2)} \tanh(W^{(1)} z_i^t + b^{(1)}) + b^{(2)})\)$
Softmax normalization (optional): $\(g_i^t = \frac{\exp(g_i^t)}{\sum_j \exp(g_j^t)}\)$
5. Neural Turing Machines (NTM)ΒΆ
Graves et al. (2014): Differentiable Turing Machine with read/write heads.
5.1 ArchitectureΒΆ
Controller: LSTM or feedforward network $\(h_t = \text{Controller}(x_t, h_{t-1})\)$
Memory: \(M_t \in \mathbb{R}^{N \times M}\) (rows = addresses, columns = content)
Read head: Weighted read from memory $\(\mathbf{r}_t = \sum_{i=1}^N w_t^r(i) M_t(i)\)$
Write head: Erase + add $\(M_t(i) = M_{t-1}(i) \cdot (\mathbf{1} - w_t^w(i) \mathbf{e}_t) + w_t^w(i) \mathbf{a}_t\)$
where:
\(w_t^w\): Write attention weights
\(\mathbf{e}_t\): Erase vector (what to remove)
\(\mathbf{a}_t\): Add vector (what to write)
5.2 Addressing MechanismsΒΆ
Content-based addressing: Similarity to key \(\mathbf{k}_t\) $\(w_t^c(i) \propto \exp\left(\beta_t \cdot \text{cosine}(\mathbf{k}_t, M_t(i))\right)\)$
Interpolation: Blend content and previous weights $\(w_t^g = g_t w_t^c + (1 - g_t) w_{t-1}\)$
where \(g_t \in [0, 1]\) is interpolation gate.
Shift: Convolutional shift (move attention) $\(\tilde{w}_t(i) = \sum_j w_t^g(j) s_t(i - j)\)$
where \(s_t\) is shift distribution (e.g., shift left/right).
Sharpening: Focus attention $\(w_t(i) = \frac{\tilde{w}_t(i)^{\gamma_t}}{\sum_j \tilde{w}_t(j)^{\gamma_t}}\)$
where \(\gamma_t \geq 1\) is sharpening factor.
5.3 TrainingΒΆ
Backpropagation through time: Gradients flow through memory and attention.
Challenges:
Vanishing/exploding gradients (long sequences)
Sensitive to initialization
Slow convergence
Solutions:
Gradient clipping
Curriculum learning (start with short sequences)
Pre-training
6. Differentiable Neural Computer (DNC)ΒΆ
DeepMind (2016): Evolution of NTM with more sophisticated memory management.
6.1 Enhancements Over NTMΒΆ
Temporal links: Track write order $\(L_t[i, j] = \text{link from slot } i \text{ to } j\)$
Usage vector: Track which slots are occupied $\(u_t(i) = (u_{t-1}(i) + w_t^w(i) - u_{t-1}(i) \odot w_t^w(i)) \odot (1 - f_t(i))\)$
where \(f_t\) is free gate (deallocate).
Allocation: Write to least-used slots $\(a_t = \text{sort}(u_t)\)$
6.2 Attention MechanismsΒΆ
Three read modes:
Forward: Follow temporal links forward $\(w_t^{f,i} = \sum_j L_t[i, j] w_{t-1}^r\)$
Backward: Follow temporal links backward $\(w_t^{b,i} = \sum_j L_t[j, i] w_{t-1}^r\)$
Content: Content-based (like NTM)
Read weights: Interpolate modes $\(w_t^r = \pi_t^{(1)} w_t^{b} + \pi_t^{(2)} w_t^{c} + \pi_t^{(3)} w_t^{f}\)$
where \(\pi_t\) are learned mode weights (\(\sum \pi_t^{(i)} = 1\)).
6.3 ApplicationsΒΆ
bAbI tasks: 20/20 tasks solved (vs 14/20 for MemN2N)
Graph reasoning: Shortest path, connectivity
Algorithmic tasks:
Copy, repeat copy
Priority sort
Mini-SHRDLU (blocks world)
Results: Perfect generalization on trained tasks, some transfer.
7. Key-Value Memory NetworksΒΆ
Miller et al. (2016): Separate keys (for retrieval) and values (for content).
Memory: Pairs \((k_i, v_i)\) where \(k_i, v_i \in \mathbb{R}^d\)
Attention over keys: $\(\alpha_i = \text{softmax}(q^T k_i)\)$
Retrieve values: $\(\mathbf{o} = \sum_i \alpha_i v_i\)$
Advantages:
Keys can be hashed/compressed (fast retrieval)
Values can be rich (full information)
Decouples addressing from content
Example: Question answering
Keys: Sentence embeddings (for retrieval)
Values: Full sentence representations (for reasoning)
8. Transformer as Memory NetworkΒΆ
Connection: Transformers are memory networks with:
Memory: Sequence of tokens (key-value pairs)
Addressing: Self-attention (content-based)
Updates: Feedforward layers
Self-attention: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\)$
is equivalent to memory read with:
Queries: \(Q = X W^Q\)
Keys: \(K = X W^K\)
Values: \(V = X W^V\)
Multi-head: Parallel memory systems (each head = different memory perspective).
Feedforward: Per-position processing (update memory slot).
Advantages over traditional memory networks:
Fully parallelizable (no sequential memory updates)
State-of-the-art on many tasks
Scalable to long sequences (with efficient attention)
Limitations:
Quadratic complexity \(O(n^2)\) in sequence length
No explicit long-term memory (limited by context window)
9. External Memory ExtensionsΒΆ
9.1 Compressive TransformersΒΆ
Motivation: Store compressed past memories beyond context window.
Architecture:
Active memory: Recent \(n_m\) tokens (regular attention)
Compressed memory: Older tokens compressed to \(n_c\) slots
Compression: Max-pooling, convolution, or learned network
Attention: Attend to both active and compressed memory.
Results: Better long-range dependencies (PG19 dataset).
9.2 Memory Transformer (Memorizing Transformer)ΒΆ
Approach: Retrieve from external key-value store.
Memory: Large database of \((k, v)\) pairs from past tokens
Retrieval: kNN search for top-k keys closest to query $\(\text{kNN}(q) = \{(k_i, v_i) : i \in \text{top-k by } \|q - k_i\|\}\)$
Integrate: Add retrieved values to Transformer layers
Benefits:
Memory scales independently (gigabytes)
Constant-time lookup (approximate NN)
Improves perplexity on long documents
9.3 Neural CacheΒΆ
Simple approach: Cache recent hidden states as memory.
Retrieval: $\(p(w_t | h_t) = \lambda p_{\text{model}}(w_t | h_t) + (1-\lambda) p_{\text{cache}}(w_t | h_t)\)$
where \(p_{\text{cache}}\) is based on similarity to cached states.
Effective: Improves language modeling perplexity with minimal overhead.
10. Training Memory NetworksΒΆ
10.1 Supervision StrategiesΒΆ
Weak supervision: Only final answer labeled
Train end-to-end with backpropagation
Attention learned implicitly
Strong supervision: Attention labels provided
Supervise which memory slots to attend to
Faster convergence, better performance
Example (bAbI): Provide supporting facts for each question.
10.2 Curriculum LearningΒΆ
Strategy: Train on easier tasks first, gradually increase difficulty.
For memory networks:
Start with 1-hop reasoning
Increase to 2-hop, 3-hop
Increase memory size \(N\)
Increase sequence length
Benefits: Better convergence, avoids local minima.
10.3 Loss FunctionsΒΆ
Classification: Cross-entropy $\(\mathcal{L} = -\sum_c y_c \log \hat{y}_c\)$
Reinforcement learning: REINFORCE for hard attention $\(\mathcal{L} = -\mathbb{E}_{\alpha \sim p_\theta}[R(\alpha) \log p_\theta(\alpha)]\)$
where \(R(\alpha)\) is reward (e.g., answer correctness).
Margin loss: For ranking (e.g., correct vs incorrect memories) $\(\mathcal{L} = \max(0, \gamma - s(q, m_+) + s(q, m_-))\)$
11. ApplicationsΒΆ
11.1 Question AnsweringΒΆ
bAbI tasks (Facebook AI):
20 tasks testing different reasoning skills
Memory networks: 20/20 tasks solved
Baselines (LSTMs): 7/20 tasks
SQuAD (reading comprehension):
Key-value memory networks competitive with BiDAF
Attention visualizations interpretable
11.2 Dialog SystemsΒΆ
Persona-based dialog: Remember user preferences
Memory: Previous conversation turns
Retrieval: Relevant past context
Generation: Consistent responses
Task-oriented dialog:
Memory: Slot-value pairs (e.g., restaurant attributes)
Reasoning: Multi-hop to find matching entities
11.3 Reasoning TasksΒΆ
Visual question answering (VQA):
Memory: Image regions (extracted by object detector)
Query: Question embedding
Multi-hop attention: Compositional reasoning
Text entailment:
Memory: Premise sentences
Query: Hypothesis
Output: Entailment/contradiction/neutral
11.4 Program SynthesisΒΆ
Neural Programmer-Interpreter (NPI):
Memory: Program states
Retrieval: Relevant sub-programs
Hierarchical execution
Differentiable Forth:
Memory: Stack and heap
Operations: Differentiable push/pop/arithmetic
12. Theoretical AnalysisΒΆ
12.1 ExpressivenessΒΆ
Theorem: Memory networks with sufficient capacity are Turing-complete.
Proof sketch:
Memory = Turing machine tape
Attention = read/write head
Controller = state transitions
Caveat: In practice, limited by:
Finite memory size \(N\)
Approximate attention (softmax vs hard)
Training difficulties
12.2 Sample ComplexityΒΆ
Hypothesis: Explicit memory reduces sample complexity for tasks requiring long-term dependencies.
Evidence:
Memory networks learn bAbI with fewer examples than LSTMs
Copy tasks: NTM sample-efficient
Theory gap: Formal bounds lacking.
12.3 Attention EntropyΒΆ
Measure: Entropy of attention distribution $\(H(\alpha) = -\sum_i \alpha_i \log \alpha_i\)$
Low entropy: Focused attention (confident retrieval)
High entropy: Distributed attention (uncertain or uniform)
Observations:
Early hops: High entropy (exploration)
Later hops: Low entropy (focused retrieval)
Difficult tasks: Higher average entropy
13. Challenges and LimitationsΒΆ
13.1 ScalabilityΒΆ
Memory size: \(O(N)\) attention cost prohibitive for large \(N\) (millions).
Solutions:
Hierarchical memory (tree structure)
Sparse attention (top-k, LSH)
Learned memory management (forget irrelevant)
13.2 Training InstabilityΒΆ
Problem: Attention can be chaotic (small input changes β large attention shifts).
Causes:
Softmax saturation
Gradient vanishing/exploding through attention
Solutions:
Gradient clipping
Attention dropout
Temperature scaling in softmax
13.3 Interpretability vs PerformanceΒΆ
Trade-off: Sparse attention (interpretable) vs dense attention (better performance).
Example: Hard attention (reinforcement learning) interpretable but slower to train.
Compromise: Top-k attention with sufficient \(k\) for performance.
14. Recent Advances (2020-2024)ΒΆ
14.1 Memory-Efficient TransformersΒΆ
Linear attention: Replace softmax with kernelized attention $\(\text{Attn}(Q, K, V) = \phi(Q) (\phi(K)^T V)\)$
reduces complexity from \(O(n^2)\) to \(O(n)\).
Reformer: LSH attention + reversible layers
\(O(n \log n)\) complexity
Constant memory (reversible)
14.2 Retrieval-Augmented Generation (RAG)ΒΆ
Approach: Retrieve documents from external corpus, condition generation.
Architecture: $\(p(y | x) = \sum_d p(d | x) p(y | x, d)\)$
where:
\(p(d | x)\): Retriever (e.g., DPR)
\(p(y | x, d)\): Generator (e.g., BART)
Applications: Open-domain QA, fact-checking.
14.3 Memory as ParametersΒΆ
Mega (2022): Exponential moving average attention
Memory encoded in EMA weights
Constant memory, linear complexity
Results: Competitive with Transformers, much faster.
14.4 Associative MemoryΒΆ
Hopfield Networks for Transformers:
Modern Hopfield networks = attention mechanism
Energy-based interpretation
Continuous Hopfield: $\(E = -\log \sum_i \exp(\beta \langle \xi, \xi_i \rangle)\)$
minimized by attention.
15. Comparison with Other ArchitecturesΒΆ
15.1 Memory Networks vs RNNs/LSTMsΒΆ
Aspect |
Memory Networks |
RNNs/LSTMs |
|---|---|---|
Memory |
Explicit (external) |
Implicit (hidden state) |
Capacity |
Scalable (independent \(N\)) |
Fixed (hidden size) |
Access |
Random (attention) |
Sequential |
Interpretability |
High (visualize attention) |
Low (opaque state) |
Training |
Stable (explicit memory) |
Vanishing gradients |
Complexity |
\(O(N)\) per step |
\(O(1)\) per step |
Verdict: Memory networks better for tasks requiring explicit storage/retrieval.
15.2 Memory Networks vs TransformersΒΆ
Aspect |
Memory Networks |
Transformers |
|---|---|---|
Memory mechanism |
Explicit external |
Implicit (self-attention) |
Temporal modeling |
Built-in (position encoding) |
Via position encoding |
Parallelization |
Moderate (depends on design) |
Full parallelization |
Long-term memory |
Can have persistent memory |
Limited by context |
Scalability |
Depends on \(N\) |
Quadratic in sequence length |
Convergence: Transformers can be viewed as memory networks (self-attention = memory read).
15.3 NTM/DNC vs TransformersΒΆ
Aspect |
NTM/DNC |
Transformers |
|---|---|---|
Memory |
Explicit read/write |
Implicit (attention) |
Sequential ops |
Required (state updates) |
Fully parallel |
Algorithmic tasks |
Excellent (exact copy) |
Poor without training tricks |
Language modeling |
Moderate |
State-of-the-art |
Training |
Difficult (BPTT, sensitive init) |
Easier (parallel, stable) |
Niche: NTM/DNC for algorithmic reasoning, Transformers for NLP.
16. Implementation Best PracticesΒΆ
16.1 Memory InitializationΒΆ
Options:
Zero initialization: \(m_i = \mathbf{0}\)
Random: \(m_i \sim \mathcal{N}(0, \sigma^2 I)\)
Pre-filled: Initialize with input embeddings
Recommendation: Pre-fill with input for faster convergence.
16.2 Attention TemperatureΒΆ
Softmax with temperature: $\(\alpha_i = \frac{\exp(s(q, m_i) / \tau)}{\sum_j \exp(s(q, m_j) / \tau)}\)$
Low \(\tau\) (< 1): Sharper attention (more focused) High \(\tau\) (> 1): Softer attention (more uniform)
Annealing: Start high (exploration), decrease (exploitation).
16.3 Gradient ClippingΒΆ
Essential for NTM/DNC due to BPTT through memory.
Norm clipping: $\(g \leftarrow \frac{g}{\max(1, \|g\| / C)}\)$
Typical: \(C = 5\) or \(10\).
16.4 HyperparametersΒΆ
Memory size: \(N = 128\) (small), \(N = 1024\) (large)
Memory dimension: \(d = 128\) or \(256\)
Number of hops: \(K = 1\) (simple), \(K = 3\) (complex reasoning)
Learning rate: \(10^{-4}\) (Adam), with warmup
Batch size: 32-128 (depends on task)
17. Summary and Future DirectionsΒΆ
17.1 Key TakeawaysΒΆ
Memory networks:
Explicit external memory (scalable capacity)
Soft attention for differentiability
Multi-hop reasoning for complex tasks
Interpretable (attention weights)
Architectures:
MemN2N: End-to-end, simple, effective
DMN: Episodic memory, strong on bAbI
NTM/DNC: Algorithmic tasks, challenging to train
Key-Value: Decoupled retrieval and content
Transformers: Modern incarnation (self-attention = memory).
17.2 Open ProblemsΒΆ
Scalability: Efficient attention for \(N > 10^6\)
Continual learning: Update memory without catastrophic forgetting
Compositional generalization: Combine memories in novel ways
Memory management: Automatic pruning, consolidation
Theoretical foundations: Formal expressiveness, sample complexity
17.3 Future DirectionsΒΆ
Hybrid architectures:
Combine Transformers with explicit external memory
Best of both worlds (parallelization + persistent storage)
Neurosymbolic integration:
Memory as symbolic knowledge base
Neural retrieval, symbolic reasoning
Lifelong learning:
Memory networks for continual learning
Store past tasks, prevent forgetting
Multimodal memory:
Unified memory for vision, language, audio
Cross-modal retrieval and reasoning
18. ConclusionΒΆ
Memory-augmented neural networks represent a fundamental shift from compressing all knowledge into parameters to explicit, inspectable storage. While Transformers have dominated recent NLP advances, the core ideasβattention over external memory, multi-hop reasoningβremain central.
Current state:
MemN2N/DMN: Proven on reasoning tasks (bAbI)
NTM/DNC: Niche (algorithmic tasks)
Transformers: Dominant (NLP, vision), can be viewed as memory networks
Verdict: Memory networks are a conceptually elegant framework that informs modern architectures. Explicit external memory will likely resurface as models scale beyond current context limits, enabling true lifelong learning and continual knowledge accumulation.
The future likely involves hybrid systems: Transformers for in-context learning, external memory for long-term knowledgeβcombining the parallelization benefits of modern architectures with the explicit reasoning capabilities of memory networks.
"""
Advanced Memory Networks - Production Implementation
Comprehensive PyTorch implementations of MemN2N, DMN, and NTM architectures
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List
# ===========================
# 1. Position Encoding
# ===========================
class PositionEncoding(nn.Module):
"""Sinusoidal position encoding for memory slots"""
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
x + positional encoding
"""
return x + self.pe[:x.size(1), :]
# ===========================
# 2. End-to-End Memory Network (MemN2N)
# ===========================
class MemN2N(nn.Module):
"""
End-to-End Memory Network (Sukhbaatar et al., 2015)
Multi-hop attention over external memory
"""
def __init__(self,
vocab_size: int,
embed_dim: int,
num_hops: int = 3,
memory_size: int = 50,
dropout: float = 0.1):
super().__init__()
self.num_hops = num_hops
self.embed_dim = embed_dim
self.memory_size = memory_size
# Embedding layers for each hop (keys and values)
# Adjacent weight tying: A^(k+1) = C^(k)
self.embed_A = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) for _ in range(num_hops + 1)])
self.embed_C = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) for _ in range(num_hops)])
# Question embedding
self.embed_B = nn.Embedding(vocab_size, embed_dim)
# Output projection
self.W = nn.Linear(embed_dim, vocab_size, bias=False)
# Position encoding
self.position_encoding = PositionEncoding(embed_dim, max_len=memory_size)
self.dropout = nn.Dropout(dropout)
def forward(self,
stories: torch.Tensor,
queries: torch.Tensor,
story_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
stories: (batch, memory_size, story_len) - memory slot contents
queries: (batch, query_len) - question
story_lengths: (batch, memory_size) - actual lengths of each story
Returns:
logits: (batch, vocab_size) - answer distribution
"""
batch_size = stories.size(0)
# Embed query: (batch, query_len, embed_dim) -> (batch, embed_dim)
q = self.embed_B(queries)
q = torch.sum(q, dim=1) # Bag-of-words (can use RNN instead)
# Initial query
u = q # (batch, embed_dim)
# Multi-hop attention
for hop in range(self.num_hops):
# Embed memory keys (for attention)
m = self.embed_A[hop](stories) # (batch, memory_size, story_len, embed_dim)
m = torch.sum(m, dim=2) # Bag-of-words per story: (batch, memory_size, embed_dim)
m = self.position_encoding(m)
# Embed memory values (for output)
c = self.embed_C[hop](stories) # (batch, memory_size, story_len, embed_dim)
c = torch.sum(c, dim=2) # (batch, memory_size, embed_dim)
c = self.position_encoding(c)
# Attention: p_i = softmax(u^T m_i)
attention_logits = torch.matmul(m, u.unsqueeze(2)).squeeze(2) # (batch, memory_size)
# Mask padding if story_lengths provided
if story_lengths is not None:
mask = torch.arange(self.memory_size, device=stories.device)[None, :] >= story_lengths[:, None]
attention_logits = attention_logits.masked_fill(mask, -1e9)
p = F.softmax(attention_logits, dim=1) # (batch, memory_size)
# Output: o = Ξ£ p_i c_i
o = torch.matmul(p.unsqueeze(1), c).squeeze(1) # (batch, embed_dim)
# Update query: u = u + o (residual connection)
u = u + o
u = self.dropout(u)
# Final answer
logits = self.W(u) # (batch, vocab_size)
return logits, p # Return logits and final attention for visualization
# ===========================
# 3. Dynamic Memory Network (DMN)
# ===========================
class AttentionGRUCell(nn.Module):
"""Attention-based GRU cell for episodic memory"""
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
# Attention gate
self.attn_gate = nn.Sequential(
nn.Linear(input_size * 7, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
# GRU cell
self.gru_cell = nn.GRUCell(input_size, hidden_size)
def forward(self, fact: torch.Tensor, prev_memory: torch.Tensor,
question: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
fact: (batch, input_size) - current fact
prev_memory: (batch, hidden_size) - previous episode memory
question: (batch, input_size) - question embedding
Returns:
gate: (batch, 1) - attention gate
memory: (batch, hidden_size) - updated memory
"""
# Compute attention gate features
z = torch.cat([
fact,
prev_memory,
question,
fact * question,
fact * prev_memory,
torch.abs(fact - question),
torch.abs(fact - prev_memory)
], dim=1)
gate = self.attn_gate(z) # (batch, 1)
# Update memory with GRU (weighted by gate)
fact_weighted = gate * fact
memory = self.gru_cell(fact_weighted, prev_memory)
return gate, memory
class EpisodicMemory(nn.Module):
"""Episodic memory module with attention-based reasoning"""
def __init__(self, input_size: int, hidden_size: int, num_hops: int = 3):
super().__init__()
self.hidden_size = hidden_size
self.num_hops = num_hops
self.attn_gru = AttentionGRUCell(input_size, hidden_size)
self.memory_update = nn.GRUCell(input_size, hidden_size)
def forward(self, facts: torch.Tensor, question: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
facts: (batch, num_facts, input_size) - input facts
question: (batch, input_size) - question embedding
Returns:
memory: (batch, hidden_size) - final episodic memory
attention_history: List of attention weights per hop
"""
batch_size, num_facts, _ = facts.shape
# Initialize memory with question
memory = question.clone()
attention_history = []
# Multiple passes (hops) over facts
for hop in range(self.num_hops):
# Context vector for this episode
context = torch.zeros(batch_size, self.hidden_size, device=facts.device)
gates = []
# Iterate over facts
for i in range(num_facts):
fact = facts[:, i, :] # (batch, input_size)
gate, context = self.attn_gru(fact, memory, question)
gates.append(gate)
# Stack gates
gates = torch.cat(gates, dim=1) # (batch, num_facts)
attention_history.append(gates)
# Update episode memory
memory = self.memory_update(context, memory)
return memory, attention_history
class DMN(nn.Module):
"""
Dynamic Memory Network (Kumar et al., 2016)
Episodic memory with attention-based reasoning
"""
def __init__(self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_hops: int = 3,
dropout: float = 0.1):
super().__init__()
# Embedding
self.embedding = nn.Embedding(vocab_size, embed_dim)
# Input module (encode facts)
self.input_gru = nn.GRU(embed_dim, hidden_size, batch_first=True, bidirectional=True)
# Question module
self.question_gru = nn.GRU(embed_dim, hidden_size, batch_first=True)
# Episodic memory module
self.episodic_memory = EpisodicMemory(hidden_size * 2, hidden_size, num_hops)
# Answer module (decoder)
self.answer_gru = nn.GRUCell(vocab_size, hidden_size)
self.answer_fc = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self,
facts: torch.Tensor,
question: torch.Tensor,
answer: Optional[torch.Tensor] = None,
max_len: int = 10) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
facts: (batch, num_facts, fact_len) - input facts
question: (batch, question_len) - question
answer: (batch, answer_len) - ground truth answer (for teacher forcing)
max_len: maximum answer length for generation
Returns:
output: (batch, answer_len, vocab_size) - answer logits
attention_history: attention weights per hop
"""
batch_size = facts.size(0)
num_facts = facts.size(1)
# Input module: encode facts
fact_embeds = self.embedding(facts) # (batch, num_facts, fact_len, embed_dim)
encoded_facts = []
for i in range(num_facts):
fact_embed = fact_embeds[:, i, :, :] # (batch, fact_len, embed_dim)
_, h = self.input_gru(fact_embed) # h: (2, batch, hidden_size)
h = torch.cat([h[0], h[1]], dim=1) # Concatenate forward and backward
encoded_facts.append(h)
encoded_facts = torch.stack(encoded_facts, dim=1) # (batch, num_facts, hidden_size*2)
# Question module
question_embed = self.embedding(question) # (batch, question_len, embed_dim)
_, q = self.question_gru(question_embed) # (1, batch, hidden_size)
q = q.squeeze(0) # (batch, hidden_size)
# Episodic memory module
memory, attention_history = self.episodic_memory(encoded_facts, q)
# Answer module (decoder)
if answer is not None:
# Teacher forcing
answer_len = answer.size(1)
else:
answer_len = max_len
outputs = []
hidden = memory
input_token = torch.zeros(batch_size, self.answer_fc.out_features, device=facts.device)
for t in range(answer_len):
hidden = self.answer_gru(input_token, hidden)
output = self.answer_fc(hidden) # (batch, vocab_size)
outputs.append(output)
# Next input
if answer is not None and t < answer_len - 1:
# Teacher forcing: use ground truth
input_token = F.one_hot(answer[:, t], num_classes=self.answer_fc.out_features).float()
else:
# Use prediction
input_token = F.softmax(output, dim=1)
outputs = torch.stack(outputs, dim=1) # (batch, answer_len, vocab_size)
return outputs, attention_history
# ===========================
# 4. Neural Turing Machine (NTM)
# ===========================
class NTMMemory(nn.Module):
"""NTM memory with read and write operations"""
def __init__(self, memory_size: int, memory_dim: int):
super().__init__()
self.memory_size = memory_size
self.memory_dim = memory_dim
def reset(self, batch_size: int, device: torch.device) -> torch.Tensor:
"""Initialize memory to small random values"""
return torch.randn(batch_size, self.memory_size, self.memory_dim, device=device) * 0.01
def read(self, memory: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Read from memory using attention weights
Args:
memory: (batch, memory_size, memory_dim)
w: (batch, memory_size) - read weights
Returns:
read_vector: (batch, memory_dim)
"""
return torch.matmul(w.unsqueeze(1), memory).squeeze(1)
def write(self, memory: torch.Tensor, w: torch.Tensor,
erase: torch.Tensor, add: torch.Tensor) -> torch.Tensor:
"""
Write to memory: erase + add
Args:
memory: (batch, memory_size, memory_dim)
w: (batch, memory_size) - write weights
erase: (batch, memory_dim) - erase vector
add: (batch, memory_dim) - add vector
Returns:
updated_memory: (batch, memory_size, memory_dim)
"""
# Erase: M_t = M_{t-1} * (1 - w * e^T)
erase_matrix = torch.matmul(w.unsqueeze(2), erase.unsqueeze(1)) # (batch, memory_size, memory_dim)
memory = memory * (1 - erase_matrix)
# Add: M_t = M_t + w * a^T
add_matrix = torch.matmul(w.unsqueeze(2), add.unsqueeze(1))
memory = memory + add_matrix
return memory
class NTMHead(nn.Module):
"""NTM read/write head with content and location addressing"""
def __init__(self, memory_dim: int, controller_dim: int):
super().__init__()
self.memory_dim = memory_dim
# Key for content addressing
self.key_layer = nn.Linear(controller_dim, memory_dim)
# Key strength (beta)
self.beta_layer = nn.Linear(controller_dim, 1)
# Interpolation gate (g)
self.g_layer = nn.Linear(controller_dim, 1)
# Shift (s) - convolutional shift
self.shift_layer = nn.Linear(controller_dim, 3) # shift left, stay, shift right
# Sharpening (gamma)
self.gamma_layer = nn.Linear(controller_dim, 1)
def content_addressing(self, memory: torch.Tensor, key: torch.Tensor,
beta: torch.Tensor) -> torch.Tensor:
"""
Content-based addressing using cosine similarity
w_c = softmax(beta * cosine(key, M))
"""
# Normalize
key = key / (torch.norm(key, dim=1, keepdim=True) + 1e-8)
memory = memory / (torch.norm(memory, dim=2, keepdim=True) + 1e-8)
# Cosine similarity
similarity = torch.matmul(memory, key.unsqueeze(2)).squeeze(2) # (batch, memory_size)
# Apply key strength
w_c = F.softmax(beta * similarity, dim=1)
return w_c
def forward(self, controller_output: torch.Tensor, memory: torch.Tensor,
prev_w: torch.Tensor) -> torch.Tensor:
"""
Compute attention weights
Args:
controller_output: (batch, controller_dim)
memory: (batch, memory_size, memory_dim)
prev_w: (batch, memory_size) - previous weights
Returns:
w: (batch, memory_size) - new attention weights
"""
# Extract addressing parameters
key = torch.tanh(self.key_layer(controller_output)) # (batch, memory_dim)
beta = F.softplus(self.beta_layer(controller_output)) # (batch, 1)
g = torch.sigmoid(self.g_layer(controller_output)) # (batch, 1)
shift = F.softmax(self.shift_layer(controller_output), dim=1) # (batch, 3)
gamma = 1 + F.softplus(self.gamma_layer(controller_output)) # (batch, 1)
# Content addressing
w_c = self.content_addressing(memory, key, beta)
# Interpolation
w_g = g * w_c + (1 - g) * prev_w
# Convolutional shift
w_tilde = self._convolutional_shift(w_g, shift)
# Sharpening
w = torch.pow(w_tilde, gamma)
w = w / (torch.sum(w, dim=1, keepdim=True) + 1e-8)
return w
def _convolutional_shift(self, w: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
"""Apply convolutional shift to weights"""
batch_size, memory_size = w.shape
# Pad for circular shift
w_padded = torch.cat([w[:, -1:], w, w[:, :1]], dim=1)
# Convolve with shift kernel
w_shifted = torch.zeros_like(w)
for i in range(memory_size):
w_shifted[:, i] = (shift[:, 0] * w_padded[:, i] + # shift left
shift[:, 1] * w_padded[:, i + 1] + # stay
shift[:, 2] * w_padded[:, i + 2]) # shift right
return w_shifted
class NTM(nn.Module):
"""Neural Turing Machine (Graves et al., 2014)"""
def __init__(self,
input_size: int,
output_size: int,
controller_dim: int,
memory_size: int,
memory_dim: int,
num_heads: int = 1):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.controller_dim = controller_dim
self.memory_size = memory_size
self.memory_dim = memory_dim
self.num_heads = num_heads
# Memory
self.memory = NTMMemory(memory_size, memory_dim)
# Controller (LSTM)
self.controller = nn.LSTMCell(input_size + num_heads * memory_dim, controller_dim)
# Read heads
self.read_heads = nn.ModuleList([NTMHead(memory_dim, controller_dim) for _ in range(num_heads)])
# Write head
self.write_head = NTMHead(memory_dim, controller_dim)
self.erase_layer = nn.Linear(controller_dim, memory_dim)
self.add_layer = nn.Linear(controller_dim, memory_dim)
# Output
self.fc_out = nn.Linear(controller_dim + num_heads * memory_dim, output_size)
def forward(self, x: torch.Tensor, prev_state: Optional[dict] = None) -> Tuple[torch.Tensor, dict]:
"""
Single step forward
Args:
x: (batch, input_size)
prev_state: Previous NTM state (memory, controller, weights)
Returns:
output: (batch, output_size)
state: Current NTM state
"""
batch_size = x.size(0)
device = x.device
# Initialize state if needed
if prev_state is None:
prev_state = self.init_state(batch_size, device)
# Read from memory
read_vectors = []
for i, head in enumerate(self.read_heads):
r = self.memory.read(prev_state['memory'], prev_state['read_w'][i])
read_vectors.append(r)
read_vector = torch.cat(read_vectors, dim=1) # (batch, num_heads * memory_dim)
# Controller input: x + read_vectors
controller_input = torch.cat([x, read_vector], dim=1)
controller_h, controller_c = self.controller(controller_input,
(prev_state['controller_h'], prev_state['controller_c']))
# Update read weights
new_read_w = []
for i, head in enumerate(self.read_heads):
w = head(controller_h, prev_state['memory'], prev_state['read_w'][i])
new_read_w.append(w)
# Update write weights
write_w = self.write_head(controller_h, prev_state['memory'], prev_state['write_w'])
# Write to memory
erase = torch.sigmoid(self.erase_layer(controller_h))
add = torch.tanh(self.add_layer(controller_h))
new_memory = self.memory.write(prev_state['memory'], write_w, erase, add)
# Output
output_input = torch.cat([controller_h, read_vector], dim=1)
output = self.fc_out(output_input)
# New state
new_state = {
'memory': new_memory,
'controller_h': controller_h,
'controller_c': controller_c,
'read_w': new_read_w,
'write_w': write_w
}
return output, new_state
def init_state(self, batch_size: int, device: torch.device) -> dict:
"""Initialize NTM state"""
return {
'memory': self.memory.reset(batch_size, device),
'controller_h': torch.zeros(batch_size, self.controller_dim, device=device),
'controller_c': torch.zeros(batch_size, self.controller_dim, device=device),
'read_w': [torch.zeros(batch_size, self.memory_size, device=device) + 1.0 / self.memory_size
for _ in range(self.num_heads)],
'write_w': torch.zeros(batch_size, self.memory_size, device=device) + 1.0 / self.memory_size
}
def forward_sequence(self, x_seq: torch.Tensor) -> torch.Tensor:
"""
Forward pass for entire sequence
Args:
x_seq: (batch, seq_len, input_size)
Returns:
output_seq: (batch, seq_len, output_size)
"""
batch_size, seq_len, _ = x_seq.shape
state = None
outputs = []
for t in range(seq_len):
output, state = self.forward(x_seq[:, t, :], state)
outputs.append(output)
return torch.stack(outputs, dim=1)
# ===========================
# 5. Demo Functions
# ===========================
def demo_memn2n():
"""Demonstrate End-to-End Memory Network"""
print("=" * 60)
print("Demo: End-to-End Memory Network (MemN2N)")
print("=" * 60)
vocab_size = 100
embed_dim = 32
model = MemN2N(vocab_size, embed_dim, num_hops=3, memory_size=20)
# Dummy data
batch_size = 4
stories = torch.randint(0, vocab_size, (batch_size, 20, 6)) # 20 stories, 6 words each
queries = torch.randint(0, vocab_size, (batch_size, 5)) # 5-word question
logits, attention = model(stories, queries)
print(f"Stories: {stories.shape} (batch, memory_size, story_len)")
print(f"Queries: {queries.shape} (batch, query_len)")
print(f"Output logits: {logits.shape} (batch, vocab_size)")
print(f"Final attention: {attention.shape} (batch, memory_size)")
print(f"Attention weights (sample): {attention[0][:10]}")
num_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {num_params:,}")
print()
def demo_dmn():
"""Demonstrate Dynamic Memory Network"""
print("=" * 60)
print("Demo: Dynamic Memory Network (DMN)")
print("=" * 60)
vocab_size = 100
embed_dim = 32
hidden_size = 64
model = DMN(vocab_size, embed_dim, hidden_size, num_hops=3)
# Dummy data
batch_size = 4
facts = torch.randint(0, vocab_size, (batch_size, 10, 8)) # 10 facts, 8 words each
question = torch.randint(0, vocab_size, (batch_size, 5)) # 5-word question
answer = torch.randint(0, vocab_size, (batch_size, 3)) # 3-word answer
outputs, attention_history = model(facts, question, answer)
print(f"Facts: {facts.shape} (batch, num_facts, fact_len)")
print(f"Question: {question.shape} (batch, question_len)")
print(f"Answer: {answer.shape} (batch, answer_len)")
print(f"Output: {outputs.shape} (batch, answer_len, vocab_size)")
print(f"Attention hops: {len(attention_history)}")
print(f"Attention per hop: {attention_history[0].shape} (batch, num_facts)")
# Show attention evolution
print(f"\nAttention evolution (sample, fact 0):")
for hop, attn in enumerate(attention_history):
print(f" Hop {hop}: {attn[0, 0].item():.4f}")
num_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {num_params:,}")
print()
def demo_ntm():
"""Demonstrate Neural Turing Machine"""
print("=" * 60)
print("Demo: Neural Turing Machine (NTM)")
print("=" * 60)
input_size = 8
output_size = 8
controller_dim = 100
memory_size = 128
memory_dim = 20
model = NTM(input_size, output_size, controller_dim, memory_size, memory_dim, num_heads=1)
# Copy task: copy input sequence
batch_size = 2
seq_len = 10
x_seq = torch.randn(batch_size, seq_len, input_size)
output_seq = model.forward_sequence(x_seq)
print(f"Input sequence: {x_seq.shape} (batch, seq_len, input_size)")
print(f"Output sequence: {output_seq.shape} (batch, seq_len, output_size)")
# Single step
x = torch.randn(batch_size, input_size)
output, state = model(x)
print(f"\nSingle step:")
print(f" Input: {x.shape}")
print(f" Output: {output.shape}")
print(f" Memory: {state['memory'].shape} (batch, memory_size, memory_dim)")
print(f" Read weights: {state['read_w'][0].shape} (batch, memory_size)")
print(f" Write weights: {state['write_w'].shape}")
num_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {num_params:,}")
print(f"Memory parameters: 0 (external memory)")
print()
def print_performance_comparison():
"""Comprehensive performance comparison"""
print("=" * 80)
print("PERFORMANCE COMPARISON: Memory Networks")
print("=" * 80)
# 1. bAbI Tasks
print("\n1. bAbI Question Answering (% tasks solved)")
print("-" * 80)
data = [
("Model", "1-hop", "2-hop", "3-hop", "Overall"),
("-" * 30, "-" * 8, "-" * 8, "-" * 8, "-" * 10),
("LSTM", "95%", "45%", "20%", "35% (7/20)"),
("MemN2N (1 hop)", "100%", "75%", "55%", "70% (14/20)"),
("MemN2N (3 hops)", "100%", "95%", "85%", "95% (19/20)"),
("DMN+", "100%", "97%", "93%", "97% (19.4/20)"),
("DNC", "100%", "100%", "98%", "100% (20/20)"),
("", "", "", "", ""),
("Observation:", "", "", "", "Multi-hop crucial for complex reasoning"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<8} {row[2]:<8} {row[3]:<8} {row[4]:<10}")
# 2. Algorithmic Tasks
print("\n2. Algorithmic Tasks (Accuracy %)")
print("-" * 80)
data = [
("Task", "LSTM", "MemN2N", "NTM", "DNC"),
("-" * 25, "-" * 8, "-" * 10, "-" * 8, "-" * 8),
("Copy", "45%", "N/A", "100%", "100%"),
("Repeat Copy", "30%", "N/A", "95%", "100%"),
("Associative Recall", "40%", "85%", "98%", "100%"),
("Priority Sort", "20%", "N/A", "75%", "98%"),
("", "", "", "", ""),
("Best for:", "Simple", "Reasoning", "Algorithms", "Complex"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<8} {row[2]:<10} {row[3]:<8} {row[4]:<8}")
# 3. Computational Complexity
print("\n3. Computational Complexity")
print("-" * 80)
data = [
("Model", "Time per step", "Memory", "Parameters"),
("-" * 20, "-" * 20, "-" * 20, "-" * 20),
("LSTM", "O(1)", "O(h)", "O(hΒ²)"),
("MemN2N", "O(KΒ·N)", "O(NΒ·d)", "O(VΒ·dΒ·K)"),
("DMN", "O(KΒ·N)", "O(NΒ·d)", "O(VΒ·d + hΒ²)"),
("NTM", "O(NΒ·d)", "O(NΒ·d) external", "O(hΒ·d)"),
("DNC", "O(NΒ·d)", "O(NΒ·d) external", "O(hΒ·d)"),
("", "", "", ""),
("Legend:", "K=hops, N=memory", "h=hidden, d=mem_dim", "V=vocab"),
]
for row in data:
print(f"{row[0]:<20} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
# 4. Comparison Table
print("\n4. Architecture Comparison")
print("-" * 80)
data = [
("Aspect", "MemN2N", "DMN", "NTM/DNC"),
("-" * 25, "-" * 20, "-" * 20, "-" * 20),
("Memory type", "Static (input)", "Episodic", "Dynamic R/W"),
("Addressing", "Content (attention)", "Attention gate", "Content + location"),
("Differentiable", "Yes (soft attn)", "Yes", "Yes"),
("Multi-hop", "Built-in", "Built-in", "Via controller"),
("Training", "Easy", "Moderate", "Hard (BPTT)"),
("Interpretability", "High", "High", "Moderate"),
("Best for", "QA, reasoning", "QA, dialog", "Algorithms"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
# 5. Hyperparameters
print("\n5. Recommended Hyperparameters")
print("-" * 80)
data = [
("Parameter", "MemN2N", "DMN", "NTM"),
("-" * 25, "-" * 15, "-" * 15, "-" * 15),
("Embedding dim", "32-128", "64-128", "N/A"),
("Hidden size", "N/A", "128-256", "100-200"),
("Memory size", "20-100", "10-50", "128-512"),
("Memory dim", "embed_dim", "2Γhidden", "20-40"),
("Num hops", "1-3", "2-4", "N/A"),
("Learning rate", "0.01", "0.001", "0.0001"),
("Gradient clip", "40", "10", "10"),
("Batch size", "32", "32", "16"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<15} {row[2]:<15} {row[3]:<15}")
# 6. Decision Guide
print("\n6. DECISION GUIDE: When to Use Memory Networks")
print("=" * 80)
print("\nβ USE Memory Networks When:")
print("β’ Multi-hop reasoning required (compositional questions)")
print("β’ Explicit memory inspection needed (interpretability)")
print("β’ Tasks with supporting facts (QA with evidence)")
print("β’ Algorithmic tasks (copy, sort, recall)")
print("β’ Dialog systems (track conversation history)")
print("β’ Long-term dependencies (beyond LSTM capacity)")
print("\nβ AVOID Memory Networks When:")
print("β’ Simple sequential tasks (LSTM sufficient)")
print("β’ Large-scale language modeling (Transformers better)")
print("β’ Real-time inference critical (overhead of attention)")
print("β’ Limited computational budget (complex architectures)")
print("\nβ RECOMMENDED ALTERNATIVES:")
print("β’ General NLP β Transformers (BERT, GPT)")
print("β’ Simple QA β BiDAF, BERT-based models")
print("β’ Reasoning β Graph Neural Networks, Transformers")
print("β’ Algorithms β Specialized neural architectures")
print("\n" + "=" * 80)
print()
# ===========================
# Run All Demos
# ===========================
if __name__ == "__main__":
print("\n" + "=" * 80)
print("MEMORY NETWORKS - COMPREHENSIVE IMPLEMENTATION")
print("=" * 80 + "\n")
demo_memn2n()
demo_dmn()
demo_ntm()
print_performance_comparison()
print("\n" + "=" * 80)
print("All demos completed successfully!")
print("=" * 80)