import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. Motivation: Discrete Latent SpacesΒΆ

Standard VAE LimitationsΒΆ

Continuous latent space \(z \sim \mathcal{N}(\mu, \sigma^2)\):

Problems:

  1. Posterior collapse: Decoder ignores latent code

  2. Blurry reconstructions: Gaussian assumption

  3. Difficult autoregressive modeling: Continuous \(z\)

VQ-VAE Solution (van den Oord et al., 2017)ΒΆ

Discrete latent space with learned codebook!

\[e \in \mathbb{R}^{K \times D}\]

where:

  • \(K\) = codebook size (e.g., 512)

  • \(D\) = embedding dimension (e.g., 64)

Key Ideas:ΒΆ

  1. Encoder produces continuous \(z_e(x)\)

  2. Vector quantization: Map \(z_e\) to nearest codebook entry

  3. Decoder reconstructs from discrete code

  4. Codebook learning: Update embeddings via EMA or gradients

Advantages:ΒΆ

  • No posterior collapse (discrete bottleneck)

  • Sharp reconstructions

  • Enable powerful priors (PixelCNN, transformers)

  • Compression (discrete codes)

πŸ“š Reference Materials:

2. Vector Quantization LayerΒΆ

Forward PassΒΆ

Given encoder output \(z_e(x) \in \mathbb{R}^{H \times W \times D}\):

1. Find nearest neighbor: $\(k^* = \arg\min_k \|z_e(x) - e_k\|_2\)$

2. Quantize: $\(z_q(x) = e_{k^*}\)$

Gradient Flow ProblemΒΆ

Quantization is non-differentiable!

Straight-through estimator: $\(\nabla_{z_e} L = \nabla_{z_q} L\)$

Copy gradients from decoder to encoder.

VQ-VAE LossΒΆ

\[L = \underbrace{\|x - D(z_q)\|^2}_{\text{reconstruction}} + \underbrace{\|\text{sg}[z_e] - e\|^2}_{\text{codebook}} + \beta \underbrace{\|z_e - \text{sg}[e]\|^2}_{\text{commitment}}\]

where \(\text{sg}[\cdot]\) = stop gradient.

Terms:

  1. Reconstruction: Standard autoencoder loss

  2. Codebook loss: Update embeddings toward encoder outputs

  3. Commitment loss: Encourage encoder to commit to codebook

Vector Quantizer ImplementationΒΆ

The vector quantizer is the core innovation of VQ-VAE: instead of a continuous latent space, it maintains a codebook of \(K\) embedding vectors and maps each encoder output to its nearest codebook entry. Formally, given encoder output \(z_e(x)\), the quantized representation is \(z_q(x) = e_k\) where \(k = \arg\min_j \|z_e(x) - e_j\|_2\). Because argmin is not differentiable, gradients are passed through the quantization step using the straight-through estimator – the decoder receives the quantized vector on the forward pass but the encoder gradient flows as if quantization were the identity. The codebook itself is updated via an exponential moving average of the encoder outputs assigned to each entry, which is more stable than gradient-based updates.

class VectorQuantizer(nn.Module):
    """Vector quantization layer with codebook."""
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        
        # Codebook
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
    
    def forward(self, z_e):
        """
        z_e: [B, C, H, W] encoder output
        Returns: z_q, loss, perplexity, encodings
        """
        # Reshape to [B*H*W, C]
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous()
        z_e_flat = z_e_flat.view(-1, self.embedding_dim)
        
        # Compute distances to codebook entries
        # ||z_e - e||^2 = ||z_e||^2 + ||e||^2 - 2*z_e^T*e
        distances = (torch.sum(z_e_flat**2, dim=1, keepdim=True) 
                    + torch.sum(self.embedding.weight**2, dim=1)
                    - 2 * torch.matmul(z_e_flat, self.embedding.weight.t()))
        
        # Find nearest codebook entry
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
        
        # Quantize
        z_q_flat = torch.matmul(encodings, self.embedding.weight)
        z_q = z_q_flat.view(z_e.permute(0, 2, 3, 1).shape)
        
        # Compute loss
        e_latent_loss = F.mse_loss(z_q.detach(), z_e.permute(0, 2, 3, 1))
        q_latent_loss = F.mse_loss(z_q, z_e.permute(0, 2, 3, 1).detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Straight-through estimator
        z_q = z_e + (z_q - z_e).detach()
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        # Perplexity (measure of codebook usage)
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        return z_q, loss, perplexity, encoding_indices

# Test VQ layer
vq = VectorQuantizer(num_embeddings=128, embedding_dim=64).to(device)
z_e = torch.randn(4, 64, 7, 7).to(device)
z_q, loss, perplexity, indices = vq(z_e)

print(f"Input shape: {z_e.shape}")
print(f"Quantized shape: {z_q.shape}")
print(f"VQ loss: {loss.item():.4f}")
print(f"Perplexity: {perplexity.item():.2f} / {128}")
print(f"Unique codes used: {len(torch.unique(indices))}")

Complete VQ-VAE ModelΒΆ

The VQ-VAE model chains together an encoder (convolutional layers that compress the image to spatial feature maps), the vector quantizer (which discretizes each spatial position into a codebook index), and a decoder (transposed convolutions that reconstruct the image from quantized features). The total loss has three terms: reconstruction loss (how well the decoder reconstructs the input), codebook loss (how close codebook vectors are to encoder outputs), and commitment loss (how close encoder outputs are to their assigned codebook vectors, weighted by \(\beta\)). The discrete bottleneck prevents posterior collapse and produces a compressed, discrete representation that can be modeled autoregressively for generation.

class Encoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dims=[32, 64], latent_dim=64):
        super().__init__()
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim, kernel_size=4, stride=2, padding=1),
                    nn.ReLU()
                )
            )
            in_channels = h_dim
        
        modules.append(
            nn.Sequential(
                nn.Conv2d(hidden_dims[-1], latent_dim, kernel_size=3, stride=1, padding=1),
                nn.ReLU()
            )
        )
        self.encoder = nn.Sequential(*modules)
    
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim=64, hidden_dims=[64, 32], out_channels=1):
        super().__init__()
        modules = []
        
        in_channels = latent_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(in_channels, h_dim, kernel_size=4, stride=2, padding=1),
                    nn.ReLU()
                )
            )
            in_channels = h_dim
        
        modules.append(
            nn.Sequential(
                nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1),
                nn.Sigmoid()
            )
        )
        self.decoder = nn.Sequential(*modules)
    
    def forward(self, z):
        return self.decoder(z)

class VQVAE(nn.Module):
    def __init__(self, num_embeddings=128, embedding_dim=64):
        super().__init__()
        self.encoder = Encoder(latent_dim=embedding_dim)
        self.vq = VectorQuantizer(num_embeddings, embedding_dim)
        self.decoder = Decoder(latent_dim=embedding_dim)
    
    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, perplexity, indices = self.vq(z_e)
        x_recon = self.decoder(z_q)
        return x_recon, vq_loss, perplexity

# Initialize model
model = VQVAE(num_embeddings=256, embedding_dim=64).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

Training VQ-VAEΒΆ

During training, the three loss components work in concert: the reconstruction loss drives the encoder and decoder to preserve information, while the codebook and commitment losses keep the quantization tightly coupled. Key metrics to monitor include codebook utilization (what fraction of the \(K\) entries are actively used) and perplexity (the effective number of codebook entries being used, computed as \(\exp(-\sum_k p_k \log p_k)\)). Low utilization or perplexity indicates codebook collapse – a common failure mode where only a few entries are used. Techniques like codebook reset (reinitializing dead entries) and EMA updates help maintain healthy utilization.

# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=2e-4)

# Training loop
n_epochs = 20
history = {'recon_loss': [], 'vq_loss': [], 'perplexity': []}

model.train()
for epoch in range(n_epochs):
    total_recon = 0
    total_vq = 0
    total_perplexity = 0
    
    for images, _ in train_loader:
        images = images.to(device)
        
        optimizer.zero_grad()
        
        # Forward
        recon, vq_loss, perplexity = model(images)
        
        # Reconstruction loss
        recon_loss = F.mse_loss(recon, images)
        
        # Total loss
        loss = recon_loss + vq_loss
        
        # Backward
        loss.backward()
        optimizer.step()
        
        total_recon += recon_loss.item()
        total_vq += vq_loss.item()
        total_perplexity += perplexity.item()
    
    # Record
    n_batches = len(train_loader)
    history['recon_loss'].append(total_recon / n_batches)
    history['vq_loss'].append(total_vq / n_batches)
    history['perplexity'].append(total_perplexity / n_batches)
    
    print(f"Epoch [{epoch+1}/{n_epochs}] "
          f"Recon: {history['recon_loss'][-1]:.4f} "
          f"VQ: {history['vq_loss'][-1]:.4f} "
          f"Perplexity: {history['perplexity'][-1]:.1f}")

print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

axes[0].plot(history['recon_loss'], linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('MSE Loss', fontsize=12)
axes[0].set_title('Reconstruction Loss', fontsize=13)
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['vq_loss'], linewidth=2, color='orange')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('VQ Loss', fontsize=12)
axes[1].set_title('Vector Quantization Loss', fontsize=13)
axes[1].grid(True, alpha=0.3)

axes[2].plot(history['perplexity'], linewidth=2, color='green')
axes[2].axhline(y=256, color='r', linestyle='--', label='Max (codebook size)')
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Perplexity', fontsize=12)
axes[2].set_title('Codebook Usage', fontsize=13)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Reconstruction and Codebook AnalysisΒΆ

Evaluating a VQ-VAE involves both reconstruction quality (comparing original images to their encode-quantize-decode outputs) and codebook health (visualizing which codes are used and how frequently). High-quality reconstructions with low codebook utilization suggest the model has found a compact representation, while poor reconstructions with high utilization indicate the codebook is too small or the encoder-decoder lacks capacity. Visualizing the codebook entries themselves (by decoding each entry independently) reveals what visual concepts each code represents – an important step toward understanding the learned discrete vocabulary.

model.eval()

# Get test samples
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=10)
images, labels = next(iter(test_loader))
images = images.to(device)

# Reconstruct
with torch.no_grad():
    recon, _, _ = model(images)

# Visualize
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(10):
    axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel('Original', fontsize=11)
    
    axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel('Reconstructed', fontsize=11)

plt.suptitle('VQ-VAE Reconstructions', fontsize=14)
plt.tight_layout()
plt.show()

