import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. Motivation: Multi-Scale Latent RepresentationsΒΆ

Limitations of Standard VAEΒΆ

Single latent layer: $\(p(x, z) = p(z) p_{\theta}(x | z)\)$

Problems:

  • All information compressed into one layer

  • No hierarchical structure (coarse β†’ fine)

  • Posterior collapse in deep models

Hierarchical VAE SolutionΒΆ

Multiple latent layers \(z_1, z_2, \ldots, z_L\):

Generative model (top-down): $\(p(x, z_{1:L}) = p(z_L) \prod_{l=1}^{L-1} p_{\theta}(z_l | z_{l+1}) \cdot p_{\theta}(x | z_1)\)$

Inference model (bottom-up + top-down): $\(q_{\phi}(z_{1:L} | x) = \prod_{l=1}^L q_{\phi}(z_l | z_{<l}, x)\)$

AdvantagesΒΆ

  1. Multi-scale: Different layers capture different abstractions

  2. Better posterior: Avoids collapse

  3. Expressive: More flexible generative process

  4. Interpretable: Hierarchical latent structure

πŸ“š Reference Materials:

2. ELBO for Hierarchical VAEΒΆ

Standard ELBOΒΆ

\[\log p(x) \geq \mathbb{E}_{q(z_{1:L}|x)} [\log p(x, z_{1:L})] - \mathbb{E}_{q(z_{1:L}|x)}[\log q(z_{1:L}|x)]\]

DecompositionΒΆ

\[\mathcal{L} = \mathbb{E}_{q} [\log p_{\theta}(x | z_1)] - \sum_{l=1}^L KL(q_{\phi}(z_l | z_{<l}, x) \| p_{\theta}(z_l | z_{>l}))\]

Interpretation:

  • Reconstruction from \(z_1\) (bottom layer)

  • KL penalty at each layer (regularization)

Key Design Choice: Inference ArchitectureΒΆ

Bottom-up only: $\(q(z_l | z_{<l}, x) = q(z_l | x)\)$

Bottom-up + top-down (Ladder VAE): $\(q(z_l | z_{<l}, x) = q(z_l | z_{l+1}, x)\)$

Ladder VAE uses both deterministic and stochastic paths.

Simple 2-Layer Hierarchical VAEΒΆ

A hierarchical VAE introduces multiple levels of latent variables, forming a chain \(z_2 \to z_1 \to x\) where higher-level latents capture global structure and lower-level latents capture local detail. The generative model factors as \(p(x, z_1, z_2) = p(z_2)\, p(z_1 | z_2)\, p(x | z_1)\), giving the model strictly more expressive power than a single-layer VAE. The inference network similarly decomposes: \(q(z_1 | x)\) first, then \(q(z_2 | z_1)\) (or \(q(z_2 | x, z_1)\) for a top-down inference path). This hierarchy encourages the latent space to organize into a meaningful abstraction ladder – a principle that underpins modern deep generative models like NVAE and VDVAE.

