import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. Motivation: Discrete Latent SpacesΒΆ
Standard VAE LimitationsΒΆ
Continuous latent space \(z \sim \mathcal{N}(\mu, \sigma^2)\):
Problems:
Posterior collapse: Decoder ignores latent code
Blurry reconstructions: Gaussian assumption
Difficult autoregressive modeling: Continuous \(z\)
VQ-VAE Solution (van den Oord et al., 2017)ΒΆ
Discrete latent space with learned codebook!
where:
\(K\) = codebook size (e.g., 512)
\(D\) = embedding dimension (e.g., 64)
Key Ideas:ΒΆ
Encoder produces continuous \(z_e(x)\)
Vector quantization: Map \(z_e\) to nearest codebook entry
Decoder reconstructs from discrete code
Codebook learning: Update embeddings via EMA or gradients
Advantages:ΒΆ
No posterior collapse (discrete bottleneck)
Sharp reconstructions
Enable powerful priors (PixelCNN, transformers)
Compression (discrete codes)
π Reference Materials:
vae.pdf - Vae
generative_models.pdf - Generative Models
2. Vector Quantization LayerΒΆ
Forward PassΒΆ
Given encoder output \(z_e(x) \in \mathbb{R}^{H \times W \times D}\):
1. Find nearest neighbor: $\(k^* = \arg\min_k \|z_e(x) - e_k\|_2\)$
2. Quantize: $\(z_q(x) = e_{k^*}\)$
Gradient Flow ProblemΒΆ
Quantization is non-differentiable!
Straight-through estimator: $\(\nabla_{z_e} L = \nabla_{z_q} L\)$
Copy gradients from decoder to encoder.
VQ-VAE LossΒΆ
where \(\text{sg}[\cdot]\) = stop gradient.
Terms:
Reconstruction: Standard autoencoder loss
Codebook loss: Update embeddings toward encoder outputs
Commitment loss: Encourage encoder to commit to codebook
Vector Quantizer ImplementationΒΆ
The vector quantizer is the core innovation of VQ-VAE: instead of a continuous latent space, it maintains a codebook of \(K\) embedding vectors and maps each encoder output to its nearest codebook entry. Formally, given encoder output \(z_e(x)\), the quantized representation is \(z_q(x) = e_k\) where \(k = \arg\min_j \|z_e(x) - e_j\|_2\). Because argmin is not differentiable, gradients are passed through the quantization step using the straight-through estimator β the decoder receives the quantized vector on the forward pass but the encoder gradient flows as if quantization were the identity. The codebook itself is updated via an exponential moving average of the encoder outputs assigned to each entry, which is more stable than gradient-based updates.
class VectorQuantizer(nn.Module):
"""Vector quantization layer with codebook."""
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
# Codebook
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, z_e):
"""
z_e: [B, C, H, W] encoder output
Returns: z_q, loss, perplexity, encodings
"""
# Reshape to [B*H*W, C]
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous()
z_e_flat = z_e_flat.view(-1, self.embedding_dim)
# Compute distances to codebook entries
# ||z_e - e||^2 = ||z_e||^2 + ||e||^2 - 2*z_e^T*e
distances = (torch.sum(z_e_flat**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(z_e_flat, self.embedding.weight.t()))
# Find nearest codebook entry
encoding_indices = torch.argmin(distances, dim=1)
encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
# Quantize
z_q_flat = torch.matmul(encodings, self.embedding.weight)
z_q = z_q_flat.view(z_e.permute(0, 2, 3, 1).shape)
# Compute loss
e_latent_loss = F.mse_loss(z_q.detach(), z_e.permute(0, 2, 3, 1))
q_latent_loss = F.mse_loss(z_q, z_e.permute(0, 2, 3, 1).detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight-through estimator
z_q = z_e + (z_q - z_e).detach()
z_q = z_q.permute(0, 3, 1, 2).contiguous()
# Perplexity (measure of codebook usage)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return z_q, loss, perplexity, encoding_indices
# Test VQ layer
vq = VectorQuantizer(num_embeddings=128, embedding_dim=64).to(device)
z_e = torch.randn(4, 64, 7, 7).to(device)
z_q, loss, perplexity, indices = vq(z_e)
print(f"Input shape: {z_e.shape}")
print(f"Quantized shape: {z_q.shape}")
print(f"VQ loss: {loss.item():.4f}")
print(f"Perplexity: {perplexity.item():.2f} / {128}")
print(f"Unique codes used: {len(torch.unique(indices))}")
Complete VQ-VAE ModelΒΆ
The VQ-VAE model chains together an encoder (convolutional layers that compress the image to spatial feature maps), the vector quantizer (which discretizes each spatial position into a codebook index), and a decoder (transposed convolutions that reconstruct the image from quantized features). The total loss has three terms: reconstruction loss (how well the decoder reconstructs the input), codebook loss (how close codebook vectors are to encoder outputs), and commitment loss (how close encoder outputs are to their assigned codebook vectors, weighted by \(\beta\)). The discrete bottleneck prevents posterior collapse and produces a compressed, discrete representation that can be modeled autoregressively for generation.
class Encoder(nn.Module):
def __init__(self, in_channels=1, hidden_dims=[32, 64], latent_dim=64):
super().__init__()
modules = []
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, h_dim, kernel_size=4, stride=2, padding=1),
nn.ReLU()
)
)
in_channels = h_dim
modules.append(
nn.Sequential(
nn.Conv2d(hidden_dims[-1], latent_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
)
self.encoder = nn.Sequential(*modules)
def forward(self, x):
return self.encoder(x)
class Decoder(nn.Module):
def __init__(self, latent_dim=64, hidden_dims=[64, 32], out_channels=1):
super().__init__()
modules = []
in_channels = latent_dim
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.ConvTranspose2d(in_channels, h_dim, kernel_size=4, stride=2, padding=1),
nn.ReLU()
)
)
in_channels = h_dim
modules.append(
nn.Sequential(
nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1),
nn.Sigmoid()
)
)
self.decoder = nn.Sequential(*modules)
def forward(self, z):
return self.decoder(z)
class VQVAE(nn.Module):
def __init__(self, num_embeddings=128, embedding_dim=64):
super().__init__()
self.encoder = Encoder(latent_dim=embedding_dim)
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
self.decoder = Decoder(latent_dim=embedding_dim)
def forward(self, x):
z_e = self.encoder(x)
z_q, vq_loss, perplexity, indices = self.vq(z_e)
x_recon = self.decoder(z_q)
return x_recon, vq_loss, perplexity
# Initialize model
model = VQVAE(num_embeddings=256, embedding_dim=64).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
Training VQ-VAEΒΆ
During training, the three loss components work in concert: the reconstruction loss drives the encoder and decoder to preserve information, while the codebook and commitment losses keep the quantization tightly coupled. Key metrics to monitor include codebook utilization (what fraction of the \(K\) entries are actively used) and perplexity (the effective number of codebook entries being used, computed as \(\exp(-\sum_k p_k \log p_k)\)). Low utilization or perplexity indicates codebook collapse β a common failure mode where only a few entries are used. Techniques like codebook reset (reinitializing dead entries) and EMA updates help maintain healthy utilization.
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=2e-4)
# Training loop
n_epochs = 20
history = {'recon_loss': [], 'vq_loss': [], 'perplexity': []}
model.train()
for epoch in range(n_epochs):
total_recon = 0
total_vq = 0
total_perplexity = 0
for images, _ in train_loader:
images = images.to(device)
optimizer.zero_grad()
# Forward
recon, vq_loss, perplexity = model(images)
# Reconstruction loss
recon_loss = F.mse_loss(recon, images)
# Total loss
loss = recon_loss + vq_loss
# Backward
loss.backward()
optimizer.step()
total_recon += recon_loss.item()
total_vq += vq_loss.item()
total_perplexity += perplexity.item()
# Record
n_batches = len(train_loader)
history['recon_loss'].append(total_recon / n_batches)
history['vq_loss'].append(total_vq / n_batches)
history['perplexity'].append(total_perplexity / n_batches)
print(f"Epoch [{epoch+1}/{n_epochs}] "
f"Recon: {history['recon_loss'][-1]:.4f} "
f"VQ: {history['vq_loss'][-1]:.4f} "
f"Perplexity: {history['perplexity'][-1]:.1f}")
print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
axes[0].plot(history['recon_loss'], linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('MSE Loss', fontsize=12)
axes[0].set_title('Reconstruction Loss', fontsize=13)
axes[0].grid(True, alpha=0.3)
axes[1].plot(history['vq_loss'], linewidth=2, color='orange')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('VQ Loss', fontsize=12)
axes[1].set_title('Vector Quantization Loss', fontsize=13)
axes[1].grid(True, alpha=0.3)
axes[2].plot(history['perplexity'], linewidth=2, color='green')
axes[2].axhline(y=256, color='r', linestyle='--', label='Max (codebook size)')
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Perplexity', fontsize=12)
axes[2].set_title('Codebook Usage', fontsize=13)
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Reconstruction and Codebook AnalysisΒΆ
Evaluating a VQ-VAE involves both reconstruction quality (comparing original images to their encode-quantize-decode outputs) and codebook health (visualizing which codes are used and how frequently). High-quality reconstructions with low codebook utilization suggest the model has found a compact representation, while poor reconstructions with high utilization indicate the codebook is too small or the encoder-decoder lacks capacity. Visualizing the codebook entries themselves (by decoding each entry independently) reveals what visual concepts each code represents β an important step toward understanding the learned discrete vocabulary.
model.eval()
# Get test samples
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=10)
images, labels = next(iter(test_loader))
images = images.to(device)
# Reconstruct
with torch.no_grad():
recon, _, _ = model(images)
# Visualize
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(10):
axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_ylabel('Original', fontsize=11)
axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_ylabel('Reconstructed', fontsize=11)
plt.suptitle('VQ-VAE Reconstructions', fontsize=14)
plt.tight_layout()
plt.show()
# Analyze codebook
print("\nCodebook analysis:")
with torch.no_grad():
all_indices = []
for images, _ in test_loader:
images = images.to(device)
z_e = model.encoder(images)
_, _, _, indices = model.vq(z_e)
all_indices.append(indices.cpu())
all_indices = torch.cat(all_indices)
unique_codes = torch.unique(all_indices)
print(f"Total codes used: {len(unique_codes)} / 256")
print(f"Utilization: {100 * len(unique_codes) / 256:.1f}%")
# Code frequency
plt.figure(figsize=(12, 4))
code_counts = torch.bincount(all_indices, minlength=256)
plt.bar(range(256), code_counts.numpy(), alpha=0.7)
plt.xlabel('Codebook Index', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Codebook Entry Usage', fontsize=13)
plt.grid(True, alpha=0.3, axis='y')
plt.show()
SummaryΒΆ
Key Innovations:ΒΆ
Discrete latent space: Codebook of learned embeddings
Vector quantization: Nearest neighbor assignment
Straight-through estimator: Gradient flow through quantization
No posterior collapse: Discrete bottleneck enforces usage
VQ-VAE Loss:ΒΆ
Advantages:ΒΆ
Sharp reconstructions: No Gaussian blurring
Compression: Discrete codes highly compressible
Powerful priors: Can use autoregressive models on codes
No mode collapse: Unlike GANs
Applications:ΒΆ
Image generation: VQ-VAE-2, DALL-E
Audio synthesis: WaveNet, Jukebox
Video modeling: VideoGPT
Reinforcement learning: World models
Extensions:ΒΆ
VQ-VAE-2 (Razavi et al., 2019): Hierarchical codebooks
EMA updates: Exponential moving average for codebook
Reset unused codes: Prevent codebook collapse
Product quantization: Multiple codebooks
Training Tips:ΒΆ
Monitor perplexity (codebook usage)
Tune commitment cost \(\beta\)
Use EMA for more stable training
Reset dead codes periodically
Next Steps:ΒΆ
14_normalizing_flows.ipynb - Exact likelihood models
07_hierarchical_vae.ipynb - Hierarchical latent variables
03_variational_autoencoders_advanced.ipynb - VAE foundations
Advanced Vector Quantized Variational Autoencoders (VQ-VAE): Theory and PracticeΒΆ
1. Introduction to VQ-VAEΒΆ
Vector Quantized Variational Autoencoders (VQ-VAE) combine the benefits of VAEs with discrete latent representations by introducing a vector quantization bottleneck. Unlike traditional VAEs that learn continuous latent distributions, VQ-VAE learns a discrete codebook of embedding vectors and maps encoder outputs to the nearest codebook entry.
Key innovation: Replace continuous latent distribution \(q(z|x)\) with discrete codebook lookup, enabling:
Discrete latent codes: Natural for structured data (text, music, images)
Powerful priors: Autoregressive models (PixelCNN, Transformers) for generation
High-quality reconstruction: No posterior collapse, better than vanilla VAE
Architecture overview:
Input x β Encoder β z_e(x) β Quantize (nearest codebook) β z_q β Decoder β Reconstruction xΜ
Mathematical formulation:
Encoder: \(z_e(x) = E_\theta(x) \in \mathbb{R}^{H \times W \times D}\)
Quantization: \(z_q(x) = e_k\) where \(k = \arg\min_j \|z_e(x) - e_j\|_2\)
Codebook: \(\mathcal{E} = \{e_1, e_2, \ldots, e_K\} \subset \mathbb{R}^D\) with \(K\) entries
Decoder: \(\hat{x} = D_\phi(z_q)\)
Advantages over VAE:
No KL divergence term (no posterior collapse)
Discrete codes easier to model with autoregressive priors
Better reconstruction quality
Codebook learns meaningful discrete representations
Applications:
High-fidelity image generation (DALL-E, Imagen)
Audio synthesis (WaveNet, Jukebox)
Video generation (VideoGPT)
Compression (learned compression with discrete codes)
2. Vector Quantization: Core MechanismΒΆ
2.1 Codebook and Nearest Neighbor LookupΒΆ
The codebook \(\mathcal{E} = \{e_1, \ldots, e_K\}\) contains \(K\) embedding vectors \(e_i \in \mathbb{R}^D\).
Quantization operation: $\(q(z_e) = e_k, \quad \text{where } k = \arg\min_{j \in \{1,\ldots,K\}} \|z_e - e_j\|_2^2\)$
For spatial encoder outputs \(z_e \in \mathbb{R}^{H \times W \times D}\):
Each spatial location \((h, w)\) gets quantized independently
\(z_q[h, w] = e_{k_{h,w}}\) where \(k_{h,w} = \arg\min_j \|z_e[h, w] - e_j\|_2^2\)
Results in discrete code map: \(K \in \{1, \ldots, K\}^{H \times W}\)
Properties:
Non-differentiable: \(\arg\min\) operation has no gradient
Straight-through estimator: Copy gradients from decoder to encoder
Codebook learning: Update embeddings via exponential moving average or gradient descent
2.2 Straight-Through EstimatorΒΆ
Since quantization is non-differentiable, use straight-through estimator for backpropagation:
Forward pass: \(z_q = q(z_e) = e_k\) (actual quantization)
Backward pass: \(\frac{\partial \mathcal{L}}{\partial z_e} = \frac{\partial \mathcal{L}}{\partial z_q}\) (copy gradients)
Intuition: Pretend quantization is identity during backprop.
Implementation:
z_q = z_e + (quantized - z_e).detach()
This creates \(z_q\) that equals quantized in forward pass but has gradients of \(z_e\) in backward pass.
3. Loss FunctionΒΆ
3.1 Total LossΒΆ
VQ-VAE optimizes three terms:
where \(\text{sg}[\cdot]\) denotes stop-gradient (no gradient flows through).
Components:
Reconstruction loss \(\mathcal{L}_{\text{recon}}\): $\(\mathcal{L}_{\text{recon}} = \|x - D(z_q)\|_2^2 \quad \text{or} \quad -\log p(x | z_q)\)$
Trains decoder to reconstruct input from quantized codes.
Codebook loss (commitment from encoder): $\(\mathcal{L}_{\text{codebook}} = \|\text{sg}[z_e] - e\|_2^2\)$
Moves codebook embeddings \(e\) towards encoder outputs \(z_e\) (no gradient to encoder).
Commitment loss (from encoder to codebook): $\(\mathcal{L}_{\text{commit}} = \beta \|z_e - \text{sg}[e]\|_2^2\)$
Encourages encoder outputs to stay close to chosen codebook entry (no gradient to codebook).
Typical value: \(\beta = 0.25\)
3.2 Exponential Moving Average (EMA) Codebook UpdateΒΆ
Alternative to gradient-based codebook updates: use EMA statistics.
Update rule: $\(N_i^{(t)} = \gamma N_i^{(t-1)} + (1 - \gamma) n_i^{(t)}\)\( \)\(m_i^{(t)} = \gamma m_i^{(t-1)} + (1 - \gamma) \sum_{z_e: q(z_e)=e_i} z_e\)\( \)\(e_i^{(t)} = \frac{m_i^{(t)}}{N_i^{(t)}}\)$
where:
\(N_i\) = count of assignments to codebook entry \(i\)
\(m_i\) = sum of encoder outputs assigned to \(i\)
\(\gamma\) = decay rate (e.g., 0.99)
Advantage: More stable than gradient descent, no codebook loss term needed.
Loss with EMA: $\(\mathcal{L}_{\text{EMA}} = \mathcal{L}_{\text{recon}} + \beta \|z_e - \text{sg}[e]\|_2^2\)$
4. Codebook Initialization and CollapseΒΆ
4.1 Codebook Collapse ProblemΒΆ
Problem: Some codebook entries never get used (dead codes).
Causes:
Poor initialization (random embeddings far from data)
Training dynamics (some codes dominate)
High codebook size \(K\) relative to data diversity
Detection: Monitor codebook usage statistics (perplexity).
Perplexity: $\(\text{Perplexity} = \exp\left(-\sum_{i=1}^K p_i \log p_i\right)\)$
where \(p_i\) = fraction of assignments to codebook entry \(i\).
High perplexity (close to \(K\)): Good utilization
Low perplexity: Collapse (few codes used)
4.2 Solutions to Codebook CollapseΒΆ
1. Random restart:
Periodically reinitialize unused codes
Replace dead codes with random encoder outputs
2. Codebook initialization:
Initialize from k-means clustering on encoder outputs
Better starting point than random initialization
3. Commitment loss weight:
Lower \(\beta\) encourages more exploration
Higher \(\beta\) forces encoder to commit to existing codes
4. Gumbel-Softmax relaxation:
Replace hard quantization with soft differentiable version
Gradually anneal temperature to approach hard quantization
5. Product quantization:
Use multiple smaller codebooks instead of one large one
Reduces chance of collapse, increases effective codebook size
5. VQ-VAE Architecture DetailsΒΆ
5.1 EncoderΒΆ
Purpose: Map input \(x\) to continuous representation \(z_e\).
Architecture:
Convolutional layers with downsampling (stride 2)
Residual blocks for capacity
Final layer outputs \(z_e \in \mathbb{R}^{H \times W \times D}\)
Example (CIFAR-10):
Input: [3, 32, 32]
Conv(128, 4Γ4, stride=2, padding=1) β [128, 16, 16]
ResBlock(128) β [128, 16, 16]
ResBlock(128) β [128, 16, 16]
Conv(256, 4Γ4, stride=2, padding=1) β [256, 8, 8]
ResBlock(256) β [256, 8, 8]
Conv(D, 3Γ3, padding=1) β [D, 8, 8]
Output: z_e of shape [D, 8, 8]
Embedding dimension: \(D = 64\) or \(D = 256\) typical.
5.2 QuantizerΒΆ
Operation: For each spatial location, find nearest codebook entry.
Codebook: \(\mathcal{E} \in \mathbb{R}^{K \times D}\), typically \(K = 512\) or \(K = 1024\).
Distance metric: \(L_2\) distance $\(d(z_e[h,w], e_j) = \|z_e[h,w] - e_j\|_2^2 = \sum_{d=1}^D (z_e[h,w,d] - e_j[d])^2\)$
Efficient implementation: $\(\|z_e - e_j\|^2 = \|z_e\|^2 - 2 z_e^T e_j + \|e_j\|^2\)$
Compute as matrix multiplication: \(z_e \mathcal{E}^T\) (batch operation).
5.3 DecoderΒΆ
Purpose: Reconstruct input from quantized codes \(z_q\).
Architecture:
Transposed convolutions for upsampling
Residual blocks
Final layer outputs reconstruction \(\hat{x}\)
Example (CIFAR-10):
Input: z_q of shape [D, 8, 8]
Conv(256, 3Γ3, padding=1) β [256, 8, 8]
ResBlock(256) β [256, 8, 8]
ConvTranspose(128, 4Γ4, stride=2, padding=1) β [128, 16, 16]
ResBlock(128) β [128, 16, 16]
ResBlock(128) β [128, 16, 16]
ConvTranspose(3, 4Γ4, stride=2, padding=1) β [3, 32, 32]
Output: xΜ of shape [3, 32, 32]
Symmetry: Decoder often mirrors encoder architecture.
6. Training ProcedureΒΆ
6.1 AlgorithmΒΆ
Input: Dataset \(\{x_i\}\), codebook size \(K\), embedding dim \(D\), commitment weight \(\beta\)
Initialize:
Encoder \(E_\theta\), Decoder \(D_\phi\) with random weights
Codebook \(\mathcal{E} \in \mathbb{R}^{K \times D}\) randomly or via k-means
Optimizer (Adam with \(\text{lr} = 10^{-4}\))
Training loop:
for each batch x:
# Forward pass
z_e = E_ΞΈ(x) # Encode
k = argmin_j ||z_e - e_j||Β² # Quantize
z_q = e_k # Lookup
z_q = z_e + (z_q - z_e).detach() # Straight-through
xΜ = D_Ο(z_q) # Decode
# Compute losses
L_recon = ||x - xΜ||Β²
L_codebook = ||sg[z_e] - e||Β²
L_commit = Ξ² ||z_e - sg[e]||Β²
L = L_recon + L_codebook + L_commit
# Backward pass
L.backward()
optimizer.step()
# Optional: EMA codebook update
update_codebook_ema(z_e, k)
6.2 HyperparametersΒΆ
Parameter |
Typical Value |
Impact |
|---|---|---|
Codebook size \(K\) |
512-1024 |
Larger = more capacity, risk of collapse |
Embedding dim \(D\) |
64-256 |
Higher = more expressive |
Commitment \(\beta\) |
0.25 |
Higher = stronger commitment |
Learning rate |
\(10^{-4}\) |
Standard Adam |
Batch size |
32-128 |
Larger helps codebook statistics |
EMA decay \(\gamma\) |
0.99 |
Codebook update smoothing |
6.3 MonitoringΒΆ
Key metrics:
Reconstruction loss: Should decrease steadily
Perplexity: Should be high (close to \(K\))
Codebook usage: Histogram of code assignments
Commitment loss: Should stabilize
Early stopping: Monitor validation reconstruction loss.
7. Prior Models for GenerationΒΆ
7.1 Two-Stage TrainingΒΆ
VQ-VAE enables powerful two-stage generation:
Stage 1: Train VQ-VAE
Learn encoder, decoder, codebook
Obtain discrete latent codes for dataset
Stage 2: Train prior on codes
Model distribution \(p(k)\) over discrete codes
Use autoregressive models (PixelCNN, Transformers)
Generation:
1. Sample code sequence: k ~ p(k) # Prior model
2. Lookup embeddings: z_q = e_k # Codebook
3. Decode to image: xΜ = D(z_q) # Decoder
7.2 PixelCNN PriorΒΆ
Architecture: Autoregressive CNN that models \(p(k_{h,w} | k_{<(h,w)})\)
Factorization: $\(p(\mathbf{k}) = \prod_{h=1}^H \prod_{w=1}^W p(k_{h,w} | k_{1:h-1,:}, k_{h,1:w-1})\)$
Masked convolutions: Ensure autoregressive property (no future information).
Training: Maximize log-likelihood $\(\mathcal{L}_{\text{prior}} = -\mathbb{E}_{\mathbf{k} \sim \text{data}}[\log p(\mathbf{k})]\)$
Sampling: Sequential sampling row-by-row, left-to-right.
7.3 Transformer PriorΒΆ
Architecture: Apply Transformer decoder to flattened code sequence.
Sequence: Flatten \(k \in \{1, \ldots, K\}^{H \times W}\) to \(k \in \{1, \ldots, K\}^{H \cdot W}\)
Autoregressive modeling: $\(p(\mathbf{k}) = \prod_{t=1}^{H \cdot W} p(k_t | k_{<t})\)$
Advantages over PixelCNN:
Global attention (long-range dependencies)
Parallel training (teacher forcing)
State-of-the-art quality
DALL-E: VQ-VAE + Transformer prior for text-to-image generation.
8. VQ-VAE-2: Hierarchical ExtensionΒΆ
8.1 MotivationΒΆ
Limitation of VQ-VAE: Single-scale quantization misses hierarchical structure.
Solution (VQ-VAE-2): Multi-level hierarchy of discrete codes.
Architecture:
Input β Encoder β [Top codes (coarse)] β Prior_top
β
Decoder_top β [Bottom codes (fine)] β Prior_bottom
β
Decoder_bottom β Reconstruction
8.2 Two-Level HierarchyΒΆ
Top level: Coarse codes \(k_{\text{top}} \in \{1, \ldots, K\}^{H_1 \times W_1}\)
Lower resolution (\(H_1 \times W_1 = 8 \times 8\) for 256Γ256 input)
Captures global structure
Bottom level: Fine codes \(k_{\text{bot}} \in \{1, \ldots, K\}^{H_2 \times W_2}\)
Higher resolution (\(H_2 \times W_2 = 32 \times 32\))
Captures local details
Conditional prior: $\(p(k_{\text{bot}} | k_{\text{top}}) = \prod_{i} p(k_{\text{bot}}^i | k_{\text{top}}, k_{\text{bot}}^{<i})\)$
Generation:
Sample top codes: \(k_{\text{top}} \sim p(k_{\text{top}})\)
Sample bottom codes: \(k_{\text{bot}} \sim p(k_{\text{bot}} | k_{\text{top}})\)
Decode: \(x = D(z_q^{\text{bot}}, z_q^{\text{top}})\)
Benefits:
Better quality (ImageNet 256Γ256)
More structured latent space
Conditional generation easier
9. Variants and ExtensionsΒΆ
9.1 Gumbel-Softmax VQ-VAEΒΆ
Problem: Straight-through estimator is biased.
Solution: Replace hard quantization with Gumbel-Softmax (temperature-annealed soft quantization).
Gumbel-Softmax: $\(\pi_j = \frac{\exp((\log \alpha_j + g_j) / \tau)}{\sum_{k=1}^K \exp((\log \alpha_k + g_k) / \tau)}\)$
where:
\(\alpha_j \propto \exp(-\|z_e - e_j\|^2)\) (similarity to codebook)
\(g_j \sim \text{Gumbel}(0, 1)\) (noise for exploration)
\(\tau\) = temperature (anneal from 1.0 to 0.1)
Soft quantization: $\(z_q = \sum_{j=1}^K \pi_j e_j\)$
Advantage: Fully differentiable, gradients flow to codebook and encoder.
Disadvantage: Soft codes during training, need to switch to hard codes for inference.
9.2 Product Quantization (PQ)ΒΆ
Idea: Use \(M\) smaller codebooks instead of one large codebook.
Codebooks: \(\mathcal{E}_1, \ldots, \mathcal{E}_M\) each of size \(K_s\) (subcodebook size)
Embedding split: \(z_e = [z_e^1, \ldots, z_e^M]\) where \(z_e^m \in \mathbb{R}^{D/M}\)
Quantization: Quantize each subvector independently $\(z_q^m = \arg\min_{e \in \mathcal{E}_m} \|z_e^m - e\|^2\)$
Total codes: \(K = K_s^M\) (exponential in \(M\))
Example: \(M=4\) subcodebooks of size \(K_s=256\) gives \(256^4 = 4.3\) billion effective codes!
Advantages:
Huge effective codebook size
Less prone to collapse
More efficient memory
9.3 Residual VQ (RVQ)ΒΆ
Idea: Apply quantization iteratively on residuals.
Algorithm:
r_0 = z_e
for i = 1 to L:
k_i = argmin_j ||r_{i-1} - e_j^(i)||Β²
q_i = e_{k_i}^(i)
r_i = r_{i-1} - q_i
z_q = q_1 + q_2 + ... + q_L
Benefits:
Finer-grained quantization (\(L\) levels)
Better reconstruction quality
Codebooks learn hierarchical structure
Used in: SoundStream (audio codec), EnCodec
10. ApplicationsΒΆ
10.1 Image GenerationΒΆ
DALL-E (OpenAI, 2021):
VQ-VAE with 8,192 codebook size
Latent codes: 32Γ32 grid for 256Γ256 images
Transformer prior (12 billion parameters)
Text-to-image generation (text β codes β image)
Parti (Google, 2022):
ViT-VQGAN encoder
Encoder-decoder Transformer prior
State-of-the-art text-to-image
ImageGPT:
VQ-VAE for image tokenization
GPT-style Transformer for generation
Unsupervised pre-training for vision
10.2 Audio GenerationΒΆ
Jukebox (OpenAI, 2020):
Three-level VQ-VAE hierarchy
Top: 8 seconds @ 44.1kHz β 8,192 codes
Transformer priors for each level
Generates coherent music with lyrics
SoundStream (Google, 2021):
RVQ with 8 quantizers
3 kbps audio compression
Better than MP3 at same bitrate
10.3 Video GenerationΒΆ
VideoGPT:
VQ-VAE on video frames (spatiotemporal)
3D convolutions for encoder/decoder
Transformer prior for frame sequence
NUWA (Microsoft, 2021):
VQ-VAE for video tokenization
3D-Transformer prior
Text-to-video generation
10.4 CompressionΒΆ
Learned compression:
VQ-VAE as lossy compressor
Discrete codes transmitted/stored
Arithmetic coding on code sequence
Advantages over JPEG/PNG:
Learned on data distribution (better rate-distortion)
End-to-end optimized
Flexible bitrate (codebook size)
11. Mathematical AnalysisΒΆ
11.1 Information-Theoretic ViewΒΆ
Rate-distortion theory: Trade-off between compression rate and reconstruction quality.
VQ-VAE objective: $\(\min_{E, D, \mathcal{E}} \mathbb{E}_{x \sim p_{\text{data}}}[\underbrace{\|x - D(q(E(x)))\|^2}_{\text{Distortion}}] \quad \text{subject to } \underbrace{H \cdot W \cdot \log_2 K \text{ bits}}_{\text{Rate}}\)$
where \(H \times W\) is latent spatial size and \(K\) is codebook size.
Codebook entropy: $\(H(\mathcal{E}) = -\sum_{k=1}^K p(k) \log_2 p(k)\)$
Lower entropy (skewed usage) β compressible codes via entropy coding.
11.2 Comparison to VAEΒΆ
VAE: Continuous latent \(z \sim q(z|x) = \mathcal{N}(\mu(x), \sigma^2(x))\)
VQ-VAE: Discrete latent \(k = \arg\min_j \|E(x) - e_j\|^2\)
ELBO comparison:
VAE: $\(\log p(x) \geq \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) \| p(z))\)$
VQ-VAE: No KL term! Just reconstruction + commitment: $\(\log p(x) \approx -\|x - D(e_k)\|^2 - \beta \|E(x) - e_k\|^2\)$
Advantage: No posterior collapse (KL β 0 problem in VAE).
11.3 Codebook CapacityΒΆ
Effective capacity: Number of actively used codes.
Perplexity: $\(\text{Perplexity} = \exp(H(\mathcal{E})) = \exp\left(-\sum_{k=1}^K p(k) \log p(k)\right)\)$
Uniform usage: Perplexity = \(K\) (full capacity)
Collapsed: Perplexity βͺ \(K\) (low capacity)
Empirical: VQ-VAE typically achieves 50-90% of theoretical capacity (\(K\)).
12. Theoretical PropertiesΒΆ
12.1 Representational PowerΒΆ
Theorem (Universal approximation): For any continuous function \(f: \mathbb{R}^n \to \mathbb{R}^m\) and \(\epsilon > 0\), there exists a VQ-VAE with sufficiently large codebook \(K\) such that: $\(\|f(x) - D(q(E(x)))\| < \epsilon\)$
Proof sketch:
Encoder \(E\) can partition input space into \(K\) regions (Voronoi cells)
Each region assigned to codebook entry
Decoder \(D\) approximates function in each region
As \(K \to \infty\), partition becomes arbitrarily fine
12.2 Gradient BiasΒΆ
Straight-through estimator: $\(\frac{\partial \mathcal{L}}{\partial z_e} = \frac{\partial \mathcal{L}}{\partial z_q}\)$
Bias: Gradients pretend quantization is identity.
Analysis: Gradient error bounded by quantization error $\(\left\|\frac{\partial \mathcal{L}}{\partial z_e} - \frac{\partial \mathcal{L}}{\partial z_e}^{\text{true}}\right\| \leq C \|z_e - z_q\|\)$
where \(C\) is Lipschitz constant of loss.
Mitigation: Commitment loss \(\|z_e - z_q\|^2\) keeps quantization error small.
13. Training Challenges and SolutionsΒΆ
13.1 Codebook CollapseΒΆ
Symptoms:
Low perplexity (<10% of \(K\))
Most codes never used
Poor reconstruction despite low training loss
Solutions:
EMA updates: More stable than gradient-based
Random restart: Reinitialize unused codes every \(N\) iterations
Laplace smoothing: Add small count to all codes $\(N_i^{(t)} = \gamma N_i^{(t-1)} + (1-\gamma) n_i^{(t)} + \epsilon\)$
Lower commitment weight: \(\beta = 0.1\) instead of 0.25
Larger codebook: Redundancy helps (\(K=2048\) instead of 512)
13.2 Training InstabilityΒΆ
Symptoms:
Reconstruction loss oscillates
Codebook embeddings diverge
Solutions:
Gradient clipping: Clip encoder/decoder gradients
Learning rate warmup: Start with low LR, increase linearly
Spectral normalization: Constrain network Lipschitz constant
Batch normalization: Stabilize activations (or GroupNorm/LayerNorm)
13.3 Mode CollapseΒΆ
Symptom: Decoder ignores some codebook entries, generates same output.
Solutions:
Diverse initialization: k-means on encoder outputs
Entropy regularization: Add \(-\lambda H(\mathcal{E})\) to loss
Curriculum learning: Start with small codebook, grow over time
14. Implementation Best PracticesΒΆ
14.1 Architecture DesignΒΆ
Encoder:
Use residual connections (ResNet-style)
Gradually downsample (avoid large strides)
Final layer: Linear projection to embedding dimension
Decoder:
Mirror encoder architecture
Use transposed convolutions or upsampling + conv
Skip connections help reconstruction
Codebook:
Initialize with k-means or normal distribution \(\mathcal{N}(0, 1/D)\)
Normalize embeddings periodically (optional)
14.2 Loss WeightingΒΆ
Typical weights:
Reconstruction: 1.0 (base)
Codebook: 1.0 (if using gradients)
Commitment: 0.25
Annealing: Some works anneal commitment weight: $\(\beta(t) = \beta_{\text{max}} \cdot \min(1, t / T_{\text{warmup}})\)$
14.3 Computational EfficiencyΒΆ
Quantization:
Precompute codebook norms: \(\|e_j\|^2\)
Use matrix multiplication: \(z_e \mathcal{E}^T\)
Avoid explicit distance computation loops
Memory:
Store codes as uint8 or uint16 (not one-hot)
Gradient checkpointing for large models
Parallelization:
Quantization independent across spatial locations (parallelize)
EMA updates can be batched
15. Evaluation MetricsΒΆ
15.1 Reconstruction QualityΒΆ
Metrics:
MSE/PSNR: \(\text{PSNR} = 10 \log_{10}(255^2 / \text{MSE})\)
SSIM: Structural similarity (perceptual quality)
LPIPS: Learned perceptual similarity (better than MSE)
Target: PSNR > 25 dB for good quality, LPIPS < 0.2
15.2 Codebook MetricsΒΆ
Perplexity: $\(\text{Perplexity} = \exp\left(-\sum_k p(k) \log p(k)\right)\)$
Target: >80% of codebook size (e.g., >410 for \(K=512\))
Active codes: Number of codes used at least once in batch/epoch.
Usage histogram: Visualize distribution of code assignments.
15.3 Generation Quality (with Prior)ΒΆ
Metrics:
FID: FrΓ©chet Inception Distance (lower better, <10 excellent)
IS: Inception Score (higher better)
Precision/Recall: Sample quality vs. diversity
Benchmark (CIFAR-10):
VQ-VAE-2 + PixelCNN: FID ~20
VQ-VAE-2 + Transformer: FID ~15
16. Comparison with Other Discrete RepresentationsΒΆ
Method |
Latent Type |
Training |
Prior |
Quality |
Use Case |
|---|---|---|---|---|---|
VQ-VAE |
Discrete codes |
Reconstruction + commitment |
Autoregressive |
High |
Images, audio, video |
DALL-E |
VQ-VAE codes |
Same as VQ-VAE |
Transformer |
State-of-art |
Text-to-image |
VQGAN |
VQ-VAE + GAN |
+ Adversarial + perceptual |
Transformer |
Higher |
High-res images |
Discrete VAE |
Categorical latent |
Gumbel-Softmax |
MLP/Transformer |
Moderate |
Structured data |
VQ-GAN |
Codebook + patch discriminator |
Adversarial |
Transformer |
Highest |
Photorealistic |
VQ-VAE vs. VQGAN:
VQGAN adds adversarial loss (discriminator) + perceptual loss (LPIPS)
Better perceptual quality, more stable training
Used in Stable Diffusion, Parti
17. Advanced TopicsΒΆ
17.1 Conditional VQ-VAEΒΆ
Conditioning: Add class label \(y\) or text embedding \(c\) to decoder.
Decoder: $\(\hat{x} = D(z_q, c)\)$
Architecture:
Concatenate \(c\) to \(z_q\) spatially (via broadcasting)
Use FiLM (Feature-wise Linear Modulation) layers
Cross-attention for text conditioning
Applications: Class-conditional generation, style transfer, inpainting.
17.2 Multi-Modal VQ-VAEΒΆ
Idea: Shared codebook across modalities (image + text, image + audio).
Architecture:
Separate encoders: \(E_{\text{img}}, E_{\text{text}}\)
Shared codebook: \(\mathcal{E}\)
Separate decoders: \(D_{\text{img}}, D_{\text{text}}\)
Training: Align representations via shared codes.
Cross-modal generation:
Text β codes β image
Image β codes β caption
17.3 DisentanglementΒΆ
Goal: Learn interpretable codebook (each dimension = factor of variation).
Approaches:
Factorized codes: Separate codebooks for different factors (shape, color, texture)
Supervised disentanglement: Use labeled attributes during training
Unsupervised: Add \(\beta\)-VAE style penalty on codebook usage
Evaluation: Factor-VAE metric, MIG (Mutual Information Gap).
18. Recent Advances (2020-2024)ΒΆ
18.1 ViT-VQGANΒΆ
Innovation: Use Vision Transformer (ViT) as encoder/decoder instead of CNNs.
Advantages:
Global receptive field (vs. local CNN)
Scalability to high-resolution (1024Γ1024+)
Better long-range dependencies
Architecture:
Patch embedding β ViT encoder β quantization β ViT decoder β reconstruction
Results: State-of-the-art FID on ImageNet 256Γ256.
18.2 RQ-VAE (Residual Quantized VAE)ΒΆ
Idea: Iterative refinement with residual quantization (RVQ).
Algorithm: Apply VQ-VAE on residuals \(L\) times
z_q^(0) = 0
for l = 1 to L:
r^(l) = z_e - z_q^(l-1)
z_q^(l) = z_q^(l-1) + VQ(r^(l))
Benefits:
Better reconstruction (\(L=8\) quantizers)
Hierarchical representation
Used in audio codecs (SoundStream, EnCodec)
18.3 Finite Scalar Quantization (FSQ)ΒΆ
Motivation: Avoid codebook collapse entirely.
Idea: Replace codebook lookup with deterministic rounding.
Quantization: Round each dimension to \(\{-1, 0, 1\}\) or \(\{0, 1, \ldots, L-1\}\)
Example: \(D=8\) dimensions, \(L=3\) levels β \(3^8 = 6,561\) codes
Advantages:
No codebook (no collapse)
Fully differentiable (via straight-through on rounding)
Simpler implementation
Disadvantage: Less flexible than learned codebook.
18.4 Masked Generative ModelsΒΆ
MaskGIT: Replace autoregressive prior with masked Transformer.
Training: Mask random codes, predict masked codes (BERT-style).
Sampling: Iterative refinement (like BERT)
Start with all codes masked
Predict all codes (confidence scores)
Unmask top-\(k\) confident predictions
Repeat until all unmasked
Advantage: Parallel sampling (~10Γ faster than autoregressive).
19. Limitations and Future DirectionsΒΆ
19.1 Current LimitationsΒΆ
Codebook collapse:
Still an issue despite mitigation strategies
Requires careful tuning and monitoring
Scalability:
Large codebooks (\(K > 10,000\)) difficult to train
Memory overhead for very high-dimensional embeddings
Prior complexity:
Strong prior (large Transformer) needed for good generation
Two-stage training awkward (VQ-VAE then prior)
Reconstruction vs. generation trade-off:
Better reconstruction β more codes used
Better generation β fewer codes (easier to model)
19.2 Future DirectionsΒΆ
End-to-end training:
Train VQ-VAE and prior jointly (avoid two-stage)
Differentiable codebook via Gumbel-Softmax
Continuous relaxations:
Soft VQ-VAE (temperature annealing)
Hybrid discrete-continuous latent spaces
Theoretical understanding:
Formal analysis of codebook capacity
Convergence guarantees for EMA updates
Optimal codebook size selection
Novel applications:
3D shape generation (point clouds, meshes)
Protein structure design (discrete sequence β structure)
Reinforcement learning (discrete action representations)
Hardware acceleration:
Custom hardware for fast nearest-neighbor search
Approximate nearest neighbors (LSH, FAISS)
20. SummaryΒΆ
Key Takeaways:
Core idea: VQ-VAE replaces continuous VAE latent with discrete codebook lookup
Encoder β quantize to nearest codebook entry β decoder
Straight-through estimator for gradients
Loss function:
Reconstruction + codebook loss + commitment loss
Or: Reconstruction + commitment with EMA codebook updates
Training challenges:
Codebook collapse (low perplexity)
Solutions: EMA updates, random restart, proper initialization
Two-stage generation:
Stage 1: Train VQ-VAE (encoder, decoder, codebook)
Stage 2: Train prior on discrete codes (PixelCNN, Transformer)
Advantages over VAE:
No posterior collapse (no KL term)
Powerful autoregressive priors on discrete codes
Better reconstruction quality
Variants:
VQ-VAE-2: Hierarchical codes (coarse + fine)
Product quantization: Multiple small codebooks
Residual VQ: Iterative refinement
Gumbel-Softmax: Differentiable quantization
Applications:
Image generation: DALL-E, Parti, ImageGPT
Audio: Jukebox, SoundStream
Video: VideoGPT, NUWA
Compression: Learned codecs
Recent advances:
ViT-VQGAN: Transformer encoder/decoder
RQ-VAE: Residual quantization (8 levels)
FSQ: Finite scalar quantization (no codebook)
MaskGIT: Masked Transformers (parallel sampling)
Best practices:
Use EMA codebook updates (more stable)
Monitor perplexity (should be >50% of \(K\))
Initialize codebook via k-means
Commitment weight \(\beta = 0.25\)
When to use:
β Discrete latent representations preferred
β Autoregressive generation desired
β High-quality reconstruction needed
β Continuous latent interpolation required
β Single-stage training preferred
# ============================================================
# ADVANCED VQ-VAE: PRODUCTION IMPLEMENTATIONS
# Complete PyTorch implementations with modern techniques
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
import math
# ============================================================
# 1. VECTOR QUANTIZER
# ============================================================
class VectorQuantizer(nn.Module):
"""Vector quantization layer with codebook learning.
Maps continuous encoder outputs z_e to discrete codebook entries e_k.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25,
use_ema: bool = True, decay: float = 0.99, epsilon: float = 1e-5):
"""
Args:
num_embeddings: Number of codebook entries (K)
embedding_dim: Dimension of each embedding (D)
commitment_cost: Weight for commitment loss (Ξ²)
use_ema: Use exponential moving average for codebook updates
decay: EMA decay rate (Ξ³)
epsilon: Small constant for numerical stability
"""
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
self.use_ema = use_ema
# Codebook embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
if use_ema:
# EMA statistics
self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
self.register_buffer('ema_w', self.embedding.weight.data.clone())
self.decay = decay
self.epsilon = epsilon
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
z_e: Encoder outputs [batch_size, embedding_dim, height, width]
Returns:
z_q: Quantized outputs [batch_size, embedding_dim, height, width]
loss: VQ loss (codebook + commitment)
perplexity: Codebook usage metric
"""
# Reshape z_e: [B, D, H, W] -> [B*H*W, D]
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous()
z_e_flat = z_e_flat.view(-1, self.embedding_dim)
# Compute distances to codebook entries
# ||z_e - e_j||^2 = ||z_e||^2 + ||e_j||^2 - 2 z_e^T e_j
distances = (torch.sum(z_e_flat ** 2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight ** 2, dim=1)
- 2 * torch.matmul(z_e_flat, self.embedding.weight.t()))
# Find nearest codebook entry
encoding_indices = torch.argmin(distances, dim=1)
encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
# Quantize
z_q_flat = torch.matmul(encodings, self.embedding.weight)
z_q = z_q_flat.view(z_e.shape[0], z_e.shape[2], z_e.shape[3], self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2).contiguous()
# Compute losses
if self.use_ema and self.training:
# EMA codebook update
self._ema_update(z_e_flat, encodings)
# Only commitment loss (no codebook loss with EMA)
commitment_loss = F.mse_loss(z_e, z_q.detach())
loss = self.commitment_cost * commitment_loss
else:
# Codebook loss: ||sg[z_e] - e||^2
codebook_loss = F.mse_loss(z_q, z_e.detach())
# Commitment loss: ||z_e - sg[e]||^2
commitment_loss = F.mse_loss(z_e, z_q.detach())
loss = codebook_loss + self.commitment_cost * commitment_loss
# Straight-through estimator
z_q = z_e + (z_q - z_e).detach()
# Compute perplexity (measure of codebook usage)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
metrics = {
'loss': loss,
'perplexity': perplexity,
'encoding_indices': encoding_indices.view(z_e.shape[0], z_e.shape[2], z_e.shape[3])
}
return z_q, loss, metrics
def _ema_update(self, z_e_flat: torch.Tensor, encodings: torch.Tensor):
"""Update codebook using exponential moving average."""
# Update cluster size
self.ema_cluster_size.data.mul_(self.decay).add_(
torch.sum(encodings, dim=0), alpha=1 - self.decay
)
# Laplace smoothing
n = torch.sum(self.ema_cluster_size)
self.ema_cluster_size.data.add_(self.epsilon).div_(n + self.num_embeddings * self.epsilon).mul_(n)
# Update embeddings
dw = torch.matmul(encodings.t(), z_e_flat)
self.ema_w.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)
# Normalize
self.embedding.weight.data.copy_(self.ema_w / self.ema_cluster_size.unsqueeze(1))
class ProductQuantizer(nn.Module):
"""Product quantization with multiple subcodebooks."""
def __init__(self, num_subcodebooks: int, subcodebook_size: int, embedding_dim: int,
commitment_cost: float = 0.25):
super().__init__()
assert embedding_dim % num_subcodebooks == 0, "embedding_dim must be divisible by num_subcodebooks"
self.num_subcodebooks = num_subcodebooks
self.subcodebook_size = subcodebook_size
self.embedding_dim = embedding_dim
self.sub_dim = embedding_dim // num_subcodebooks
# Create multiple quantizers
self.quantizers = nn.ModuleList([
VectorQuantizer(subcodebook_size, self.sub_dim, commitment_cost, use_ema=True)
for _ in range(num_subcodebooks)
])
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
z_e: Encoder outputs [batch_size, embedding_dim, height, width]
Returns:
z_q: Quantized outputs [batch_size, embedding_dim, height, width]
loss: Total VQ loss across all subcodebooks
metrics: Combined metrics
"""
# Split along embedding dimension
z_e_splits = torch.split(z_e, self.sub_dim, dim=1)
# Quantize each subvector
z_q_list = []
total_loss = 0.0
total_perplexity = 0.0
for i, z_e_sub in enumerate(z_e_splits):
z_q_sub, loss_sub, metrics_sub = self.quantizers[i](z_e_sub)
z_q_list.append(z_q_sub)
total_loss += loss_sub
total_perplexity += metrics_sub['perplexity']
# Concatenate quantized subvectors
z_q = torch.cat(z_q_list, dim=1)
metrics = {
'loss': total_loss,
'perplexity': total_perplexity / self.num_subcodebooks
}
return z_q, total_loss, metrics
# ============================================================
# 2. RESIDUAL BLOCKS
# ============================================================
class ResidualBlock(nn.Module):
"""Residual block for encoder/decoder."""
def __init__(self, in_channels: int, out_channels: int, num_residual_layers: int = 2):
super().__init__()
layers = []
for i in range(num_residual_layers):
layers.extend([
nn.ReLU(inplace=True),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
kernel_size=3, padding=1, bias=False),
])
self.layers = nn.Sequential(*layers)
# Projection for skip connection if dimensions change
if in_channels != out_channels:
self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
else:
self.skip = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.skip(x) + self.layers(x)
# ============================================================
# 3. ENCODER AND DECODER
# ============================================================
class Encoder(nn.Module):
"""Convolutional encoder for VQ-VAE."""
def __init__(self, in_channels: int = 3, hidden_dims: List[int] = [128, 256],
embedding_dim: int = 64, num_residual_layers: int = 2):
super().__init__()
modules = []
# Initial convolution
current_dim = hidden_dims[0]
modules.append(nn.Conv2d(in_channels, current_dim, kernel_size=4, stride=2, padding=1))
# Downsampling blocks with residual connections
for h_dim in hidden_dims:
modules.append(ResidualBlock(current_dim, h_dim, num_residual_layers))
if h_dim != current_dim:
# Additional downsampling
modules.append(nn.Conv2d(h_dim, h_dim, kernel_size=4, stride=2, padding=1))
current_dim = h_dim
# Final projection to embedding dimension
modules.extend([
nn.ReLU(inplace=True),
nn.Conv2d(current_dim, embedding_dim, kernel_size=3, padding=1)
])
self.encoder = nn.Sequential(*modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input images [batch_size, in_channels, height, width]
Returns:
z_e: Encoder outputs [batch_size, embedding_dim, h, w]
"""
return self.encoder(x)
class Decoder(nn.Module):
"""Convolutional decoder for VQ-VAE."""
def __init__(self, embedding_dim: int = 64, hidden_dims: List[int] = [256, 128],
out_channels: int = 3, num_residual_layers: int = 2):
super().__init__()
# Reverse hidden dimensions for decoder
hidden_dims = list(reversed(hidden_dims))
modules = []
# Initial projection
current_dim = hidden_dims[0]
modules.append(nn.Conv2d(embedding_dim, current_dim, kernel_size=3, padding=1))
# Upsampling blocks with residual connections
for i, h_dim in enumerate(hidden_dims):
modules.append(ResidualBlock(current_dim, h_dim, num_residual_layers))
if i < len(hidden_dims) - 1:
# Upsampling (except last layer)
next_dim = hidden_dims[i + 1]
modules.append(nn.ConvTranspose2d(h_dim, next_dim, kernel_size=4, stride=2, padding=1))
current_dim = next_dim
else:
current_dim = h_dim
# Final upsampling to original resolution
modules.extend([
nn.ReLU(inplace=True),
nn.ConvTranspose2d(current_dim, out_channels, kernel_size=4, stride=2, padding=1),
nn.Sigmoid() # Assume input normalized to [0, 1]
])
self.decoder = nn.Sequential(*modules)
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
"""
Args:
z_q: Quantized latents [batch_size, embedding_dim, h, w]
Returns:
x_recon: Reconstructed images [batch_size, out_channels, height, width]
"""
return self.decoder(z_q)
# ============================================================
# 4. VQ-VAE MODEL
# ============================================================
class VQVAE(nn.Module):
"""Complete VQ-VAE model."""
def __init__(self, in_channels: int = 3, hidden_dims: List[int] = [128, 256],
num_embeddings: int = 512, embedding_dim: int = 64,
commitment_cost: float = 0.25, use_ema: bool = True):
super().__init__()
self.encoder = Encoder(in_channels, hidden_dims, embedding_dim)
self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost, use_ema)
self.decoder = Decoder(embedding_dim, hidden_dims, in_channels)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
x: Input images [batch_size, in_channels, height, width]
Returns:
x_recon: Reconstructed images
vq_loss: Vector quantization loss
metrics: Dictionary of metrics
"""
# Encode
z_e = self.encoder(x)
# Quantize
z_q, vq_loss, vq_metrics = self.quantizer(z_e)
# Decode
x_recon = self.decoder(z_q)
metrics = {
'vq_loss': vq_loss,
'perplexity': vq_metrics['perplexity'],
'encoding_indices': vq_metrics['encoding_indices']
}
return x_recon, vq_loss, metrics
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode input to discrete codes."""
z_e = self.encoder(x)
_, _, metrics = self.quantizer(z_e)
return metrics['encoding_indices']
def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode discrete codes to images."""
# Lookup embeddings
z_q = self.quantizer.embedding(codes)
z_q = z_q.permute(0, 3, 1, 2).contiguous()
# Decode
x_recon = self.decoder(z_q)
return x_recon
# ============================================================
# 5. VQ-VAE-2 (HIERARCHICAL)
# ============================================================
class VQVAE2(nn.Module):
"""Hierarchical VQ-VAE with top and bottom levels."""
def __init__(self, in_channels: int = 3,
num_embeddings_top: int = 512, num_embeddings_bottom: int = 512,
embedding_dim: int = 64, commitment_cost: float = 0.25):
super().__init__()
# Top-level encoder (coarse, 4Γ downsampling)
self.encoder_top = Encoder(in_channels, [64, 128], embedding_dim)
self.quantizer_top = VectorQuantizer(num_embeddings_top, embedding_dim, commitment_cost)
# Bottom-level encoder (fine, 2Γ downsampling from top)
self.encoder_bottom = Encoder(in_channels, [128, 256], embedding_dim)
self.quantizer_bottom = VectorQuantizer(num_embeddings_bottom, embedding_dim, commitment_cost)
# Decoder (uses both levels)
self.decoder_bottom = Decoder(embedding_dim * 2, [256, 128], in_channels)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
x: Input images [batch_size, in_channels, height, width]
Returns:
x_recon: Reconstructed images
total_loss: Total VQ loss (top + bottom)
metrics: Combined metrics
"""
# Encode at both levels
z_e_top = self.encoder_top(x)
z_e_bottom = self.encoder_bottom(x)
# Quantize
z_q_top, loss_top, metrics_top = self.quantizer_top(z_e_top)
z_q_bottom, loss_bottom, metrics_bottom = self.quantizer_bottom(z_e_bottom)
# Upsample top to match bottom spatial size
z_q_top_upsampled = F.interpolate(z_q_top, size=z_q_bottom.shape[2:], mode='nearest')
# Concatenate top and bottom
z_q_combined = torch.cat([z_q_top_upsampled, z_q_bottom], dim=1)
# Decode
x_recon = self.decoder_bottom(z_q_combined)
total_loss = loss_top + loss_bottom
metrics = {
'vq_loss': total_loss,
'perplexity_top': metrics_top['perplexity'],
'perplexity_bottom': metrics_bottom['perplexity']
}
return x_recon, total_loss, metrics
# ============================================================
# 6. GUMBEL-SOFTMAX VQ-VAE
# ============================================================
class GumbelVectorQuantizer(nn.Module):
"""Gumbel-Softmax quantizer (differentiable alternative to hard quantization)."""
def __init__(self, num_embeddings: int, embedding_dim: int, temperature: float = 1.0,
hard: bool = True):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.temperature = temperature
self.hard = hard
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
z_e: Encoder outputs [batch_size, embedding_dim, height, width]
Returns:
z_q: Quantized outputs (soft or hard)
loss: Dummy loss (for API compatibility)
"""
# Reshape
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
# Compute similarities (negative distances)
logits = -torch.sum((z_e_flat.unsqueeze(1) - self.embedding.weight.unsqueeze(0)) ** 2, dim=2)
# Gumbel-Softmax
if self.training:
# Add Gumbel noise
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
logits = (logits + gumbel_noise) / self.temperature
# Softmax
probs = F.softmax(logits, dim=1)
if self.hard:
# Hard quantization in forward, soft in backward
indices = torch.argmax(probs, dim=1)
hard_probs = F.one_hot(indices, self.num_embeddings).float()
probs = hard_probs - probs.detach() + probs
# Weighted sum of embeddings
z_q_flat = torch.matmul(probs, self.embedding.weight)
z_q = z_q_flat.view(z_e.shape[0], z_e.shape[2], z_e.shape[3], self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2).contiguous()
# No explicit loss (gradients flow through Gumbel-Softmax)
loss = torch.tensor(0.0, device=z_e.device)
return z_q, loss
# ============================================================
# 7. TRAINING
# ============================================================
class VQVAETrainer:
"""Trainer for VQ-VAE."""
def __init__(self, model: VQVAE, optimizer: optim.Optimizer, recon_loss_type: str = 'mse'):
self.model = model
self.optimizer = optimizer
self.recon_loss_type = recon_loss_type
def train_step(self, x: torch.Tensor) -> Dict[str, float]:
"""Single training step.
Args:
x: Input batch [batch_size, channels, H, W]
Returns:
metrics: Training metrics
"""
self.model.train()
# Forward pass
x_recon, vq_loss, metrics = self.model(x)
# Reconstruction loss
if self.recon_loss_type == 'mse':
recon_loss = F.mse_loss(x_recon, x)
elif self.recon_loss_type == 'l1':
recon_loss = F.l1_loss(x_recon, x)
else:
raise ValueError(f"Unknown reconstruction loss: {self.recon_loss_type}")
# Total loss
total_loss = recon_loss + vq_loss
# Backward pass
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# Metrics
train_metrics = {
'total_loss': total_loss.item(),
'recon_loss': recon_loss.item(),
'vq_loss': vq_loss.item(),
'perplexity': metrics['perplexity'].item()
}
return train_metrics
@torch.no_grad()
def eval_step(self, x: torch.Tensor) -> Dict[str, float]:
"""Evaluation step."""
self.model.eval()
# Forward pass
x_recon, vq_loss, metrics = self.model(x)
# Reconstruction loss
if self.recon_loss_type == 'mse':
recon_loss = F.mse_loss(x_recon, x)
elif self.recon_loss_type == 'l1':
recon_loss = F.l1_loss(x_recon, x)
# Total loss
total_loss = recon_loss + vq_loss
# Metrics
eval_metrics = {
'total_loss': total_loss.item(),
'recon_loss': recon_loss.item(),
'vq_loss': vq_loss.item(),
'perplexity': metrics['perplexity'].item()
}
return eval_metrics
# ============================================================
# 8. DEMONSTRATIONS
# ============================================================
def demo_vector_quantizer():
"""Demonstrate vector quantization."""
print("=" * 60)
print("DEMO: Vector Quantizer")
print("=" * 60)
# Create quantizer
quantizer = VectorQuantizer(num_embeddings=512, embedding_dim=64, use_ema=True)
# Dummy encoder output
z_e = torch.randn(4, 64, 8, 8)
print(f"Encoder output shape: {z_e.shape}")
print(f"Codebook size: {quantizer.num_embeddings}")
print(f"Embedding dimension: {quantizer.embedding_dim}")
# Forward pass
z_q, loss, metrics = quantizer(z_e)
print(f"\nQuantized output shape: {z_q.shape}")
print(f"VQ loss: {loss.item():.4f}")
print(f"Perplexity: {metrics['perplexity'].item():.2f} / {quantizer.num_embeddings}")
print(f"Codebook usage: {metrics['perplexity'].item() / quantizer.num_embeddings * 100:.1f}%")
print(f"Encoding indices shape: {metrics['encoding_indices'].shape}")
print()
def demo_vqvae_model():
"""Demonstrate VQ-VAE model."""
print("=" * 60)
print("DEMO: VQ-VAE Model")
print("=" * 60)
# Create model
model = VQVAE(in_channels=3, hidden_dims=[128, 256], num_embeddings=512,
embedding_dim=64, use_ema=True)
# Dummy input
x = torch.rand(4, 3, 32, 32)
print(f"Input shape: {x.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Forward pass
x_recon, vq_loss, metrics = model(x)
print(f"\nReconstruction shape: {x_recon.shape}")
print(f"VQ loss: {vq_loss.item():.4f}")
print(f"Perplexity: {metrics['perplexity'].item():.2f}")
print(f"Reconstruction error: {F.mse_loss(x_recon, x).item():.4f}")
# Encode to codes
codes = model.encode(x)
print(f"\nDiscrete codes shape: {codes.shape}")
print(f"Code range: [{codes.min().item()}, {codes.max().item()}]")
# Decode from codes
x_decoded = model.decode_codes(codes)
print(f"Decoded shape: {x_decoded.shape}")
print(f"Decode error: {F.mse_loss(x_decoded, x_recon).item():.6f}")
print()
def demo_product_quantizer():
"""Demonstrate product quantization."""
print("=" * 60)
print("DEMO: Product Quantization")
print("=" * 60)
# Create product quantizer
pq = ProductQuantizer(num_subcodebooks=4, subcodebook_size=256,
embedding_dim=64, commitment_cost=0.25)
# Dummy encoder output
z_e = torch.randn(4, 64, 8, 8)
print(f"Encoder output shape: {z_e.shape}")
print(f"Number of subcodebooks: {pq.num_subcodebooks}")
print(f"Subcodebook size: {pq.subcodebook_size}")
print(f"Effective codebook size: {pq.subcodebook_size ** pq.num_subcodebooks:,}")
# Forward pass
z_q, loss, metrics = pq(z_e)
print(f"\nQuantized output shape: {z_q.shape}")
print(f"Total VQ loss: {loss.item():.4f}")
print(f"Average perplexity: {metrics['perplexity'].item():.2f}")
print()
def demo_vqvae2():
"""Demonstrate VQ-VAE-2 (hierarchical)."""
print("=" * 60)
print("DEMO: VQ-VAE-2 (Hierarchical)")
print("=" * 60)
# Create model
model = VQVAE2(in_channels=3, num_embeddings_top=256, num_embeddings_bottom=512,
embedding_dim=64)
# Dummy input
x = torch.rand(2, 3, 32, 32)
print(f"Input shape: {x.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Forward pass
x_recon, total_loss, metrics = model(x)
print(f"\nReconstruction shape: {x_recon.shape}")
print(f"Total VQ loss: {total_loss.item():.4f}")
print(f"Top perplexity: {metrics['perplexity_top'].item():.2f}")
print(f"Bottom perplexity: {metrics['perplexity_bottom'].item():.2f}")
print()
def demo_gumbel_quantizer():
"""Demonstrate Gumbel-Softmax quantization."""
print("=" * 60)
print("DEMO: Gumbel-Softmax Quantizer")
print("=" * 60)
# Create quantizer
quantizer = GumbelVectorQuantizer(num_embeddings=512, embedding_dim=64,
temperature=1.0, hard=True)
# Dummy encoder output
z_e = torch.randn(4, 64, 8, 8)
print(f"Encoder output shape: {z_e.shape}")
print(f"Temperature: {quantizer.temperature}")
print(f"Hard quantization: {quantizer.hard}")
# Forward pass
z_q, _ = quantizer(z_e)
print(f"\nQuantized output shape: {z_q.shape}")
print(f"Quantization error: {F.mse_loss(z_e, z_q).item():.4f}")
print()
def demo_training():
"""Demonstrate VQ-VAE training."""
print("=" * 60)
print("DEMO: VQ-VAE Training")
print("=" * 60)
# Create model
model = VQVAE(in_channels=3, hidden_dims=[64, 128], num_embeddings=256,
embedding_dim=32, use_ema=True)
# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Create trainer
trainer = VQVAETrainer(model, optimizer, recon_loss_type='mse')
# Dummy training data
x_train = torch.rand(16, 3, 32, 32)
print("Training for 5 steps...")
for step in range(5):
metrics = trainer.train_step(x_train)
print(f"Step {step+1}: Loss={metrics['total_loss']:.4f}, "
f"Recon={metrics['recon_loss']:.4f}, "
f"VQ={metrics['vq_loss']:.4f}, "
f"Perplexity={metrics['perplexity']:.2f}")
print("\nEvaluating...")
x_val = torch.rand(8, 3, 32, 32)
metrics = trainer.eval_step(x_val)
print(f"Validation: Loss={metrics['total_loss']:.4f}, "
f"Recon={metrics['recon_loss']:.4f}, "
f"Perplexity={metrics['perplexity']:.2f}")
print()
def print_performance_comparison():
"""Print comprehensive performance comparison and decision guide."""
print("=" * 80)
print("VQ-VAE: COMPREHENSIVE PERFORMANCE ANALYSIS")
print("=" * 80)
print("\n" + "=" * 80)
print("1. RECONSTRUCTION QUALITY BENCHMARKS")
print("=" * 80)
recon_data = [
["Model", "Dataset", "PSNR β", "SSIM β", "LPIPS β", "Codebook Size"],
["-" * 20, "-" * 15, "-" * 10, "-" * 10, "-" * 10, "-" * 15],
["VQ-VAE", "CIFAR-10", "26.5 dB", "0.88", "0.18", "K=512"],
["VQ-VAE (large)", "CIFAR-10", "28.2 dB", "0.91", "0.14", "K=1024"],
["VQ-VAE-2", "CIFAR-10", "29.8 dB", "0.93", "0.11", "K=512Γ2"],
["VQGAN", "ImageNet", "22.4 dB", "0.72", "0.09", "K=1024"],
["Standard VAE", "CIFAR-10", "24.1 dB", "0.83", "0.25", "N/A (cont.)"],
["Ξ²-VAE", "CIFAR-10", "22.8 dB", "0.79", "0.31", "N/A (cont.)"]
]
for row in recon_data:
print(f"{row[0]:<20} {row[1]:<15} {row[2]:<10} {row[3]:<10} {row[4]:<10} {row[5]:<15}")
print("\nπ Key Observations:")
print(" β’ VQ-VAE consistently outperforms standard VAE in reconstruction")
print(" β’ VQ-VAE-2 (hierarchical) achieves best quality via multi-scale codes")
print(" β’ Larger codebook (K=1024) improves PSNR by ~2 dB vs K=512")
print(" β’ LPIPS (perceptual) more important than PSNR for image quality")
print("\n" + "=" * 80)
print("2. CODEBOOK USAGE AND PERPLEXITY")
print("=" * 80)
codebook_data = [
["Configuration", "Codebook K", "Perplexity", "Usage %", "Active Codes"],
["-" * 25, "-" * 13, "-" * 12, "-" * 10, "-" * 15],
["VQ-VAE (gradient)", "512", "180-220", "35-43%", "180-220"],
["VQ-VAE (EMA)", "512", "320-410", "63-80%", "320-410"],
["Product-VQ (4Γ256)", "256^4 β 4B", "850-980", "~85%", "850-980"],
["VQ-VAE-2 (top)", "512", "280-350", "55-68%", "280-350"],
["VQ-VAE-2 (bottom)", "512", "390-450", "76-88%", "390-450"],
["Random restart", "512", "420-480", "82-94%", "420-480"]
]
for row in codebook_data:
print(f"{row[0]:<25} {row[1]:<13} {row[2]:<12} {row[3]:<10} {row[4]:<15}")
print("\nπ Key Observations:")
print(" β’ EMA updates dramatically improve codebook usage (2Γ perplexity)")
print(" β’ Random restart further increases usage to >80%")
print(" β’ Product quantization effectively uses exponentially large codebook")
print(" β’ Bottom-level codes (VQ-VAE-2) used more than top-level (finer details)")
print("\n" + "=" * 80)
print("3. GENERATION QUALITY (WITH PRIORS)")
print("=" * 80)
gen_data = [
["Model + Prior", "Dataset", "FID β", "IS β", "Sampling Time"],
["-" * 25, "-" * 15, "-" * 10, "-" * 10, "-" * 15],
["VQ-VAE + PixelCNN", "CIFAR-10", "18.5", "6.8", "~30s (seq.)"],
["VQ-VAE-2 + PixelCNN", "CIFAR-10", "12.3", "7.9", "~60s (2-level)"],
["VQ-VAE + Transformer", "CIFAR-10", "15.2", "7.4", "~20s (seq.)"],
["VQGAN + Transformer", "ImageNet 256", "7.94", "~80", "~40s"],
["DALL-E", "MS-COCO", "~28", "N/A", "~10s (parallel)"],
["GAN (StyleGAN2)", "CIFAR-10", "2.92", "9.18", "~0.1s"],
["Diffusion (DDPM)", "CIFAR-10", "3.17", "9.46", "~50s"]
]
for row in gen_data:
print(f"{row[0]:<25} {row[1]:<15} {row[2]:<10} {row[3]:<10} {row[4]:<15}")
print("\nπ¨ Key Observations:")
print(" β’ VQ-VAE + strong prior (Transformer) achieves good generation quality")
print(" β’ Two-stage training overhead, but enables powerful autoregressive models")
print(" β’ Sampling slower than GANs due to sequential generation")
print(" β’ VQGAN (+ adversarial) bridges gap to GANs/diffusion")
print("\n" + "=" * 80)
print("4. COMPUTATIONAL COMPLEXITY")
print("=" * 80)
complexity_data = [
["Operation", "Time", "Space", "Bottleneck"],
["-" * 30, "-" * 20, "-" * 15, "-" * 30],
["Encoder forward", "O(HWC)", "O(HWC)", "Convolutions"],
["Quantization (naive)", "O(HWKΒ·D)", "O(KΒ·D)", "Distance computation"],
["Quantization (efficient)", "O(HWΒ·K + HWΒ·D)", "O(KΒ·D)", "Matrix multiply"],
["Decoder forward", "O(HWC)", "O(HWC)", "Transposed convs"],
["EMA codebook update", "O(KΒ·D)", "O(KΒ·D)", "Embedding update"],
["Full forward pass", "O(HW(C+K))", "O(HWC + KΒ·D)", "Total"]
]
for row in complexity_data:
print(f"{row[0]:<30} {row[1]:<20} {row[2]:<15} {row[3]:<30}")
print("\nβ±οΈ Typical Values (CIFAR-10, 32Γ32):")
print(" β’ HΓW = 8Γ8 (after 4Γ downsampling)")
print(" β’ K = 512 codebook entries")
print(" β’ D = 64 embedding dimension")
print(" β’ Forward pass: ~5ms (GPU), Quantization: ~0.5ms")
print("\n" + "=" * 80)
print("5. COMPARISON WITH OTHER GENERATIVE MODELS")
print("=" * 80)
comparison_data = [
["Model", "Latent", "Training", "Sampling", "Quality", "Best Use"],
["-" * 12, "-" * 12, "-" * 15, "-" * 12, "-" * 10, "-" * 30],
["VQ-VAE", "Discrete", "Stable", "2-stage", "Good", "Compression, discrete repr."],
["VAE", "Continuous", "Stable", "Fast", "Moderate", "Latent interpolation"],
["GAN", "Noise", "Unstable", "Fast", "High", "High-quality images"],
["Flow", "Continuous", "Stable", "Fast", "Moderate", "Exact likelihood"],
["Diffusion", "Noise", "Stable", "Slow", "Highest", "State-of-art generation"],
["VQGAN", "Discrete", "Moderate", "2-stage", "High", "Text-to-image (DALL-E)"]
]
for row in comparison_data:
print(f"{row[0]:<12} {row[1]:<12} {row[2]:<15} {row[3]:<12} {row[4]:<10} {row[5]:<30}")
print("\n" + "=" * 80)
print("6. TRAINING HYPERPARAMETERS")
print("=" * 80)
hyperparam_data = [
["Parameter", "Typical Range", "Recommended", "Impact"],
["-" * 25, "-" * 20, "-" * 20, "-" * 35],
["Codebook size K", "128 - 2048", "512", "Larger = more capacity"],
["Embedding dim D", "32 - 256", "64", "Higher = more expressive"],
["Commitment Ξ²", "0.1 - 1.0", "0.25", "Higher = stronger commitment"],
["Learning rate", "1e-5 - 1e-3", "1e-4 (Adam)", "Standard impact"],
["Batch size", "32 - 256", "64-128", "Larger helps codebook stats"],
["EMA decay Ξ³", "0.9 - 0.999", "0.99", "Higher = slower updates"],
["Hidden dims", "64-512", "[128, 256]", "Depth vs. capacity"],
["Num residual layers", "1 - 4", "2", "More = deeper features"]
]
for row in hyperparam_data:
print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<35}")
print("\n" + "=" * 80)
print("7. DECISION GUIDE: VQ-VAE vs. ALTERNATIVES")
print("=" * 80)
print("\nβ
USE VQ-VAE WHEN:")
print(" 1. Discrete latent representations needed (tokenization for text/image)")
print(" 2. Two-stage generation acceptable (VQ-VAE β prior)")
print(" 3. Autoregressive priors desired (PixelCNN, Transformers)")
print(" 4. Compression with discrete codes (entropy coding, efficient storage)")
print(" 5. Interpretable codebook valuable (discrete units, clustering)")
print(" 6. Avoiding posterior collapse (VAE KL β 0 problem)")
print(" 7. Multi-modal learning (shared discrete space across modalities)")
print("\nβ AVOID VQ-VAE WHEN:")
print(" 1. Continuous latent interpolation required (use VAE/Flow)")
print(" 2. Single-stage end-to-end generation preferred")
print(" 3. Fast sampling critical (GANs 100Γ faster for generation)")
print(" 4. State-of-art image quality needed (diffusion models better)")
print(" 5. Codebook collapse cannot be tolerated")
print(" 6. Simple baseline sufficient (VAE easier to implement)")
print("\n" + "=" * 80)
print("8. VARIANT SELECTION GUIDE")
print("=" * 80)
variant_data = [
["Variant", "Advantage", "Disadvantage", "When to Use"],
["-" * 20, "-" * 30, "-" * 30, "-" * 35],
["VQ-VAE (baseline)", "Simple, stable", "Single scale, ~60% usage", "Default choice"],
["VQ-VAE + EMA", "Better codebook usage (80%)", "Slight complexity", "Recommended default"],
["VQ-VAE-2", "Multi-scale, best quality", "2Γ complexity, 2 codebooks", "High-res images (256Γ256+)"],
["Product-VQ", "Huge effective codebook", "More hyperparameters", "When K>2048 needed"],
["Residual VQ", "Iterative refinement", "L quantizers needed", "Audio (high bitrate)"],
["Gumbel VQ", "Fully differentiable", "Soft codes, temperature tuning", "Experimental/research"],
["VQGAN", "Perceptual quality", "Adversarial training", "Photorealistic images"]
]
for row in variant_data:
print(f"{row[0]:<20} {row[1]:<30} {row[2]:<30} {row[3]:<35}")
print("\n" + "=" * 80)
print("9. TROUBLESHOOTING COMMON ISSUES")
print("=" * 80)
print("\nπ§ CODEBOOK COLLAPSE (Perplexity < 50%):")
print(" β Switch to EMA updates (use_ema=True)")
print(" β Add random restart (reinitialize unused codes)")
print(" β Lower commitment weight (Ξ²=0.1 instead of 0.25)")
print(" β Increase codebook size (K=1024 instead of 512)")
print(" β Check encoder initialization (k-means on outputs)")
print("\nπ§ POOR RECONSTRUCTION:")
print(" β Increase embedding dimension (D=128 instead of 64)")
print(" β Add more residual layers (num_residual_layers=3)")
print(" β Reduce commitment weight (allows more deviation)")
print(" β Check decoder capacity (mirror encoder)")
print(" β Try perceptual loss (LPIPS) instead of MSE")
print("\nπ§ TRAINING INSTABILITY:")
print(" β Gradient clipping (max_norm=1.0)")
print(" β Learning rate warmup (0β1e-4 over 1000 steps)")
print(" β Batch normalization (or GroupNorm)")
print(" β Reduce learning rate (1e-5)")
print(" β Check for NaN values (codebook initialization)")
print("\n" + "=" * 80)
print("10. APPLICATIONS AND SOTA RESULTS")
print("=" * 80)
print("\nπ¨ IMAGE GENERATION:")
print(" β’ DALL-E (2021): 12B Transformer + VQ-VAE, text-to-image")
print(" β’ Parti (2022): 20B Transformer + ViT-VQGAN, photorealistic")
print(" β’ ImageGPT: VQ-VAE tokenization + GPT, unsupervised vision")
print(" β’ VQGAN: GAN + VQ-VAE, FID 7.94 on ImageNet 256Γ256")
print("\nπ΅ AUDIO SYNTHESIS:")
print(" β’ Jukebox (2020): 3-level VQ-VAE, generates music with lyrics")
print(" β’ SoundStream (2021): RVQ (8 levels), 3 kbps compression")
print(" β’ EnCodec (2022): RVQ + adversarial, better than MP3")
print("\n㪠VIDEO GENERATION:")
print(" β’ VideoGPT: 3D VQ-VAE + Transformer, FVD 170 on UCF-101")
print(" β’ NUWA (2021): Text-to-video, 3D-Transformer prior")
print(" β’ TATS: VQ-VAE + Transformer, long videos (>100 frames)")
print("\nπ¦ COMPRESSION:")
print(" β’ Learned image compression: Outperforms JPEG at low bitrates")
print(" β’ Neural audio codecs: 3-12 kbps (vs. 128 kbps MP3)")
print(" β’ Video compression: ~10Γ bitrate reduction vs. H.264")
print("\n" + "=" * 80)
# ============================================================
# RUN ALL DEMONSTRATIONS
# ============================================================
if __name__ == "__main__":
demo_vector_quantizer()
demo_vqvae_model()
demo_product_quantizer()
demo_vqvae2()
demo_gumbel_quantizer()
demo_training()
print_performance_comparison()