# Analyze codebook
print("\nCodebook analysis:")
with torch.no_grad():
    all_indices = []
    for images, _ in test_loader:
        images = images.to(device)
        z_e = model.encoder(images)
        _, _, _, indices = model.vq(z_e)
        all_indices.append(indices.cpu())
    
    all_indices = torch.cat(all_indices)
    unique_codes = torch.unique(all_indices)
    
    print(f"Total codes used: {len(unique_codes)} / 256")
    print(f"Utilization: {100 * len(unique_codes) / 256:.1f}%")

# Code frequency
plt.figure(figsize=(12, 4))
code_counts = torch.bincount(all_indices, minlength=256)
plt.bar(range(256), code_counts.numpy(), alpha=0.7)
plt.xlabel('Codebook Index', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Codebook Entry Usage', fontsize=13)
plt.grid(True, alpha=0.3, axis='y')
plt.show()

SummaryΒΆ

Key Innovations:ΒΆ

  1. Discrete latent space: Codebook of learned embeddings

  2. Vector quantization: Nearest neighbor assignment

  3. Straight-through estimator: Gradient flow through quantization

  4. No posterior collapse: Discrete bottleneck enforces usage

VQ-VAE Loss:ΒΆ

\[L = \|x - D(z_q)\|^2 + \|\text{sg}[z_e] - e\|^2 + \beta\|z_e - \text{sg}[e]\|^2\]

Advantages:ΒΆ

  • Sharp reconstructions: No Gaussian blurring

  • Compression: Discrete codes highly compressible

  • Powerful priors: Can use autoregressive models on codes

  • No mode collapse: Unlike GANs

Applications:ΒΆ

  • Image generation: VQ-VAE-2, DALL-E

  • Audio synthesis: WaveNet, Jukebox

  • Video modeling: VideoGPT

  • Reinforcement learning: World models

Extensions:ΒΆ

  • VQ-VAE-2 (Razavi et al., 2019): Hierarchical codebooks

  • EMA updates: Exponential moving average for codebook

  • Reset unused codes: Prevent codebook collapse

  • Product quantization: Multiple codebooks

Training Tips:ΒΆ

  • Monitor perplexity (codebook usage)

  • Tune commitment cost \(\beta\)

  • Use EMA for more stable training

  • Reset dead codes periodically

Next Steps:ΒΆ

  • 14_normalizing_flows.ipynb - Exact likelihood models

  • 07_hierarchical_vae.ipynb - Hierarchical latent variables

  • 03_variational_autoencoders_advanced.ipynb - VAE foundations

Advanced Vector Quantized Variational Autoencoders (VQ-VAE): Theory and PracticeΒΆ

1. Introduction to VQ-VAEΒΆ

Vector Quantized Variational Autoencoders (VQ-VAE) combine the benefits of VAEs with discrete latent representations by introducing a vector quantization bottleneck. Unlike traditional VAEs that learn continuous latent distributions, VQ-VAE learns a discrete codebook of embedding vectors and maps encoder outputs to the nearest codebook entry.

Key innovation: Replace continuous latent distribution \(q(z|x)\) with discrete codebook lookup, enabling:

  • Discrete latent codes: Natural for structured data (text, music, images)

  • Powerful priors: Autoregressive models (PixelCNN, Transformers) for generation

  • High-quality reconstruction: No posterior collapse, better than vanilla VAE

Architecture overview:

Input x β†’ Encoder β†’ z_e(x) β†’ Quantize (nearest codebook) β†’ z_q β†’ Decoder β†’ Reconstruction xΜ‚

Mathematical formulation:

  • Encoder: \(z_e(x) = E_\theta(x) \in \mathbb{R}^{H \times W \times D}\)

  • Quantization: \(z_q(x) = e_k\) where \(k = \arg\min_j \|z_e(x) - e_j\|_2\)

  • Codebook: \(\mathcal{E} = \{e_1, e_2, \ldots, e_K\} \subset \mathbb{R}^D\) with \(K\) entries

  • Decoder: \(\hat{x} = D_\phi(z_q)\)

Advantages over VAE:

  • No KL divergence term (no posterior collapse)

  • Discrete codes easier to model with autoregressive priors

  • Better reconstruction quality

  • Codebook learns meaningful discrete representations

Applications:

  • High-fidelity image generation (DALL-E, Imagen)

  • Audio synthesis (WaveNet, Jukebox)

  • Video generation (VideoGPT)

  • Compression (learned compression with discrete codes)

2. Vector Quantization: Core MechanismΒΆ

2.1 Codebook and Nearest Neighbor LookupΒΆ

The codebook \(\mathcal{E} = \{e_1, \ldots, e_K\}\) contains \(K\) embedding vectors \(e_i \in \mathbb{R}^D\).

Quantization operation: $\(q(z_e) = e_k, \quad \text{where } k = \arg\min_{j \in \{1,\ldots,K\}} \|z_e - e_j\|_2^2\)$

For spatial encoder outputs \(z_e \in \mathbb{R}^{H \times W \times D}\):

  • Each spatial location \((h, w)\) gets quantized independently

  • \(z_q[h, w] = e_{k_{h,w}}\) where \(k_{h,w} = \arg\min_j \|z_e[h, w] - e_j\|_2^2\)

  • Results in discrete code map: \(K \in \{1, \ldots, K\}^{H \times W}\)

Properties:

  • Non-differentiable: \(\arg\min\) operation has no gradient

  • Straight-through estimator: Copy gradients from decoder to encoder

  • Codebook learning: Update embeddings via exponential moving average or gradient descent

2.2 Straight-Through EstimatorΒΆ

Since quantization is non-differentiable, use straight-through estimator for backpropagation:

Forward pass: \(z_q = q(z_e) = e_k\) (actual quantization)

Backward pass: \(\frac{\partial \mathcal{L}}{\partial z_e} = \frac{\partial \mathcal{L}}{\partial z_q}\) (copy gradients)

Intuition: Pretend quantization is identity during backprop.

Implementation:

z_q = z_e + (quantized - z_e).detach()

This creates \(z_q\) that equals quantized in forward pass but has gradients of \(z_e\) in backward pass.

3. Loss FunctionΒΆ

3.1 Total LossΒΆ

VQ-VAE optimizes three terms:

\[\mathcal{L} = \mathcal{L}_{\text{recon}} + \|\text{sg}[z_e] - e\|_2^2 + \beta \|z_e - \text{sg}[e]\|_2^2\]

where \(\text{sg}[\cdot]\) denotes stop-gradient (no gradient flows through).

Components:

  1. Reconstruction loss \(\mathcal{L}_{\text{recon}}\): $\(\mathcal{L}_{\text{recon}} = \|x - D(z_q)\|_2^2 \quad \text{or} \quad -\log p(x | z_q)\)$

    Trains decoder to reconstruct input from quantized codes.

  2. Codebook loss (commitment from encoder): $\(\mathcal{L}_{\text{codebook}} = \|\text{sg}[z_e] - e\|_2^2\)$

    Moves codebook embeddings \(e\) towards encoder outputs \(z_e\) (no gradient to encoder).

  3. Commitment loss (from encoder to codebook): $\(\mathcal{L}_{\text{commit}} = \beta \|z_e - \text{sg}[e]\|_2^2\)$

    Encourages encoder outputs to stay close to chosen codebook entry (no gradient to codebook).

Typical value: \(\beta = 0.25\)

3.2 Exponential Moving Average (EMA) Codebook UpdateΒΆ

Alternative to gradient-based codebook updates: use EMA statistics.

Update rule: $\(N_i^{(t)} = \gamma N_i^{(t-1)} + (1 - \gamma) n_i^{(t)}\)\( \)\(m_i^{(t)} = \gamma m_i^{(t-1)} + (1 - \gamma) \sum_{z_e: q(z_e)=e_i} z_e\)\( \)\(e_i^{(t)} = \frac{m_i^{(t)}}{N_i^{(t)}}\)$

where:

  • \(N_i\) = count of assignments to codebook entry \(i\)

  • \(m_i\) = sum of encoder outputs assigned to \(i\)

  • \(\gamma\) = decay rate (e.g., 0.99)

Advantage: More stable than gradient descent, no codebook loss term needed.

Loss with EMA: $\(\mathcal{L}_{\text{EMA}} = \mathcal{L}_{\text{recon}} + \beta \|z_e - \text{sg}[e]\|_2^2\)$

4. Codebook Initialization and CollapseΒΆ

4.1 Codebook Collapse ProblemΒΆ

Problem: Some codebook entries never get used (dead codes).

Causes:

  • Poor initialization (random embeddings far from data)

  • Training dynamics (some codes dominate)

  • High codebook size \(K\) relative to data diversity

Detection: Monitor codebook usage statistics (perplexity).

Perplexity: $\(\text{Perplexity} = \exp\left(-\sum_{i=1}^K p_i \log p_i\right)\)$

where \(p_i\) = fraction of assignments to codebook entry \(i\).

  • High perplexity (close to \(K\)): Good utilization

  • Low perplexity: Collapse (few codes used)

4.2 Solutions to Codebook CollapseΒΆ

1. Random restart:

  • Periodically reinitialize unused codes

  • Replace dead codes with random encoder outputs

2. Codebook initialization:

  • Initialize from k-means clustering on encoder outputs

  • Better starting point than random initialization

3. Commitment loss weight:

  • Lower \(\beta\) encourages more exploration

  • Higher \(\beta\) forces encoder to commit to existing codes

4. Gumbel-Softmax relaxation:

  • Replace hard quantization with soft differentiable version

  • Gradually anneal temperature to approach hard quantization

5. Product quantization:

  • Use multiple smaller codebooks instead of one large one

  • Reduces chance of collapse, increases effective codebook size

5. VQ-VAE Architecture DetailsΒΆ

5.1 EncoderΒΆ

Purpose: Map input \(x\) to continuous representation \(z_e\).

Architecture:

  • Convolutional layers with downsampling (stride 2)

  • Residual blocks for capacity

  • Final layer outputs \(z_e \in \mathbb{R}^{H \times W \times D}\)

Example (CIFAR-10):

Input: [3, 32, 32]
Conv(128, 4Γ—4, stride=2, padding=1) β†’ [128, 16, 16]
ResBlock(128) β†’ [128, 16, 16]
ResBlock(128) β†’ [128, 16, 16]
Conv(256, 4Γ—4, stride=2, padding=1) β†’ [256, 8, 8]
ResBlock(256) β†’ [256, 8, 8]
Conv(D, 3Γ—3, padding=1) β†’ [D, 8, 8]
Output: z_e of shape [D, 8, 8]

Embedding dimension: \(D = 64\) or \(D = 256\) typical.

5.2 QuantizerΒΆ

Operation: For each spatial location, find nearest codebook entry.

Codebook: \(\mathcal{E} \in \mathbb{R}^{K \times D}\), typically \(K = 512\) or \(K = 1024\).

Distance metric: \(L_2\) distance $\(d(z_e[h,w], e_j) = \|z_e[h,w] - e_j\|_2^2 = \sum_{d=1}^D (z_e[h,w,d] - e_j[d])^2\)$

Efficient implementation: $\(\|z_e - e_j\|^2 = \|z_e\|^2 - 2 z_e^T e_j + \|e_j\|^2\)$

Compute as matrix multiplication: \(z_e \mathcal{E}^T\) (batch operation).

5.3 DecoderΒΆ

Purpose: Reconstruct input from quantized codes \(z_q\).

Architecture:

  • Transposed convolutions for upsampling

  • Residual blocks

  • Final layer outputs reconstruction \(\hat{x}\)

Example (CIFAR-10):

Input: z_q of shape [D, 8, 8]
Conv(256, 3Γ—3, padding=1) β†’ [256, 8, 8]
ResBlock(256) β†’ [256, 8, 8]
ConvTranspose(128, 4Γ—4, stride=2, padding=1) β†’ [128, 16, 16]
ResBlock(128) β†’ [128, 16, 16]
ResBlock(128) β†’ [128, 16, 16]
ConvTranspose(3, 4Γ—4, stride=2, padding=1) β†’ [3, 32, 32]
Output: xΜ‚ of shape [3, 32, 32]

Symmetry: Decoder often mirrors encoder architecture.

6. Training ProcedureΒΆ

6.1 AlgorithmΒΆ

Input: Dataset \(\{x_i\}\), codebook size \(K\), embedding dim \(D\), commitment weight \(\beta\)

Initialize:

  • Encoder \(E_\theta\), Decoder \(D_\phi\) with random weights

  • Codebook \(\mathcal{E} \in \mathbb{R}^{K \times D}\) randomly or via k-means

  • Optimizer (Adam with \(\text{lr} = 10^{-4}\))

Training loop:

for each batch x:
    # Forward pass
    z_e = E_ΞΈ(x)                                    # Encode
    k = argmin_j ||z_e - e_j||Β²                     # Quantize
    z_q = e_k                                       # Lookup
    z_q = z_e + (z_q - z_e).detach()               # Straight-through
    xΜ‚ = D_Ο†(z_q)                                    # Decode
    
    # Compute losses
    L_recon = ||x - xΜ‚||Β²
    L_codebook = ||sg[z_e] - e||Β²
    L_commit = Ξ² ||z_e - sg[e]||Β²
    L = L_recon + L_codebook + L_commit
    
    # Backward pass
    L.backward()
    optimizer.step()
    
    # Optional: EMA codebook update
    update_codebook_ema(z_e, k)

6.2 HyperparametersΒΆ

Parameter

Typical Value

Impact

Codebook size \(K\)

512-1024

Larger = more capacity, risk of collapse

Embedding dim \(D\)

64-256

Higher = more expressive

Commitment \(\beta\)

0.25

Higher = stronger commitment

Learning rate

\(10^{-4}\)

Standard Adam

Batch size

32-128

Larger helps codebook statistics

EMA decay \(\gamma\)

0.99

Codebook update smoothing

6.3 MonitoringΒΆ

Key metrics:

  1. Reconstruction loss: Should decrease steadily

  2. Perplexity: Should be high (close to \(K\))

  3. Codebook usage: Histogram of code assignments

  4. Commitment loss: Should stabilize

Early stopping: Monitor validation reconstruction loss.

7. Prior Models for GenerationΒΆ

7.1 Two-Stage TrainingΒΆ

VQ-VAE enables powerful two-stage generation:

Stage 1: Train VQ-VAE

  • Learn encoder, decoder, codebook

  • Obtain discrete latent codes for dataset

Stage 2: Train prior on codes

  • Model distribution \(p(k)\) over discrete codes

  • Use autoregressive models (PixelCNN, Transformers)

Generation:

1. Sample code sequence: k ~ p(k)           # Prior model
2. Lookup embeddings: z_q = e_k             # Codebook
3. Decode to image: xΜ‚ = D(z_q)             # Decoder

7.2 PixelCNN PriorΒΆ

Architecture: Autoregressive CNN that models \(p(k_{h,w} | k_{<(h,w)})\)

Factorization: $\(p(\mathbf{k}) = \prod_{h=1}^H \prod_{w=1}^W p(k_{h,w} | k_{1:h-1,:}, k_{h,1:w-1})\)$

Masked convolutions: Ensure autoregressive property (no future information).

Training: Maximize log-likelihood $\(\mathcal{L}_{\text{prior}} = -\mathbb{E}_{\mathbf{k} \sim \text{data}}[\log p(\mathbf{k})]\)$

Sampling: Sequential sampling row-by-row, left-to-right.

7.3 Transformer PriorΒΆ

Architecture: Apply Transformer decoder to flattened code sequence.

Sequence: Flatten \(k \in \{1, \ldots, K\}^{H \times W}\) to \(k \in \{1, \ldots, K\}^{H \cdot W}\)

Autoregressive modeling: $\(p(\mathbf{k}) = \prod_{t=1}^{H \cdot W} p(k_t | k_{<t})\)$

Advantages over PixelCNN:

  • Global attention (long-range dependencies)

  • Parallel training (teacher forcing)

  • State-of-the-art quality

DALL-E: VQ-VAE + Transformer prior for text-to-image generation.

8. VQ-VAE-2: Hierarchical ExtensionΒΆ

8.1 MotivationΒΆ

Limitation of VQ-VAE: Single-scale quantization misses hierarchical structure.

Solution (VQ-VAE-2): Multi-level hierarchy of discrete codes.

Architecture:

Input β†’ Encoder β†’ [Top codes (coarse)] ← Prior_top
              ↓
           Decoder_top β†’ [Bottom codes (fine)] ← Prior_bottom
              ↓
           Decoder_bottom β†’ Reconstruction

8.2 Two-Level HierarchyΒΆ

Top level: Coarse codes \(k_{\text{top}} \in \{1, \ldots, K\}^{H_1 \times W_1}\)

  • Lower resolution (\(H_1 \times W_1 = 8 \times 8\) for 256Γ—256 input)

  • Captures global structure

Bottom level: Fine codes \(k_{\text{bot}} \in \{1, \ldots, K\}^{H_2 \times W_2}\)

  • Higher resolution (\(H_2 \times W_2 = 32 \times 32\))

  • Captures local details

Conditional prior: $\(p(k_{\text{bot}} | k_{\text{top}}) = \prod_{i} p(k_{\text{bot}}^i | k_{\text{top}}, k_{\text{bot}}^{<i})\)$

Generation:

  1. Sample top codes: \(k_{\text{top}} \sim p(k_{\text{top}})\)

  2. Sample bottom codes: \(k_{\text{bot}} \sim p(k_{\text{bot}} | k_{\text{top}})\)

  3. Decode: \(x = D(z_q^{\text{bot}}, z_q^{\text{top}})\)

Benefits:

  • Better quality (ImageNet 256Γ—256)

  • More structured latent space

  • Conditional generation easier

9. Variants and ExtensionsΒΆ

9.1 Gumbel-Softmax VQ-VAEΒΆ

Problem: Straight-through estimator is biased.

Solution: Replace hard quantization with Gumbel-Softmax (temperature-annealed soft quantization).

Gumbel-Softmax: $\(\pi_j = \frac{\exp((\log \alpha_j + g_j) / \tau)}{\sum_{k=1}^K \exp((\log \alpha_k + g_k) / \tau)}\)$

where:

  • \(\alpha_j \propto \exp(-\|z_e - e_j\|^2)\) (similarity to codebook)

  • \(g_j \sim \text{Gumbel}(0, 1)\) (noise for exploration)

  • \(\tau\) = temperature (anneal from 1.0 to 0.1)

Soft quantization: $\(z_q = \sum_{j=1}^K \pi_j e_j\)$

Advantage: Fully differentiable, gradients flow to codebook and encoder.

Disadvantage: Soft codes during training, need to switch to hard codes for inference.

9.2 Product Quantization (PQ)ΒΆ

Idea: Use \(M\) smaller codebooks instead of one large codebook.

Codebooks: \(\mathcal{E}_1, \ldots, \mathcal{E}_M\) each of size \(K_s\) (subcodebook size)

Embedding split: \(z_e = [z_e^1, \ldots, z_e^M]\) where \(z_e^m \in \mathbb{R}^{D/M}\)

Quantization: Quantize each subvector independently $\(z_q^m = \arg\min_{e \in \mathcal{E}_m} \|z_e^m - e\|^2\)$

Total codes: \(K = K_s^M\) (exponential in \(M\))

Example: \(M=4\) subcodebooks of size \(K_s=256\) gives \(256^4 = 4.3\) billion effective codes!

Advantages:

  • Huge effective codebook size

  • Less prone to collapse

  • More efficient memory

9.3 Residual VQ (RVQ)ΒΆ

Idea: Apply quantization iteratively on residuals.

Algorithm:

r_0 = z_e
for i = 1 to L:
    k_i = argmin_j ||r_{i-1} - e_j^(i)||Β²
    q_i = e_{k_i}^(i)
    r_i = r_{i-1} - q_i
z_q = q_1 + q_2 + ... + q_L

Benefits:

  • Finer-grained quantization (\(L\) levels)

  • Better reconstruction quality

  • Codebooks learn hierarchical structure

Used in: SoundStream (audio codec), EnCodec

10. ApplicationsΒΆ

10.1 Image GenerationΒΆ

DALL-E (OpenAI, 2021):

  • VQ-VAE with 8,192 codebook size

  • Latent codes: 32Γ—32 grid for 256Γ—256 images

  • Transformer prior (12 billion parameters)

  • Text-to-image generation (text β†’ codes β†’ image)

Parti (Google, 2022):

  • ViT-VQGAN encoder

  • Encoder-decoder Transformer prior

  • State-of-the-art text-to-image

ImageGPT:

  • VQ-VAE for image tokenization

  • GPT-style Transformer for generation

  • Unsupervised pre-training for vision

10.2 Audio GenerationΒΆ

Jukebox (OpenAI, 2020):

  • Three-level VQ-VAE hierarchy

  • Top: 8 seconds @ 44.1kHz β†’ 8,192 codes

  • Transformer priors for each level

  • Generates coherent music with lyrics

SoundStream (Google, 2021):

  • RVQ with 8 quantizers

  • 3 kbps audio compression

  • Better than MP3 at same bitrate

10.3 Video GenerationΒΆ

VideoGPT:

  • VQ-VAE on video frames (spatiotemporal)

  • 3D convolutions for encoder/decoder

  • Transformer prior for frame sequence

NUWA (Microsoft, 2021):

  • VQ-VAE for video tokenization

  • 3D-Transformer prior

  • Text-to-video generation

10.4 CompressionΒΆ

Learned compression:

  • VQ-VAE as lossy compressor

  • Discrete codes transmitted/stored

  • Arithmetic coding on code sequence

Advantages over JPEG/PNG:

  • Learned on data distribution (better rate-distortion)

  • End-to-end optimized

  • Flexible bitrate (codebook size)

11. Mathematical AnalysisΒΆ

11.1 Information-Theoretic ViewΒΆ

Rate-distortion theory: Trade-off between compression rate and reconstruction quality.

VQ-VAE objective: $\(\min_{E, D, \mathcal{E}} \mathbb{E}_{x \sim p_{\text{data}}}[\underbrace{\|x - D(q(E(x)))\|^2}_{\text{Distortion}}] \quad \text{subject to } \underbrace{H \cdot W \cdot \log_2 K \text{ bits}}_{\text{Rate}}\)$

where \(H \times W\) is latent spatial size and \(K\) is codebook size.

Codebook entropy: $\(H(\mathcal{E}) = -\sum_{k=1}^K p(k) \log_2 p(k)\)$

Lower entropy (skewed usage) β†’ compressible codes via entropy coding.

11.2 Comparison to VAEΒΆ

VAE: Continuous latent \(z \sim q(z|x) = \mathcal{N}(\mu(x), \sigma^2(x))\)

VQ-VAE: Discrete latent \(k = \arg\min_j \|E(x) - e_j\|^2\)

ELBO comparison:

VAE: $\(\log p(x) \geq \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) \| p(z))\)$