class HierarchicalVAE(nn.Module):
    """2-layer hierarchical VAE with z1 and z2."""
    def __init__(self, input_dim=784, z1_dim=32, z2_dim=16, hidden_dim=256):
        super().__init__()
        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        
        # Encoder (bottom-up): x β†’ z1, z2
        self.enc_z1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.enc_z1_mu = nn.Linear(hidden_dim, z1_dim)
        self.enc_z1_logvar = nn.Linear(hidden_dim, z1_dim)
        
        self.enc_z2 = nn.Sequential(
            nn.Linear(z1_dim + hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        self.enc_z2_mu = nn.Linear(hidden_dim // 2, z2_dim)
        self.enc_z2_logvar = nn.Linear(hidden_dim // 2, z2_dim)
        
        # Prior p(z1 | z2)
        self.prior_z1 = nn.Sequential(
            nn.Linear(z2_dim, hidden_dim // 2),
            nn.ReLU()
        )
        self.prior_z1_mu = nn.Linear(hidden_dim // 2, z1_dim)
        self.prior_z1_logvar = nn.Linear(hidden_dim // 2, z1_dim)
        
        # Decoder: z1 β†’ x
        self.decoder = nn.Sequential(
            nn.Linear(z1_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        """Bottom-up inference."""
        # q(z1 | x)
        h1 = self.enc_z1(x)
        z1_mu = self.enc_z1_mu(h1)
        z1_logvar = self.enc_z1_logvar(h1)
        z1 = self.reparameterize(z1_mu, z1_logvar)
        
        # q(z2 | z1, x)
        h2 = self.enc_z2(torch.cat([z1, h1], dim=1))
        z2_mu = self.enc_z2_mu(h2)
        z2_logvar = self.enc_z2_logvar(h2)
        z2 = self.reparameterize(z2_mu, z2_logvar)
        
        return z1, z1_mu, z1_logvar, z2, z2_mu, z2_logvar
    
    def get_prior_z1(self, z2):
        """p(z1 | z2)"""
        h = self.prior_z1(z2)
        mu = self.prior_z1_mu(h)
        logvar = self.prior_z1_logvar(h)
        return mu, logvar
    
    def decode(self, z1):
        return self.decoder(z1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # Encode
        z1, z1_mu_q, z1_logvar_q, z2, z2_mu_q, z2_logvar_q = self.encode(x)
        
        # Prior p(z1 | z2)
        z1_mu_p, z1_logvar_p = self.get_prior_z1(z2)
        
        # Decode
        recon = self.decode(z1)
        
        return recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q

def loss_function(recon_x, x, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q):
    """Hierarchical VAE loss."""
    # Reconstruction
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL(q(z2|x) || p(z2)) where p(z2) = N(0, I)
    KLD_z2 = -0.5 * torch.sum(1 + z2_logvar_q - z2_mu_q.pow(2) - z2_logvar_q.exp())
    
    # KL(q(z1|z2,x) || p(z1|z2))
    KLD_z1 = -0.5 * torch.sum(
        1 + z1_logvar_q - z1_logvar_p - 
        ((z1_mu_q - z1_mu_p).pow(2) + z1_logvar_q.exp()) / z1_logvar_p.exp()
    )
    
    return BCE + KLD_z1 + KLD_z2, BCE, KLD_z1, KLD_z2

# Test
model = HierarchicalVAE().to(device)
x = torch.randn(16, 784).to(device)
recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q = model(x)
loss, bce, kl1, kl2 = loss_function(recon, x, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Loss: {loss.item():.2f}, BCE: {bce.item():.2f}, KL_z1: {kl1.item():.2f}, KL_z2: {kl2.item():.2f}")

Training on MNISTΒΆ

Training a hierarchical VAE on MNIST optimizes the same ELBO objective as a standard VAE, but now summed over all latent layers: \(\mathcal{L} = \mathbb{E}_q[\log p(x|z_1)] - \text{KL}(q(z_1|x) \| p(z_1|z_2)) - \text{KL}(q(z_2|z_1) \| p(z_2))\). A common training challenge is posterior collapse, where the model ignores higher-level latents. Techniques like KL annealing (gradually increasing the KL weight from 0 to 1) and free bits (setting a minimum KL per group) help ensure all layers are utilized. Monitoring the KL divergence at each level during training is the primary diagnostic for detecting and mitigating collapse.

# Data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

# Training
model = HierarchicalVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    total_bce = 0
    total_kl1 = 0
    total_kl2 = 0
    
    for images, _ in loader:
        images = images.view(images.size(0), -1).to(device)
        optimizer.zero_grad()
        
        recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q = model(images)
        loss, bce, kl1, kl2 = loss_function(
            recon, images, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q
        )
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_bce += bce.item()
        total_kl1 += kl1.item()
        total_kl2 += kl2.item()
    
    n = len(loader.dataset)
    return total_loss/n, total_bce/n, total_kl1/n, total_kl2/n

n_epochs = 20
history = {'loss': [], 'bce': [], 'kl1': [], 'kl2': []}

for epoch in range(n_epochs):
    loss, bce, kl1, kl2 = train_epoch(model, train_loader, optimizer)
    
    history['loss'].append(loss)
    history['bce'].append(bce)
    history['kl1'].append(kl1)
    history['kl2'].append(kl2)
    
    print(f"Epoch {epoch+1}/{n_epochs}: Loss={loss:.4f}, BCE={bce:.4f}, "
          f"KL_z1={kl1:.4f}, KL_z2={kl2:.4f}")

print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax1 = axes[0]
ax1.plot(history['loss'], linewidth=2, label='Total Loss')
ax1.plot(history['bce'], linewidth=2, label='Reconstruction')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=13)
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2 = axes[1]
ax2.plot(history['kl1'], linewidth=2, label='KL(z1)')
ax2.plot(history['kl2'], linewidth=2, label='KL(z2)')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('KL Divergence', fontsize=12)
ax2.set_title('KL Terms by Layer', fontsize=13)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Analyzing Hierarchical RepresentationsΒΆ

After training, we inspect what each latent level has learned by encoding test images and examining the activations. The top level (\(z_2\)) should capture global attributes like digit identity and overall shape, while the bottom level (\(z_1\)) should encode local details like stroke thickness and exact pixel placement. Visualizing reconstructions with only \(z_2\) active (sampling \(z_1\) from the prior) versus both levels active reveals the information each level contributes, confirming whether the hierarchy has learned a meaningful decomposition.

model.eval()

# Generate samples at different levels
with torch.no_grad():
    # Sample z2 from prior
    z2 = torch.randn(10, model.z2_dim).to(device)
    
    # Generate z1 from p(z1|z2)
    z1_mu, z1_logvar = model.get_prior_z1(z2)
    z1 = model.reparameterize(z1_mu, z1_logvar)
    
    # Decode
    samples = model.decode(z1).cpu().view(-1, 28, 28)

fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
    axes[i].imshow(samples[i], cmap='gray')
    axes[i].axis('off')
plt.suptitle('Hierarchical Generation: z2 β†’ z1 β†’ x', fontsize=14)
plt.tight_layout()
plt.show()

# Vary z2 while keeping z1 structure
print("\nVarying z2 (high-level features):")
with torch.no_grad():
    z2_base = torch.randn(1, model.z2_dim).to(device)
    
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i in range(10):
        z2 = z2_base + 0.5 * torch.randn(1, model.z2_dim).to(device)
        z1_mu, z1_logvar = model.get_prior_z1(z2)
        z1 = z1_mu  # Use mean
        sample = model.decode(z1).cpu().view(28, 28)
        axes[i].imshow(sample, cmap='gray')
        axes[i].axis('off')
    plt.suptitle('Varying z2 (high-level)', fontsize=14)
    plt.tight_layout()
    plt.show()

Latent Space InterpolationΒΆ

Interpolating between two points in latent space produces a smooth morphing between the corresponding images – one of the hallmarks of a well-trained generative model. In a hierarchical VAE, we can interpolate at different levels independently: interpolating \(z_2\) while keeping \(z_1\) fixed changes the global structure (e.g., digit identity), while interpolating \(z_1\) changes fine details. Smooth, semantically meaningful interpolations indicate that the latent space has a well-organized geometry, which is important for downstream tasks like data augmentation, style transfer, and conditional generation.

# Get two test images
img1, _ = test_dataset[0]
img2, _ = test_dataset[5]

with torch.no_grad():
    x1 = img1.view(1, -1).to(device)
    x2 = img2.view(1, -1).to(device)
    
    # Encode both
    z1_1, z1_mu_1, z1_logvar_1, z2_1, z2_mu_1, z2_logvar_1 = model.encode(x1)
    z1_2, z1_mu_2, z1_logvar_2, z2_2, z2_mu_2, z2_logvar_2 = model.encode(x2)
    
    # Interpolate in z2 space
    n_steps = 10
    alphas = np.linspace(0, 1, n_steps)
    
    fig, axes = plt.subplots(2, n_steps, figsize=(15, 4))
    
    for i, alpha in enumerate(alphas):
        # Interpolate z2
        z2_interp = (1 - alpha) * z2_mu_1 + alpha * z2_mu_2
        
        # Generate z1 from interpolated z2
        z1_mu, z1_logvar = model.get_prior_z1(z2_interp)
        
        # Decode
        sample = model.decode(z1_mu).cpu().view(28, 28)
        axes[0, i].imshow(sample, cmap='gray')
        axes[0, i].axis('off')
        
        # Also interpolate in z1 space
        z1_interp = (1 - alpha) * z1_mu_1 + alpha * z1_mu_2
        sample_z1 = model.decode(z1_interp).cpu().view(28, 28)
        axes[1, i].imshow(sample_z1, cmap='gray')
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('z2 interp', fontsize=11)
    axes[1, 0].set_ylabel('z1 interp', fontsize=11)
    plt.suptitle('Latent Space Interpolation', fontsize=14)
    plt.tight_layout()
    plt.show()

print("z2 interpolation shows smoother high-level transitions")
print("z1 interpolation shows more detailed low-level changes")

SummaryΒΆ

Key Contributions:ΒΆ

  1. Hierarchical latents: Multi-scale representation (z2=coarse, z1=fine)

  2. Conditional priors: \(p(z_l | z_{l+1})\) models dependencies

  3. Better ELBO: KL terms at each layer prevent posterior collapse

  4. Expressive: More flexible than single-layer VAE

Hierarchical VAE Variants:ΒΆ

  • Ladder VAE (SΓΈnderby et al., 2016): Bottom-up + top-down paths

  • BIVA (MaalΓΈe et al., 2019): Bidirectional inference

  • NVAE (Vahdat & Kautz, 2020): Very deep hierarchies

  • VD-VAE (Child, 2020): Variational diffusion

When to Use:ΒΆ

  • Complex data (images, audio, video)

  • Need interpretable hierarchy

  • Avoid posterior collapse

  • Multi-resolution generation

Challenges:ΒΆ

  • Training complexity (multiple KL terms)

  • Balancing layers (KL annealing)

  • Computational cost (deeper networks)

Next Steps:ΒΆ

  • 12_vq_vae.ipynb - Discrete hierarchies

  • 14_normalizing_flows.ipynb - Exact likelihood

  • 03_variational_autoencoders_advanced.ipynb - Review VAE basics