import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. Motivation: Multi-Scale Latent RepresentationsΒΆ
Limitations of Standard VAEΒΆ
Single latent layer: $\(p(x, z) = p(z) p_{\theta}(x | z)\)$
Problems:
All information compressed into one layer
No hierarchical structure (coarse β fine)
Posterior collapse in deep models
Hierarchical VAE SolutionΒΆ
Multiple latent layers \(z_1, z_2, \ldots, z_L\):
Generative model (top-down): $\(p(x, z_{1:L}) = p(z_L) \prod_{l=1}^{L-1} p_{\theta}(z_l | z_{l+1}) \cdot p_{\theta}(x | z_1)\)$
Inference model (bottom-up + top-down): $\(q_{\phi}(z_{1:L} | x) = \prod_{l=1}^L q_{\phi}(z_l | z_{<l}, x)\)$
AdvantagesΒΆ
Multi-scale: Different layers capture different abstractions
Better posterior: Avoids collapse
Expressive: More flexible generative process
Interpretable: Hierarchical latent structure
π Reference Materials:
2. ELBO for Hierarchical VAEΒΆ
Standard ELBOΒΆ
DecompositionΒΆ
Interpretation:
Reconstruction from \(z_1\) (bottom layer)
KL penalty at each layer (regularization)
Key Design Choice: Inference ArchitectureΒΆ
Bottom-up only: $\(q(z_l | z_{<l}, x) = q(z_l | x)\)$
Bottom-up + top-down (Ladder VAE): $\(q(z_l | z_{<l}, x) = q(z_l | z_{l+1}, x)\)$
Ladder VAE uses both deterministic and stochastic paths.
Simple 2-Layer Hierarchical VAEΒΆ
A hierarchical VAE introduces multiple levels of latent variables, forming a chain \(z_2 \to z_1 \to x\) where higher-level latents capture global structure and lower-level latents capture local detail. The generative model factors as \(p(x, z_1, z_2) = p(z_2)\, p(z_1 | z_2)\, p(x | z_1)\), giving the model strictly more expressive power than a single-layer VAE. The inference network similarly decomposes: \(q(z_1 | x)\) first, then \(q(z_2 | z_1)\) (or \(q(z_2 | x, z_1)\) for a top-down inference path). This hierarchy encourages the latent space to organize into a meaningful abstraction ladder β a principle that underpins modern deep generative models like NVAE and VDVAE.
class HierarchicalVAE(nn.Module):
"""2-layer hierarchical VAE with z1 and z2."""
def __init__(self, input_dim=784, z1_dim=32, z2_dim=16, hidden_dim=256):
super().__init__()
self.z1_dim = z1_dim
self.z2_dim = z2_dim
# Encoder (bottom-up): x β z1, z2
self.enc_z1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.enc_z1_mu = nn.Linear(hidden_dim, z1_dim)
self.enc_z1_logvar = nn.Linear(hidden_dim, z1_dim)
self.enc_z2 = nn.Sequential(
nn.Linear(z1_dim + hidden_dim, hidden_dim // 2),
nn.ReLU()
)
self.enc_z2_mu = nn.Linear(hidden_dim // 2, z2_dim)
self.enc_z2_logvar = nn.Linear(hidden_dim // 2, z2_dim)
# Prior p(z1 | z2)
self.prior_z1 = nn.Sequential(
nn.Linear(z2_dim, hidden_dim // 2),
nn.ReLU()
)
self.prior_z1_mu = nn.Linear(hidden_dim // 2, z1_dim)
self.prior_z1_logvar = nn.Linear(hidden_dim // 2, z1_dim)
# Decoder: z1 β x
self.decoder = nn.Sequential(
nn.Linear(z1_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
"""Bottom-up inference."""
# q(z1 | x)
h1 = self.enc_z1(x)
z1_mu = self.enc_z1_mu(h1)
z1_logvar = self.enc_z1_logvar(h1)
z1 = self.reparameterize(z1_mu, z1_logvar)
# q(z2 | z1, x)
h2 = self.enc_z2(torch.cat([z1, h1], dim=1))
z2_mu = self.enc_z2_mu(h2)
z2_logvar = self.enc_z2_logvar(h2)
z2 = self.reparameterize(z2_mu, z2_logvar)
return z1, z1_mu, z1_logvar, z2, z2_mu, z2_logvar
def get_prior_z1(self, z2):
"""p(z1 | z2)"""
h = self.prior_z1(z2)
mu = self.prior_z1_mu(h)
logvar = self.prior_z1_logvar(h)
return mu, logvar
def decode(self, z1):
return self.decoder(z1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# Encode
z1, z1_mu_q, z1_logvar_q, z2, z2_mu_q, z2_logvar_q = self.encode(x)
# Prior p(z1 | z2)
z1_mu_p, z1_logvar_p = self.get_prior_z1(z2)
# Decode
recon = self.decode(z1)
return recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q
def loss_function(recon_x, x, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q):
"""Hierarchical VAE loss."""
# Reconstruction
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KL(q(z2|x) || p(z2)) where p(z2) = N(0, I)
KLD_z2 = -0.5 * torch.sum(1 + z2_logvar_q - z2_mu_q.pow(2) - z2_logvar_q.exp())
# KL(q(z1|z2,x) || p(z1|z2))
KLD_z1 = -0.5 * torch.sum(
1 + z1_logvar_q - z1_logvar_p -
((z1_mu_q - z1_mu_p).pow(2) + z1_logvar_q.exp()) / z1_logvar_p.exp()
)
return BCE + KLD_z1 + KLD_z2, BCE, KLD_z1, KLD_z2
# Test
model = HierarchicalVAE().to(device)
x = torch.randn(16, 784).to(device)
recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q = model(x)
loss, bce, kl1, kl2 = loss_function(recon, x, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Loss: {loss.item():.2f}, BCE: {bce.item():.2f}, KL_z1: {kl1.item():.2f}, KL_z2: {kl2.item():.2f}")
Training on MNISTΒΆ
Training a hierarchical VAE on MNIST optimizes the same ELBO objective as a standard VAE, but now summed over all latent layers: \(\mathcal{L} = \mathbb{E}_q[\log p(x|z_1)] - \text{KL}(q(z_1|x) \| p(z_1|z_2)) - \text{KL}(q(z_2|z_1) \| p(z_2))\). A common training challenge is posterior collapse, where the model ignores higher-level latents. Techniques like KL annealing (gradually increasing the KL weight from 0 to 1) and free bits (setting a minimum KL per group) help ensure all layers are utilized. Monitoring the KL divergence at each level during training is the primary diagnostic for detecting and mitigating collapse.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)
# Training
model = HierarchicalVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train_epoch(model, loader, optimizer):
model.train()
total_loss = 0
total_bce = 0
total_kl1 = 0
total_kl2 = 0
for images, _ in loader:
images = images.view(images.size(0), -1).to(device)
optimizer.zero_grad()
recon, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q = model(images)
loss, bce, kl1, kl2 = loss_function(
recon, images, z1_mu_q, z1_logvar_q, z1_mu_p, z1_logvar_p, z2_mu_q, z2_logvar_q
)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_bce += bce.item()
total_kl1 += kl1.item()
total_kl2 += kl2.item()
n = len(loader.dataset)
return total_loss/n, total_bce/n, total_kl1/n, total_kl2/n
n_epochs = 20
history = {'loss': [], 'bce': [], 'kl1': [], 'kl2': []}
for epoch in range(n_epochs):
loss, bce, kl1, kl2 = train_epoch(model, train_loader, optimizer)
history['loss'].append(loss)
history['bce'].append(bce)
history['kl1'].append(kl1)
history['kl2'].append(kl2)
print(f"Epoch {epoch+1}/{n_epochs}: Loss={loss:.4f}, BCE={bce:.4f}, "
f"KL_z1={kl1:.4f}, KL_z2={kl2:.4f}")
print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax1 = axes[0]
ax1.plot(history['loss'], linewidth=2, label='Total Loss')
ax1.plot(history['bce'], linewidth=2, label='Reconstruction')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=13)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2 = axes[1]
ax2.plot(history['kl1'], linewidth=2, label='KL(z1)')
ax2.plot(history['kl2'], linewidth=2, label='KL(z2)')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('KL Divergence', fontsize=12)
ax2.set_title('KL Terms by Layer', fontsize=13)
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Analyzing Hierarchical RepresentationsΒΆ
After training, we inspect what each latent level has learned by encoding test images and examining the activations. The top level (\(z_2\)) should capture global attributes like digit identity and overall shape, while the bottom level (\(z_1\)) should encode local details like stroke thickness and exact pixel placement. Visualizing reconstructions with only \(z_2\) active (sampling \(z_1\) from the prior) versus both levels active reveals the information each level contributes, confirming whether the hierarchy has learned a meaningful decomposition.
model.eval()
# Generate samples at different levels
with torch.no_grad():
# Sample z2 from prior
z2 = torch.randn(10, model.z2_dim).to(device)
# Generate z1 from p(z1|z2)
z1_mu, z1_logvar = model.get_prior_z1(z2)
z1 = model.reparameterize(z1_mu, z1_logvar)
# Decode
samples = model.decode(z1).cpu().view(-1, 28, 28)
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
axes[i].imshow(samples[i], cmap='gray')
axes[i].axis('off')
plt.suptitle('Hierarchical Generation: z2 β z1 β x', fontsize=14)
plt.tight_layout()
plt.show()
# Vary z2 while keeping z1 structure
print("\nVarying z2 (high-level features):")
with torch.no_grad():
z2_base = torch.randn(1, model.z2_dim).to(device)
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
z2 = z2_base + 0.5 * torch.randn(1, model.z2_dim).to(device)
z1_mu, z1_logvar = model.get_prior_z1(z2)
z1 = z1_mu # Use mean
sample = model.decode(z1).cpu().view(28, 28)
axes[i].imshow(sample, cmap='gray')
axes[i].axis('off')
plt.suptitle('Varying z2 (high-level)', fontsize=14)
plt.tight_layout()
plt.show()
Latent Space InterpolationΒΆ
Interpolating between two points in latent space produces a smooth morphing between the corresponding images β one of the hallmarks of a well-trained generative model. In a hierarchical VAE, we can interpolate at different levels independently: interpolating \(z_2\) while keeping \(z_1\) fixed changes the global structure (e.g., digit identity), while interpolating \(z_1\) changes fine details. Smooth, semantically meaningful interpolations indicate that the latent space has a well-organized geometry, which is important for downstream tasks like data augmentation, style transfer, and conditional generation.
# Get two test images
img1, _ = test_dataset[0]
img2, _ = test_dataset[5]
with torch.no_grad():
x1 = img1.view(1, -1).to(device)
x2 = img2.view(1, -1).to(device)
# Encode both
z1_1, z1_mu_1, z1_logvar_1, z2_1, z2_mu_1, z2_logvar_1 = model.encode(x1)
z1_2, z1_mu_2, z1_logvar_2, z2_2, z2_mu_2, z2_logvar_2 = model.encode(x2)
# Interpolate in z2 space
n_steps = 10
alphas = np.linspace(0, 1, n_steps)
fig, axes = plt.subplots(2, n_steps, figsize=(15, 4))
for i, alpha in enumerate(alphas):
# Interpolate z2
z2_interp = (1 - alpha) * z2_mu_1 + alpha * z2_mu_2
# Generate z1 from interpolated z2
z1_mu, z1_logvar = model.get_prior_z1(z2_interp)
# Decode
sample = model.decode(z1_mu).cpu().view(28, 28)
axes[0, i].imshow(sample, cmap='gray')
axes[0, i].axis('off')
# Also interpolate in z1 space
z1_interp = (1 - alpha) * z1_mu_1 + alpha * z1_mu_2
sample_z1 = model.decode(z1_interp).cpu().view(28, 28)
axes[1, i].imshow(sample_z1, cmap='gray')
axes[1, i].axis('off')
axes[0, 0].set_ylabel('z2 interp', fontsize=11)
axes[1, 0].set_ylabel('z1 interp', fontsize=11)
plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.show()
print("z2 interpolation shows smoother high-level transitions")
print("z1 interpolation shows more detailed low-level changes")
SummaryΒΆ
Key Contributions:ΒΆ
Hierarchical latents: Multi-scale representation (z2=coarse, z1=fine)
Conditional priors: \(p(z_l | z_{l+1})\) models dependencies
Better ELBO: KL terms at each layer prevent posterior collapse
Expressive: More flexible than single-layer VAE
Hierarchical VAE Variants:ΒΆ
Ladder VAE (SΓΈnderby et al., 2016): Bottom-up + top-down paths
BIVA (MaalΓΈe et al., 2019): Bidirectional inference
NVAE (Vahdat & Kautz, 2020): Very deep hierarchies
VD-VAE (Child, 2020): Variational diffusion
When to Use:ΒΆ
Complex data (images, audio, video)
Need interpretable hierarchy
Avoid posterior collapse
Multi-resolution generation
Challenges:ΒΆ
Training complexity (multiple KL terms)
Balancing layers (KL annealing)
Computational cost (deeper networks)
Next Steps:ΒΆ
12_vq_vae.ipynb - Discrete hierarchies
14_normalizing_flows.ipynb - Exact likelihood
03_variational_autoencoders_advanced.ipynb - Review VAE basics