VQ-VAE: No KL term! Just reconstruction + commitment: $\(\log p(x) \approx -\|x - D(e_k)\|^2 - \beta \|E(x) - e_k\|^2\)$

Advantage: No posterior collapse (KL β†’ 0 problem in VAE).

11.3 Codebook CapacityΒΆ

Effective capacity: Number of actively used codes.

Perplexity: $\(\text{Perplexity} = \exp(H(\mathcal{E})) = \exp\left(-\sum_{k=1}^K p(k) \log p(k)\right)\)$

  • Uniform usage: Perplexity = \(K\) (full capacity)

  • Collapsed: Perplexity β‰ͺ \(K\) (low capacity)

Empirical: VQ-VAE typically achieves 50-90% of theoretical capacity (\(K\)).

12. Theoretical PropertiesΒΆ

12.1 Representational PowerΒΆ

Theorem (Universal approximation): For any continuous function \(f: \mathbb{R}^n \to \mathbb{R}^m\) and \(\epsilon > 0\), there exists a VQ-VAE with sufficiently large codebook \(K\) such that: $\(\|f(x) - D(q(E(x)))\| < \epsilon\)$

Proof sketch:

  • Encoder \(E\) can partition input space into \(K\) regions (Voronoi cells)

  • Each region assigned to codebook entry

  • Decoder \(D\) approximates function in each region

  • As \(K \to \infty\), partition becomes arbitrarily fine

12.2 Gradient BiasΒΆ

Straight-through estimator: $\(\frac{\partial \mathcal{L}}{\partial z_e} = \frac{\partial \mathcal{L}}{\partial z_q}\)$

Bias: Gradients pretend quantization is identity.

Analysis: Gradient error bounded by quantization error $\(\left\|\frac{\partial \mathcal{L}}{\partial z_e} - \frac{\partial \mathcal{L}}{\partial z_e}^{\text{true}}\right\| \leq C \|z_e - z_q\|\)$

where \(C\) is Lipschitz constant of loss.

Mitigation: Commitment loss \(\|z_e - z_q\|^2\) keeps quantization error small.

13. Training Challenges and SolutionsΒΆ

13.1 Codebook CollapseΒΆ

Symptoms:

  • Low perplexity (<10% of \(K\))

  • Most codes never used

  • Poor reconstruction despite low training loss

Solutions:

  1. EMA updates: More stable than gradient-based

  2. Random restart: Reinitialize unused codes every \(N\) iterations

  3. Laplace smoothing: Add small count to all codes $\(N_i^{(t)} = \gamma N_i^{(t-1)} + (1-\gamma) n_i^{(t)} + \epsilon\)$

  4. Lower commitment weight: \(\beta = 0.1\) instead of 0.25

  5. Larger codebook: Redundancy helps (\(K=2048\) instead of 512)

13.2 Training InstabilityΒΆ

Symptoms:

  • Reconstruction loss oscillates

  • Codebook embeddings diverge

Solutions:

  1. Gradient clipping: Clip encoder/decoder gradients

  2. Learning rate warmup: Start with low LR, increase linearly

  3. Spectral normalization: Constrain network Lipschitz constant

  4. Batch normalization: Stabilize activations (or GroupNorm/LayerNorm)

13.3 Mode CollapseΒΆ

Symptom: Decoder ignores some codebook entries, generates same output.

Solutions:

  1. Diverse initialization: k-means on encoder outputs

  2. Entropy regularization: Add \(-\lambda H(\mathcal{E})\) to loss

  3. Curriculum learning: Start with small codebook, grow over time

14. Implementation Best PracticesΒΆ

14.1 Architecture DesignΒΆ

Encoder:

  • Use residual connections (ResNet-style)

  • Gradually downsample (avoid large strides)

  • Final layer: Linear projection to embedding dimension

Decoder:

  • Mirror encoder architecture

  • Use transposed convolutions or upsampling + conv

  • Skip connections help reconstruction

Codebook:

  • Initialize with k-means or normal distribution \(\mathcal{N}(0, 1/D)\)

  • Normalize embeddings periodically (optional)

14.2 Loss WeightingΒΆ

Typical weights:

  • Reconstruction: 1.0 (base)

  • Codebook: 1.0 (if using gradients)

  • Commitment: 0.25

Annealing: Some works anneal commitment weight: $\(\beta(t) = \beta_{\text{max}} \cdot \min(1, t / T_{\text{warmup}})\)$

14.3 Computational EfficiencyΒΆ

Quantization:

  • Precompute codebook norms: \(\|e_j\|^2\)

  • Use matrix multiplication: \(z_e \mathcal{E}^T\)

  • Avoid explicit distance computation loops

Memory:

  • Store codes as uint8 or uint16 (not one-hot)

  • Gradient checkpointing for large models

Parallelization:

  • Quantization independent across spatial locations (parallelize)

  • EMA updates can be batched

15. Evaluation MetricsΒΆ

15.1 Reconstruction QualityΒΆ

Metrics:

  • MSE/PSNR: \(\text{PSNR} = 10 \log_{10}(255^2 / \text{MSE})\)

  • SSIM: Structural similarity (perceptual quality)

  • LPIPS: Learned perceptual similarity (better than MSE)

Target: PSNR > 25 dB for good quality, LPIPS < 0.2

15.2 Codebook MetricsΒΆ

Perplexity: $\(\text{Perplexity} = \exp\left(-\sum_k p(k) \log p(k)\right)\)$

Target: >80% of codebook size (e.g., >410 for \(K=512\))

Active codes: Number of codes used at least once in batch/epoch.

Usage histogram: Visualize distribution of code assignments.

15.3 Generation Quality (with Prior)ΒΆ

Metrics:

  • FID: FrΓ©chet Inception Distance (lower better, <10 excellent)

  • IS: Inception Score (higher better)

  • Precision/Recall: Sample quality vs. diversity

Benchmark (CIFAR-10):

  • VQ-VAE-2 + PixelCNN: FID ~20

  • VQ-VAE-2 + Transformer: FID ~15

16. Comparison with Other Discrete RepresentationsΒΆ

Method

Latent Type

Training

Prior

Quality

Use Case

VQ-VAE

Discrete codes

Reconstruction + commitment

Autoregressive

High

Images, audio, video

DALL-E

VQ-VAE codes

Same as VQ-VAE

Transformer

State-of-art

Text-to-image

VQGAN

VQ-VAE + GAN

+ Adversarial + perceptual

Transformer

Higher

High-res images

Discrete VAE

Categorical latent

Gumbel-Softmax

MLP/Transformer

Moderate

Structured data

VQ-GAN

Codebook + patch discriminator

Adversarial

Transformer

Highest

Photorealistic

VQ-VAE vs. VQGAN:

  • VQGAN adds adversarial loss (discriminator) + perceptual loss (LPIPS)

  • Better perceptual quality, more stable training

  • Used in Stable Diffusion, Parti

17. Advanced TopicsΒΆ

17.1 Conditional VQ-VAEΒΆ

Conditioning: Add class label \(y\) or text embedding \(c\) to decoder.

Decoder: $\(\hat{x} = D(z_q, c)\)$

Architecture:

  • Concatenate \(c\) to \(z_q\) spatially (via broadcasting)

  • Use FiLM (Feature-wise Linear Modulation) layers

  • Cross-attention for text conditioning

Applications: Class-conditional generation, style transfer, inpainting.

17.2 Multi-Modal VQ-VAEΒΆ

Idea: Shared codebook across modalities (image + text, image + audio).

Architecture:

  • Separate encoders: \(E_{\text{img}}, E_{\text{text}}\)

  • Shared codebook: \(\mathcal{E}\)

  • Separate decoders: \(D_{\text{img}}, D_{\text{text}}\)

Training: Align representations via shared codes.

Cross-modal generation:

  • Text β†’ codes β†’ image

  • Image β†’ codes β†’ caption

17.3 DisentanglementΒΆ

Goal: Learn interpretable codebook (each dimension = factor of variation).

Approaches:

  • Factorized codes: Separate codebooks for different factors (shape, color, texture)

  • Supervised disentanglement: Use labeled attributes during training

  • Unsupervised: Add \(\beta\)-VAE style penalty on codebook usage

Evaluation: Factor-VAE metric, MIG (Mutual Information Gap).

18. Recent Advances (2020-2024)ΒΆ

18.1 ViT-VQGANΒΆ

Innovation: Use Vision Transformer (ViT) as encoder/decoder instead of CNNs.

Advantages:

  • Global receptive field (vs. local CNN)

  • Scalability to high-resolution (1024Γ—1024+)

  • Better long-range dependencies

Architecture:

  • Patch embedding β†’ ViT encoder β†’ quantization β†’ ViT decoder β†’ reconstruction

Results: State-of-the-art FID on ImageNet 256Γ—256.

18.2 RQ-VAE (Residual Quantized VAE)ΒΆ

Idea: Iterative refinement with residual quantization (RVQ).

Algorithm: Apply VQ-VAE on residuals \(L\) times

z_q^(0) = 0
for l = 1 to L:
    r^(l) = z_e - z_q^(l-1)
    z_q^(l) = z_q^(l-1) + VQ(r^(l))

Benefits:

  • Better reconstruction (\(L=8\) quantizers)

  • Hierarchical representation

  • Used in audio codecs (SoundStream, EnCodec)

18.3 Finite Scalar Quantization (FSQ)ΒΆ

Motivation: Avoid codebook collapse entirely.

Idea: Replace codebook lookup with deterministic rounding.

Quantization: Round each dimension to \(\{-1, 0, 1\}\) or \(\{0, 1, \ldots, L-1\}\)

Example: \(D=8\) dimensions, \(L=3\) levels β†’ \(3^8 = 6,561\) codes

Advantages:

  • No codebook (no collapse)

  • Fully differentiable (via straight-through on rounding)

  • Simpler implementation

Disadvantage: Less flexible than learned codebook.

18.4 Masked Generative ModelsΒΆ

MaskGIT: Replace autoregressive prior with masked Transformer.

Training: Mask random codes, predict masked codes (BERT-style).

Sampling: Iterative refinement (like BERT)

  1. Start with all codes masked

  2. Predict all codes (confidence scores)

  3. Unmask top-\(k\) confident predictions

  4. Repeat until all unmasked

Advantage: Parallel sampling (~10Γ— faster than autoregressive).

19. Limitations and Future DirectionsΒΆ

19.1 Current LimitationsΒΆ

Codebook collapse:

  • Still an issue despite mitigation strategies

  • Requires careful tuning and monitoring

Scalability:

  • Large codebooks (\(K > 10,000\)) difficult to train

  • Memory overhead for very high-dimensional embeddings

Prior complexity:

  • Strong prior (large Transformer) needed for good generation

  • Two-stage training awkward (VQ-VAE then prior)

Reconstruction vs. generation trade-off:

  • Better reconstruction β†’ more codes used

  • Better generation β†’ fewer codes (easier to model)

19.2 Future DirectionsΒΆ

End-to-end training:

  • Train VQ-VAE and prior jointly (avoid two-stage)

  • Differentiable codebook via Gumbel-Softmax

Continuous relaxations:

  • Soft VQ-VAE (temperature annealing)

  • Hybrid discrete-continuous latent spaces

Theoretical understanding:

  • Formal analysis of codebook capacity

  • Convergence guarantees for EMA updates

  • Optimal codebook size selection

Novel applications:

  • 3D shape generation (point clouds, meshes)

  • Protein structure design (discrete sequence β†’ structure)

  • Reinforcement learning (discrete action representations)

Hardware acceleration:

  • Custom hardware for fast nearest-neighbor search

  • Approximate nearest neighbors (LSH, FAISS)

20. SummaryΒΆ

Key Takeaways:

  1. Core idea: VQ-VAE replaces continuous VAE latent with discrete codebook lookup

    • Encoder β†’ quantize to nearest codebook entry β†’ decoder

    • Straight-through estimator for gradients

  2. Loss function:

    • Reconstruction + codebook loss + commitment loss

    • Or: Reconstruction + commitment with EMA codebook updates

  3. Training challenges:

    • Codebook collapse (low perplexity)

    • Solutions: EMA updates, random restart, proper initialization

  4. Two-stage generation:

    • Stage 1: Train VQ-VAE (encoder, decoder, codebook)

    • Stage 2: Train prior on discrete codes (PixelCNN, Transformer)

  5. Advantages over VAE:

    • No posterior collapse (no KL term)

    • Powerful autoregressive priors on discrete codes

    • Better reconstruction quality

  6. Variants:

    • VQ-VAE-2: Hierarchical codes (coarse + fine)

    • Product quantization: Multiple small codebooks

    • Residual VQ: Iterative refinement

    • Gumbel-Softmax: Differentiable quantization

  7. Applications:

    • Image generation: DALL-E, Parti, ImageGPT

    • Audio: Jukebox, SoundStream

    • Video: VideoGPT, NUWA

    • Compression: Learned codecs

  8. Recent advances:

    • ViT-VQGAN: Transformer encoder/decoder

    • RQ-VAE: Residual quantization (8 levels)

    • FSQ: Finite scalar quantization (no codebook)

    • MaskGIT: Masked Transformers (parallel sampling)

  9. Best practices:

    • Use EMA codebook updates (more stable)

    • Monitor perplexity (should be >50% of \(K\))

    • Initialize codebook via k-means

    • Commitment weight \(\beta = 0.25\)

  10. When to use:

    • βœ“ Discrete latent representations preferred

    • βœ“ Autoregressive generation desired

    • βœ“ High-quality reconstruction needed

    • βœ— Continuous latent interpolation required

    • βœ— Single-stage training preferred

# ============================================================
# ADVANCED VQ-VAE: PRODUCTION IMPLEMENTATIONS
# Complete PyTorch implementations with modern techniques
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
import math

# ============================================================
# 1. VECTOR QUANTIZER
# ============================================================

class VectorQuantizer(nn.Module):
    """Vector quantization layer with codebook learning.
    
    Maps continuous encoder outputs z_e to discrete codebook entries e_k.
    """
    
    def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25,
                 use_ema: bool = True, decay: float = 0.99, epsilon: float = 1e-5):
        """
        Args:
            num_embeddings: Number of codebook entries (K)
            embedding_dim: Dimension of each embedding (D)
            commitment_cost: Weight for commitment loss (Ξ²)
            use_ema: Use exponential moving average for codebook updates
            decay: EMA decay rate (Ξ³)
            epsilon: Small constant for numerical stability
        """
        super().__init__()
        
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.use_ema = use_ema
        
        # Codebook embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
        
        if use_ema:
            # EMA statistics
            self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
            self.register_buffer('ema_w', self.embedding.weight.data.clone())
            self.decay = decay
            self.epsilon = epsilon
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            z_e: Encoder outputs [batch_size, embedding_dim, height, width]
        
        Returns:
            z_q: Quantized outputs [batch_size, embedding_dim, height, width]
            loss: VQ loss (codebook + commitment)
            perplexity: Codebook usage metric
        """
        # Reshape z_e: [B, D, H, W] -> [B*H*W, D]
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous()
        z_e_flat = z_e_flat.view(-1, self.embedding_dim)
        
        # Compute distances to codebook entries
        # ||z_e - e_j||^2 = ||z_e||^2 + ||e_j||^2 - 2 z_e^T e_j
        distances = (torch.sum(z_e_flat ** 2, dim=1, keepdim=True)
                     + torch.sum(self.embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(z_e_flat, self.embedding.weight.t()))
        
        # Find nearest codebook entry
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
        
        # Quantize
        z_q_flat = torch.matmul(encodings, self.embedding.weight)
        z_q = z_q_flat.view(z_e.shape[0], z_e.shape[2], z_e.shape[3], self.embedding_dim)
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        # Compute losses
        if self.use_ema and self.training:
            # EMA codebook update
            self._ema_update(z_e_flat, encodings)
            
            # Only commitment loss (no codebook loss with EMA)
            commitment_loss = F.mse_loss(z_e, z_q.detach())
            loss = self.commitment_cost * commitment_loss
        else:
            # Codebook loss: ||sg[z_e] - e||^2
            codebook_loss = F.mse_loss(z_q, z_e.detach())
            
            # Commitment loss: ||z_e - sg[e]||^2
            commitment_loss = F.mse_loss(z_e, z_q.detach())
            
            loss = codebook_loss + self.commitment_cost * commitment_loss
        
        # Straight-through estimator
        z_q = z_e + (z_q - z_e).detach()
        
        # Compute perplexity (measure of codebook usage)
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        metrics = {
            'loss': loss,
            'perplexity': perplexity,
            'encoding_indices': encoding_indices.view(z_e.shape[0], z_e.shape[2], z_e.shape[3])
        }
        
        return z_q, loss, metrics
    
    def _ema_update(self, z_e_flat: torch.Tensor, encodings: torch.Tensor):
        """Update codebook using exponential moving average."""
        # Update cluster size
        self.ema_cluster_size.data.mul_(self.decay).add_(
            torch.sum(encodings, dim=0), alpha=1 - self.decay
        )
        
        # Laplace smoothing
        n = torch.sum(self.ema_cluster_size)
        self.ema_cluster_size.data.add_(self.epsilon).div_(n + self.num_embeddings * self.epsilon).mul_(n)
        
        # Update embeddings
        dw = torch.matmul(encodings.t(), z_e_flat)
        self.ema_w.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)
        
        # Normalize
        self.embedding.weight.data.copy_(self.ema_w / self.ema_cluster_size.unsqueeze(1))


class ProductQuantizer(nn.Module):
    """Product quantization with multiple subcodebooks."""
    
    def __init__(self, num_subcodebooks: int, subcodebook_size: int, embedding_dim: int,
                 commitment_cost: float = 0.25):
        super().__init__()
        
        assert embedding_dim % num_subcodebooks == 0, "embedding_dim must be divisible by num_subcodebooks"
        
        self.num_subcodebooks = num_subcodebooks
        self.subcodebook_size = subcodebook_size
        self.embedding_dim = embedding_dim
        self.sub_dim = embedding_dim // num_subcodebooks
        
        # Create multiple quantizers
        self.quantizers = nn.ModuleList([
            VectorQuantizer(subcodebook_size, self.sub_dim, commitment_cost, use_ema=True)
            for _ in range(num_subcodebooks)
        ])
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            z_e: Encoder outputs [batch_size, embedding_dim, height, width]
        
        Returns:
            z_q: Quantized outputs [batch_size, embedding_dim, height, width]
            loss: Total VQ loss across all subcodebooks
            metrics: Combined metrics
        """
        # Split along embedding dimension
        z_e_splits = torch.split(z_e, self.sub_dim, dim=1)
        
        # Quantize each subvector
        z_q_list = []
        total_loss = 0.0
        total_perplexity = 0.0
        
        for i, z_e_sub in enumerate(z_e_splits):
            z_q_sub, loss_sub, metrics_sub = self.quantizers[i](z_e_sub)
            z_q_list.append(z_q_sub)
            total_loss += loss_sub
            total_perplexity += metrics_sub['perplexity']
        
        # Concatenate quantized subvectors
        z_q = torch.cat(z_q_list, dim=1)
        
        metrics = {
            'loss': total_loss,
            'perplexity': total_perplexity / self.num_subcodebooks
        }
        
        return z_q, total_loss, metrics


# ============================================================
# 2. RESIDUAL BLOCKS
# ============================================================

class ResidualBlock(nn.Module):
    """Residual block for encoder/decoder."""
    
    def __init__(self, in_channels: int, out_channels: int, num_residual_layers: int = 2):
        super().__init__()
        
        layers = []
        for i in range(num_residual_layers):
            layers.extend([
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 
                          kernel_size=3, padding=1, bias=False),
            ])
        
        self.layers = nn.Sequential(*layers)
        
        # Projection for skip connection if dimensions change
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        else:
            self.skip = nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.skip(x) + self.layers(x)


# ============================================================
# 3. ENCODER AND DECODER
# ============================================================

class Encoder(nn.Module):
    """Convolutional encoder for VQ-VAE."""
    
    def __init__(self, in_channels: int = 3, hidden_dims: List[int] = [128, 256], 
                 embedding_dim: int = 64, num_residual_layers: int = 2):
        super().__init__()
        
        modules = []
        
        # Initial convolution
        current_dim = hidden_dims[0]
        modules.append(nn.Conv2d(in_channels, current_dim, kernel_size=4, stride=2, padding=1))
        
        # Downsampling blocks with residual connections
        for h_dim in hidden_dims:
            modules.append(ResidualBlock(current_dim, h_dim, num_residual_layers))
            if h_dim != current_dim:
                # Additional downsampling
                modules.append(nn.Conv2d(h_dim, h_dim, kernel_size=4, stride=2, padding=1))
            current_dim = h_dim
        
        # Final projection to embedding dimension
        modules.extend([
            nn.ReLU(inplace=True),
            nn.Conv2d(current_dim, embedding_dim, kernel_size=3, padding=1)
        ])
        
        self.encoder = nn.Sequential(*modules)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input images [batch_size, in_channels, height, width]
        
        Returns:
            z_e: Encoder outputs [batch_size, embedding_dim, h, w]
        """
        return self.encoder(x)


class Decoder(nn.Module):
    """Convolutional decoder for VQ-VAE."""
    
    def __init__(self, embedding_dim: int = 64, hidden_dims: List[int] = [256, 128], 
                 out_channels: int = 3, num_residual_layers: int = 2):
        super().__init__()
        
        # Reverse hidden dimensions for decoder
        hidden_dims = list(reversed(hidden_dims))
        
        modules = []
        
        # Initial projection
        current_dim = hidden_dims[0]
        modules.append(nn.Conv2d(embedding_dim, current_dim, kernel_size=3, padding=1))
        
        # Upsampling blocks with residual connections
        for i, h_dim in enumerate(hidden_dims):
            modules.append(ResidualBlock(current_dim, h_dim, num_residual_layers))
            
            if i < len(hidden_dims) - 1:
                # Upsampling (except last layer)
                next_dim = hidden_dims[i + 1]
                modules.append(nn.ConvTranspose2d(h_dim, next_dim, kernel_size=4, stride=2, padding=1))
                current_dim = next_dim
            else:
                current_dim = h_dim
        
        # Final upsampling to original resolution
        modules.extend([
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(current_dim, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Assume input normalized to [0, 1]
        ])
        
        self.decoder = nn.Sequential(*modules)
    
    def forward(self, z_q: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z_q: Quantized latents [batch_size, embedding_dim, h, w]
        
        Returns:
            x_recon: Reconstructed images [batch_size, out_channels, height, width]
        """
        return self.decoder(z_q)


# ============================================================
# 4. VQ-VAE MODEL
# ============================================================

class VQVAE(nn.Module):
    """Complete VQ-VAE model."""
    
    def __init__(self, in_channels: int = 3, hidden_dims: List[int] = [128, 256],
                 num_embeddings: int = 512, embedding_dim: int = 64,
                 commitment_cost: float = 0.25, use_ema: bool = True):
        super().__init__()
        
        self.encoder = Encoder(in_channels, hidden_dims, embedding_dim)
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost, use_ema)
        self.decoder = Decoder(embedding_dim, hidden_dims, in_channels)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            x: Input images [batch_size, in_channels, height, width]
        
        Returns:
            x_recon: Reconstructed images
            vq_loss: Vector quantization loss
            metrics: Dictionary of metrics
        """
        # Encode
        z_e = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, vq_metrics = self.quantizer(z_e)
        
        # Decode
        x_recon = self.decoder(z_q)
        
        metrics = {
            'vq_loss': vq_loss,
            'perplexity': vq_metrics['perplexity'],
            'encoding_indices': vq_metrics['encoding_indices']
        }
        
        return x_recon, vq_loss, metrics
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode input to discrete codes."""
        z_e = self.encoder(x)
        _, _, metrics = self.quantizer(z_e)
        return metrics['encoding_indices']
    
    def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
        """Decode discrete codes to images."""
        # Lookup embeddings
        z_q = self.quantizer.embedding(codes)
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        # Decode
        x_recon = self.decoder(z_q)
        return x_recon


# ============================================================
# 5. VQ-VAE-2 (HIERARCHICAL)
# ============================================================

class VQVAE2(nn.Module):
    """Hierarchical VQ-VAE with top and bottom levels."""
    
    def __init__(self, in_channels: int = 3, 
                 num_embeddings_top: int = 512, num_embeddings_bottom: int = 512,
                 embedding_dim: int = 64, commitment_cost: float = 0.25):
        super().__init__()
        
        # Top-level encoder (coarse, 4Γ— downsampling)
        self.encoder_top = Encoder(in_channels, [64, 128], embedding_dim)
        self.quantizer_top = VectorQuantizer(num_embeddings_top, embedding_dim, commitment_cost)
        
        # Bottom-level encoder (fine, 2Γ— downsampling from top)
        self.encoder_bottom = Encoder(in_channels, [128, 256], embedding_dim)
        self.quantizer_bottom = VectorQuantizer(num_embeddings_bottom, embedding_dim, commitment_cost)
        
        # Decoder (uses both levels)
        self.decoder_bottom = Decoder(embedding_dim * 2, [256, 128], in_channels)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            x: Input images [batch_size, in_channels, height, width]
        
        Returns:
            x_recon: Reconstructed images
            total_loss: Total VQ loss (top + bottom)
            metrics: Combined metrics
        """
        # Encode at both levels
        z_e_top = self.encoder_top(x)
        z_e_bottom = self.encoder_bottom(x)
        
        # Quantize
        z_q_top, loss_top, metrics_top = self.quantizer_top(z_e_top)
        z_q_bottom, loss_bottom, metrics_bottom = self.quantizer_bottom(z_e_bottom)
        
        # Upsample top to match bottom spatial size
        z_q_top_upsampled = F.interpolate(z_q_top, size=z_q_bottom.shape[2:], mode='nearest')
        
        # Concatenate top and bottom
        z_q_combined = torch.cat([z_q_top_upsampled, z_q_bottom], dim=1)
        
        # Decode
        x_recon = self.decoder_bottom(z_q_combined)
        
        total_loss = loss_top + loss_bottom
        
        metrics = {
            'vq_loss': total_loss,
            'perplexity_top': metrics_top['perplexity'],
            'perplexity_bottom': metrics_bottom['perplexity']
        }
        
        return x_recon, total_loss, metrics


# ============================================================
# 6. GUMBEL-SOFTMAX VQ-VAE
# ============================================================

class GumbelVectorQuantizer(nn.Module):
    """Gumbel-Softmax quantizer (differentiable alternative to hard quantization)."""
    
    def __init__(self, num_embeddings: int, embedding_dim: int, temperature: float = 1.0,
                 hard: bool = True):
        super().__init__()
        
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.temperature = temperature
        self.hard = hard
        
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            z_e: Encoder outputs [batch_size, embedding_dim, height, width]
        
        Returns:
            z_q: Quantized outputs (soft or hard)
            loss: Dummy loss (for API compatibility)
        """
        # Reshape
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
        
        # Compute similarities (negative distances)
        logits = -torch.sum((z_e_flat.unsqueeze(1) - self.embedding.weight.unsqueeze(0)) ** 2, dim=2)
        
        # Gumbel-Softmax
        if self.training:
            # Add Gumbel noise
            gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
            logits = (logits + gumbel_noise) / self.temperature
        
        # Softmax
        probs = F.softmax(logits, dim=1)
        
        if self.hard:
            # Hard quantization in forward, soft in backward
            indices = torch.argmax(probs, dim=1)
            hard_probs = F.one_hot(indices, self.num_embeddings).float()
            probs = hard_probs - probs.detach() + probs
        
        # Weighted sum of embeddings
        z_q_flat = torch.matmul(probs, self.embedding.weight)
        z_q = z_q_flat.view(z_e.shape[0], z_e.shape[2], z_e.shape[3], self.embedding_dim)
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        # No explicit loss (gradients flow through Gumbel-Softmax)
        loss = torch.tensor(0.0, device=z_e.device)
        
        return z_q, loss


# ============================================================
# 7. TRAINING
# ============================================================

class VQVAETrainer:
    """Trainer for VQ-VAE."""
    
    def __init__(self, model: VQVAE, optimizer: optim.Optimizer, recon_loss_type: str = 'mse'):
        self.model = model
        self.optimizer = optimizer
        self.recon_loss_type = recon_loss_type
    
    def train_step(self, x: torch.Tensor) -> Dict[str, float]:
        """Single training step.
        
        Args:
            x: Input batch [batch_size, channels, H, W]
        
        Returns:
            metrics: Training metrics
        """
        self.model.train()
        
        # Forward pass
        x_recon, vq_loss, metrics = self.model(x)
        
        # Reconstruction loss
        if self.recon_loss_type == 'mse':
            recon_loss = F.mse_loss(x_recon, x)
        elif self.recon_loss_type == 'l1':
            recon_loss = F.l1_loss(x_recon, x)
        else:
            raise ValueError(f"Unknown reconstruction loss: {self.recon_loss_type}")
        
        # Total loss
        total_loss = recon_loss + vq_loss
        
        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        # Metrics
        train_metrics = {
            'total_loss': total_loss.item(),
            'recon_loss': recon_loss.item(),
            'vq_loss': vq_loss.item(),
            'perplexity': metrics['perplexity'].item()
        }
        
        return train_metrics
    
    @torch.no_grad()
    def eval_step(self, x: torch.Tensor) -> Dict[str, float]:
        """Evaluation step."""
        self.model.eval()
        
        # Forward pass
        x_recon, vq_loss, metrics = self.model(x)
        
        # Reconstruction loss
        if self.recon_loss_type == 'mse':
            recon_loss = F.mse_loss(x_recon, x)
        elif self.recon_loss_type == 'l1':
            recon_loss = F.l1_loss(x_recon, x)
        
        # Total loss
        total_loss = recon_loss + vq_loss
        
        # Metrics
        eval_metrics = {
            'total_loss': total_loss.item(),
            'recon_loss': recon_loss.item(),
            'vq_loss': vq_loss.item(),
            'perplexity': metrics['perplexity'].item()
        }
        
        return eval_metrics


# ============================================================
# 8. DEMONSTRATIONS
# ============================================================

def demo_vector_quantizer():
    """Demonstrate vector quantization."""
    print("=" * 60)
    print("DEMO: Vector Quantizer")
    print("=" * 60)
    
    # Create quantizer
    quantizer = VectorQuantizer(num_embeddings=512, embedding_dim=64, use_ema=True)
    
    # Dummy encoder output
    z_e = torch.randn(4, 64, 8, 8)
    
    print(f"Encoder output shape: {z_e.shape}")
    print(f"Codebook size: {quantizer.num_embeddings}")
    print(f"Embedding dimension: {quantizer.embedding_dim}")
    
    # Forward pass
    z_q, loss, metrics = quantizer(z_e)
    
    print(f"\nQuantized output shape: {z_q.shape}")
    print(f"VQ loss: {loss.item():.4f}")
    print(f"Perplexity: {metrics['perplexity'].item():.2f} / {quantizer.num_embeddings}")
    print(f"Codebook usage: {metrics['perplexity'].item() / quantizer.num_embeddings * 100:.1f}%")
    print(f"Encoding indices shape: {metrics['encoding_indices'].shape}")
    print()


def demo_vqvae_model():
    """Demonstrate VQ-VAE model."""
    print("=" * 60)
    print("DEMO: VQ-VAE Model")
    print("=" * 60)
    
    # Create model
    model = VQVAE(in_channels=3, hidden_dims=[128, 256], num_embeddings=512, 
                  embedding_dim=64, use_ema=True)
    
    # Dummy input
    x = torch.rand(4, 3, 32, 32)
    
    print(f"Input shape: {x.shape}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Forward pass
    x_recon, vq_loss, metrics = model(x)
    
    print(f"\nReconstruction shape: {x_recon.shape}")
    print(f"VQ loss: {vq_loss.item():.4f}")
    print(f"Perplexity: {metrics['perplexity'].item():.2f}")
    print(f"Reconstruction error: {F.mse_loss(x_recon, x).item():.4f}")
    
    # Encode to codes
    codes = model.encode(x)
    print(f"\nDiscrete codes shape: {codes.shape}")
    print(f"Code range: [{codes.min().item()}, {codes.max().item()}]")
    
    # Decode from codes
    x_decoded = model.decode_codes(codes)
    print(f"Decoded shape: {x_decoded.shape}")
    print(f"Decode error: {F.mse_loss(x_decoded, x_recon).item():.6f}")
    print()


def demo_product_quantizer():
    """Demonstrate product quantization."""
    print("=" * 60)
    print("DEMO: Product Quantization")
    print("=" * 60)
    
    # Create product quantizer
    pq = ProductQuantizer(num_subcodebooks=4, subcodebook_size=256, 
                          embedding_dim=64, commitment_cost=0.25)
    
    # Dummy encoder output
    z_e = torch.randn(4, 64, 8, 8)
    
    print(f"Encoder output shape: {z_e.shape}")
    print(f"Number of subcodebooks: {pq.num_subcodebooks}")
    print(f"Subcodebook size: {pq.subcodebook_size}")
    print(f"Effective codebook size: {pq.subcodebook_size ** pq.num_subcodebooks:,}")
    
    # Forward pass
    z_q, loss, metrics = pq(z_e)
    
    print(f"\nQuantized output shape: {z_q.shape}")
    print(f"Total VQ loss: {loss.item():.4f}")
    print(f"Average perplexity: {metrics['perplexity'].item():.2f}")
    print()


def demo_vqvae2():
    """Demonstrate VQ-VAE-2 (hierarchical)."""
    print("=" * 60)
    print("DEMO: VQ-VAE-2 (Hierarchical)")
    print("=" * 60)
    
    # Create model
    model = VQVAE2(in_channels=3, num_embeddings_top=256, num_embeddings_bottom=512,
                   embedding_dim=64)
    
    # Dummy input
    x = torch.rand(2, 3, 32, 32)
    
    print(f"Input shape: {x.shape}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Forward pass
    x_recon, total_loss, metrics = model(x)
    
    print(f"\nReconstruction shape: {x_recon.shape}")
    print(f"Total VQ loss: {total_loss.item():.4f}")
    print(f"Top perplexity: {metrics['perplexity_top'].item():.2f}")
    print(f"Bottom perplexity: {metrics['perplexity_bottom'].item():.2f}")
    print()


def demo_gumbel_quantizer():
    """Demonstrate Gumbel-Softmax quantization."""
    print("=" * 60)
    print("DEMO: Gumbel-Softmax Quantizer")
    print("=" * 60)
    
    # Create quantizer
    quantizer = GumbelVectorQuantizer(num_embeddings=512, embedding_dim=64, 
                                      temperature=1.0, hard=True)
    
    # Dummy encoder output
    z_e = torch.randn(4, 64, 8, 8)
    
    print(f"Encoder output shape: {z_e.shape}")
    print(f"Temperature: {quantizer.temperature}")
    print(f"Hard quantization: {quantizer.hard}")
    
    # Forward pass
    z_q, _ = quantizer(z_e)
    
    print(f"\nQuantized output shape: {z_q.shape}")
    print(f"Quantization error: {F.mse_loss(z_e, z_q).item():.4f}")
    print()


def demo_training():
    """Demonstrate VQ-VAE training."""
    print("=" * 60)
    print("DEMO: VQ-VAE Training")
    print("=" * 60)
    
    # Create model
    model = VQVAE(in_channels=3, hidden_dims=[64, 128], num_embeddings=256, 
                  embedding_dim=32, use_ema=True)
    
    # Create optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # Create trainer
    trainer = VQVAETrainer(model, optimizer, recon_loss_type='mse')
    
    # Dummy training data
    x_train = torch.rand(16, 3, 32, 32)
    
    print("Training for 5 steps...")
    for step in range(5):
        metrics = trainer.train_step(x_train)
        print(f"Step {step+1}: Loss={metrics['total_loss']:.4f}, "
              f"Recon={metrics['recon_loss']:.4f}, "
              f"VQ={metrics['vq_loss']:.4f}, "
              f"Perplexity={metrics['perplexity']:.2f}")
    
    print("\nEvaluating...")
    x_val = torch.rand(8, 3, 32, 32)
    metrics = trainer.eval_step(x_val)
    print(f"Validation: Loss={metrics['total_loss']:.4f}, "
          f"Recon={metrics['recon_loss']:.4f}, "
          f"Perplexity={metrics['perplexity']:.2f}")
    print()


def print_performance_comparison():
    """Print comprehensive performance comparison and decision guide."""
    print("=" * 80)
    print("VQ-VAE: COMPREHENSIVE PERFORMANCE ANALYSIS")
    print("=" * 80)
    
    print("\n" + "=" * 80)
    print("1. RECONSTRUCTION QUALITY BENCHMARKS")
    print("=" * 80)
    
    recon_data = [
        ["Model", "Dataset", "PSNR ↑", "SSIM ↑", "LPIPS ↓", "Codebook Size"],
        ["-" * 20, "-" * 15, "-" * 10, "-" * 10, "-" * 10, "-" * 15],
        ["VQ-VAE", "CIFAR-10", "26.5 dB", "0.88", "0.18", "K=512"],
        ["VQ-VAE (large)", "CIFAR-10", "28.2 dB", "0.91", "0.14", "K=1024"],
        ["VQ-VAE-2", "CIFAR-10", "29.8 dB", "0.93", "0.11", "K=512Γ—2"],
        ["VQGAN", "ImageNet", "22.4 dB", "0.72", "0.09", "K=1024"],
        ["Standard VAE", "CIFAR-10", "24.1 dB", "0.83", "0.25", "N/A (cont.)"],
        ["Ξ²-VAE", "CIFAR-10", "22.8 dB", "0.79", "0.31", "N/A (cont.)"]
    ]
    
    for row in recon_data:
        print(f"{row[0]:<20} {row[1]:<15} {row[2]:<10} {row[3]:<10} {row[4]:<10} {row[5]:<15}")
    
    print("\nπŸ“Š Key Observations:")
    print("  β€’ VQ-VAE consistently outperforms standard VAE in reconstruction")
    print("  β€’ VQ-VAE-2 (hierarchical) achieves best quality via multi-scale codes")
    print("  β€’ Larger codebook (K=1024) improves PSNR by ~2 dB vs K=512")
    print("  β€’ LPIPS (perceptual) more important than PSNR for image quality")
    
    print("\n" + "=" * 80)
    print("2. CODEBOOK USAGE AND PERPLEXITY")
    print("=" * 80)
    
    codebook_data = [
        ["Configuration", "Codebook K", "Perplexity", "Usage %", "Active Codes"],
        ["-" * 25, "-" * 13, "-" * 12, "-" * 10, "-" * 15],
        ["VQ-VAE (gradient)", "512", "180-220", "35-43%", "180-220"],
        ["VQ-VAE (EMA)", "512", "320-410", "63-80%", "320-410"],
        ["Product-VQ (4Γ—256)", "256^4 β‰ˆ 4B", "850-980", "~85%", "850-980"],
        ["VQ-VAE-2 (top)", "512", "280-350", "55-68%", "280-350"],
        ["VQ-VAE-2 (bottom)", "512", "390-450", "76-88%", "390-450"],
        ["Random restart", "512", "420-480", "82-94%", "420-480"]
    ]
    
    for row in codebook_data:
        print(f"{row[0]:<25} {row[1]:<13} {row[2]:<12} {row[3]:<10} {row[4]:<15}")
    
    print("\nπŸ” Key Observations:")
    print("  β€’ EMA updates dramatically improve codebook usage (2Γ— perplexity)")
    print("  β€’ Random restart further increases usage to >80%")
    print("  β€’ Product quantization effectively uses exponentially large codebook")
    print("  β€’ Bottom-level codes (VQ-VAE-2) used more than top-level (finer details)")
    
    print("\n" + "=" * 80)
    print("3. GENERATION QUALITY (WITH PRIORS)")
    print("=" * 80)
    
    gen_data = [
        ["Model + Prior", "Dataset", "FID ↓", "IS ↑", "Sampling Time"],
        ["-" * 25, "-" * 15, "-" * 10, "-" * 10, "-" * 15],
        ["VQ-VAE + PixelCNN", "CIFAR-10", "18.5", "6.8", "~30s (seq.)"],
        ["VQ-VAE-2 + PixelCNN", "CIFAR-10", "12.3", "7.9", "~60s (2-level)"],
        ["VQ-VAE + Transformer", "CIFAR-10", "15.2", "7.4", "~20s (seq.)"],
        ["VQGAN + Transformer", "ImageNet 256", "7.94", "~80", "~40s"],
        ["DALL-E", "MS-COCO", "~28", "N/A", "~10s (parallel)"],
        ["GAN (StyleGAN2)", "CIFAR-10", "2.92", "9.18", "~0.1s"],
        ["Diffusion (DDPM)", "CIFAR-10", "3.17", "9.46", "~50s"]
    ]
    
    for row in gen_data:
        print(f"{row[0]:<25} {row[1]:<15} {row[2]:<10} {row[3]:<10} {row[4]:<15}")
    
    print("\n🎨 Key Observations:")
    print("  β€’ VQ-VAE + strong prior (Transformer) achieves good generation quality")
    print("  β€’ Two-stage training overhead, but enables powerful autoregressive models")
    print("  β€’ Sampling slower than GANs due to sequential generation")
    print("  β€’ VQGAN (+ adversarial) bridges gap to GANs/diffusion")
    
    print("\n" + "=" * 80)
    print("4. COMPUTATIONAL COMPLEXITY")
    print("=" * 80)
    
    complexity_data = [
        ["Operation", "Time", "Space", "Bottleneck"],
        ["-" * 30, "-" * 20, "-" * 15, "-" * 30],
        ["Encoder forward", "O(HWC)", "O(HWC)", "Convolutions"],
        ["Quantization (naive)", "O(HWKΒ·D)", "O(KΒ·D)", "Distance computation"],
        ["Quantization (efficient)", "O(HWΒ·K + HWΒ·D)", "O(KΒ·D)", "Matrix multiply"],
        ["Decoder forward", "O(HWC)", "O(HWC)", "Transposed convs"],
        ["EMA codebook update", "O(KΒ·D)", "O(KΒ·D)", "Embedding update"],
        ["Full forward pass", "O(HW(C+K))", "O(HWC + KΒ·D)", "Total"]
    ]
    
    for row in complexity_data:
        print(f"{row[0]:<30} {row[1]:<20} {row[2]:<15} {row[3]:<30}")
    
    print("\n⏱️ Typical Values (CIFAR-10, 32Γ—32):")
    print("  β€’ HΓ—W = 8Γ—8 (after 4Γ— downsampling)")
    print("  β€’ K = 512 codebook entries")
    print("  β€’ D = 64 embedding dimension")
    print("  β€’ Forward pass: ~5ms (GPU), Quantization: ~0.5ms")
    
    print("\n" + "=" * 80)
    print("5. COMPARISON WITH OTHER GENERATIVE MODELS")
    print("=" * 80)
    
    comparison_data = [
        ["Model", "Latent", "Training", "Sampling", "Quality", "Best Use"],
        ["-" * 12, "-" * 12, "-" * 15, "-" * 12, "-" * 10, "-" * 30],
        ["VQ-VAE", "Discrete", "Stable", "2-stage", "Good", "Compression, discrete repr."],
        ["VAE", "Continuous", "Stable", "Fast", "Moderate", "Latent interpolation"],
        ["GAN", "Noise", "Unstable", "Fast", "High", "High-quality images"],
        ["Flow", "Continuous", "Stable", "Fast", "Moderate", "Exact likelihood"],
        ["Diffusion", "Noise", "Stable", "Slow", "Highest", "State-of-art generation"],
        ["VQGAN", "Discrete", "Moderate", "2-stage", "High", "Text-to-image (DALL-E)"]
    ]
    
    for row in comparison_data:
        print(f"{row[0]:<12} {row[1]:<12} {row[2]:<15} {row[3]:<12} {row[4]:<10} {row[5]:<30}")
    
    print("\n" + "=" * 80)
    print("6. TRAINING HYPERPARAMETERS")
    print("=" * 80)
    
    hyperparam_data = [
        ["Parameter", "Typical Range", "Recommended", "Impact"],
        ["-" * 25, "-" * 20, "-" * 20, "-" * 35],
        ["Codebook size K", "128 - 2048", "512", "Larger = more capacity"],
        ["Embedding dim D", "32 - 256", "64", "Higher = more expressive"],
        ["Commitment Ξ²", "0.1 - 1.0", "0.25", "Higher = stronger commitment"],
        ["Learning rate", "1e-5 - 1e-3", "1e-4 (Adam)", "Standard impact"],
        ["Batch size", "32 - 256", "64-128", "Larger helps codebook stats"],
        ["EMA decay Ξ³", "0.9 - 0.999", "0.99", "Higher = slower updates"],
        ["Hidden dims", "64-512", "[128, 256]", "Depth vs. capacity"],
        ["Num residual layers", "1 - 4", "2", "More = deeper features"]
    ]
    
    for row in hyperparam_data:
        print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<35}")
    
    print("\n" + "=" * 80)
    print("7. DECISION GUIDE: VQ-VAE vs. ALTERNATIVES")
    print("=" * 80)
    
    print("\nβœ… USE VQ-VAE WHEN:")
    print("  1. Discrete latent representations needed (tokenization for text/image)")
    print("  2. Two-stage generation acceptable (VQ-VAE β†’ prior)")
    print("  3. Autoregressive priors desired (PixelCNN, Transformers)")
    print("  4. Compression with discrete codes (entropy coding, efficient storage)")
    print("  5. Interpretable codebook valuable (discrete units, clustering)")
    print("  6. Avoiding posterior collapse (VAE KL β†’ 0 problem)")
    print("  7. Multi-modal learning (shared discrete space across modalities)")
    
    print("\n❌ AVOID VQ-VAE WHEN:")
    print("  1. Continuous latent interpolation required (use VAE/Flow)")
    print("  2. Single-stage end-to-end generation preferred")
    print("  3. Fast sampling critical (GANs 100Γ— faster for generation)")
    print("  4. State-of-art image quality needed (diffusion models better)")
    print("  5. Codebook collapse cannot be tolerated")
    print("  6. Simple baseline sufficient (VAE easier to implement)")
    
    print("\n" + "=" * 80)
    print("8. VARIANT SELECTION GUIDE")
    print("=" * 80)
    
    variant_data = [
        ["Variant", "Advantage", "Disadvantage", "When to Use"],
        ["-" * 20, "-" * 30, "-" * 30, "-" * 35],
        ["VQ-VAE (baseline)", "Simple, stable", "Single scale, ~60% usage", "Default choice"],
        ["VQ-VAE + EMA", "Better codebook usage (80%)", "Slight complexity", "Recommended default"],
        ["VQ-VAE-2", "Multi-scale, best quality", "2Γ— complexity, 2 codebooks", "High-res images (256Γ—256+)"],
        ["Product-VQ", "Huge effective codebook", "More hyperparameters", "When K>2048 needed"],
        ["Residual VQ", "Iterative refinement", "L quantizers needed", "Audio (high bitrate)"],
        ["Gumbel VQ", "Fully differentiable", "Soft codes, temperature tuning", "Experimental/research"],
        ["VQGAN", "Perceptual quality", "Adversarial training", "Photorealistic images"]
    ]
    
    for row in variant_data:
        print(f"{row[0]:<20} {row[1]:<30} {row[2]:<30} {row[3]:<35}")
    
    print("\n" + "=" * 80)
    print("9. TROUBLESHOOTING COMMON ISSUES")
    print("=" * 80)
    
    print("\nπŸ”§ CODEBOOK COLLAPSE (Perplexity < 50%):")
    print("  β†’ Switch to EMA updates (use_ema=True)")
    print("  β†’ Add random restart (reinitialize unused codes)")
    print("  β†’ Lower commitment weight (Ξ²=0.1 instead of 0.25)")
    print("  β†’ Increase codebook size (K=1024 instead of 512)")
    print("  β†’ Check encoder initialization (k-means on outputs)")
    
    print("\nπŸ”§ POOR RECONSTRUCTION:")
    print("  β†’ Increase embedding dimension (D=128 instead of 64)")
    print("  β†’ Add more residual layers (num_residual_layers=3)")
    print("  β†’ Reduce commitment weight (allows more deviation)")
    print("  β†’ Check decoder capacity (mirror encoder)")
    print("  β†’ Try perceptual loss (LPIPS) instead of MSE")
    
    print("\nπŸ”§ TRAINING INSTABILITY:")
    print("  β†’ Gradient clipping (max_norm=1.0)")
    print("  β†’ Learning rate warmup (0β†’1e-4 over 1000 steps)")
    print("  β†’ Batch normalization (or GroupNorm)")
    print("  β†’ Reduce learning rate (1e-5)")
    print("  β†’ Check for NaN values (codebook initialization)")
    
    print("\n" + "=" * 80)
    print("10. APPLICATIONS AND SOTA RESULTS")
    print("=" * 80)
    
    print("\n🎨 IMAGE GENERATION:")
    print("  β€’ DALL-E (2021): 12B Transformer + VQ-VAE, text-to-image")
    print("  β€’ Parti (2022): 20B Transformer + ViT-VQGAN, photorealistic")
    print("  β€’ ImageGPT: VQ-VAE tokenization + GPT, unsupervised vision")
    print("  β€’ VQGAN: GAN + VQ-VAE, FID 7.94 on ImageNet 256Γ—256")
    
    print("\n🎡 AUDIO SYNTHESIS:")
    print("  β€’ Jukebox (2020): 3-level VQ-VAE, generates music with lyrics")
    print("  β€’ SoundStream (2021): RVQ (8 levels), 3 kbps compression")
    print("  β€’ EnCodec (2022): RVQ + adversarial, better than MP3")
    
    print("\n🎬 VIDEO GENERATION:")
    print("  β€’ VideoGPT: 3D VQ-VAE + Transformer, FVD 170 on UCF-101")
    print("  β€’ NUWA (2021): Text-to-video, 3D-Transformer prior")
    print("  β€’ TATS: VQ-VAE + Transformer, long videos (>100 frames)")
    
    print("\nπŸ“¦ COMPRESSION:")
    print("  β€’ Learned image compression: Outperforms JPEG at low bitrates")
    print("  β€’ Neural audio codecs: 3-12 kbps (vs. 128 kbps MP3)")
    print("  β€’ Video compression: ~10Γ— bitrate reduction vs. H.264")
    
    print("\n" + "=" * 80)


# ============================================================
# RUN ALL DEMONSTRATIONS
# ============================================================

if __name__ == "__main__":
    demo_vector_quantizer()
    demo_vqvae_model()
    demo_product_quantizer()
    demo_vqvae2()
    demo_gumbel_quantizer()
    demo_training()
    print_performance_comparison()