Variational Autoencoders (VAEs)ΒΆ
Learning Objectives:
Understand variational inference and ELBO
Implement reparameterization trick
Train VAE for image generation
Explore latent space structure
Prerequisites: Deep learning, probability theory, variational inference
Time: 90 minutes
π Reference Materials:
vae.pdf - VAE theory and mathematical foundations
generative_models.pdf - Overview of generative models including VAEs
1. Latent Variable ModelsΒΆ
The SetupΒΆ
Generative story:
Sample latent code: \(z \sim p(z)\) (prior)
Generate data: \(x \sim p(x|z)\) (likelihood)
Goal: Learn \(p(x) = \int p(x|z)p(z)dz\) (marginal likelihood)
Problem: Integral is intractable for complex \(p(x|z)\)!
Variational ApproachΒΆ
Idea: Introduce variational posterior \(q(z|x) \approx p(z|x)\)
Evidence Lower Bound (ELBO): $\(\log p(x) \geq \mathbb{E}_{z \sim q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) || p(z))\)$
Maximize ELBO = Maximize log-likelihood (approximately)
VAE ArchitectureΒΆ
Encoder: \(q_\phi(z|x)\) (approximate posterior)
Decoder: \(p_\theta(x|z)\) (likelihood)
Prior: \(p(z) = \mathcal{N}(0, I)\) (standard Gaussian)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, kl_divergence
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
2. ELBO DerivationΒΆ
Variational BoundΒΆ
Starting from log-likelihood: $\(\log p(x) = \log \int p(x, z) dz = \log \int q(z|x) \frac{p(x, z)}{q(z|x)} dz\)$
By Jensenβs inequality: $\(\log p(x) \geq \int q(z|x) \log \frac{p(x, z)}{q(z|x)} dz = \mathcal{L}(x; \theta, \phi)\)$
This is the Evidence Lower Bound (ELBO).
ELBO DecompositionΒΆ
Expand \(p(x, z) = p(x|z)p(z)\):
InterpretationΒΆ
Reconstruction term: How well decoder reconstructs \(x\) from \(z\)
KL term: How close \(q(z|x)\) is to prior \(p(z)\)
Trade-off: Good reconstruction vs. simple latent distribution
2.5. Advanced ELBO AnalysisΒΆ
Tightness of the BoundΒΆ
The gap between \(\log p(x)\) and ELBO is: $\(\log p(x) - \mathcal{L} = D_{KL}(q(z|x) || p(z|x))\)$
Proof: $\(\log p(x) = \mathbb{E}_{q(z|x)}\left[\log p(x)\right]\)\( \)\(= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{p(z|x)}\right]\)\( \)\(= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{q(z|x)}\right] + \mathbb{E}_{q(z|x)}\left[\log \frac{q(z|x)}{p(z|x)}\right]\)\( \)\(= \mathcal{L} + D_{KL}(q(z|x) || p(z|x))\)$
Implication: ELBO is tight when \(q(z|x) = p(z|x)\) (exact posterior)!
Alternative Formulation: Importance Weighted BoundΒΆ
The standard ELBO uses a single sample. We can improve tightness with Importance Weighted Autoencoder (IWAE):
Properties:
\(\mathcal{L}_1\) is standard ELBO
\(\mathcal{L}_k \geq \mathcal{L}_1\) (tighter bound)
As \(k \to \infty\), \(\mathcal{L}_k \to \log p(x)\)
Beta-VAE: Disentangling Latent FactorsΒΆ
Modify ELBO with hyperparameter \(\beta\): $\(\mathcal{L}_\beta = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) || p(z))\)$
Effect of \(\beta\):
\(\beta < 1\): Prioritize reconstruction (sharper images)
\(\beta = 1\): Standard VAE
\(\beta > 1\): Stronger regularization (more disentangled latent space)
Intuition: Higher \(\beta\) forces encoder to use latent dimensions more independently
Information Theoretic ViewΒΆ
The ELBO can be rewritten as: $\(\mathcal{L} = \underbrace{I_q(x; z)}_{\text{Mutual Information}} - \underbrace{D_{KL}(q(z) || p(z))}_{\text{Marginal KL}}\)$
where:
\(I_q(x; z)\) measures how much information \(z\) captures about \(x\)
\(D_{KL}(q(z) || p(z))\) encourages marginal \(q(z) = \int q(z|x)p(x)dx\) to match prior
Trade-off: Encode information vs. maintain simple prior
This perspective explains:
Rate-distortion trade-off in compression
Information bottleneck in representation learning
# Visualize ELBO decomposition and trade-offs
def visualize_elbo_components(recon_losses, kl_losses, betas=[0.5, 1.0, 2.0, 4.0]):
"""
Visualize how different beta values affect the ELBO components
Demonstrates:
- Reconstruction vs KL trade-off
- Effect of beta on latent space regularization
- Optimal beta selection
"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: ELBO components for different betas
ax = axes[0, 0]
epochs = np.arange(len(recon_losses))
for beta in betas:
elbo = -(recon_losses + beta * kl_losses) # Negative for plotting
ax.plot(epochs, elbo, label=f'Ξ²={beta}', linewidth=2)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('ELBO (higher is better)', fontsize=11)
ax.set_title('ELBO vs Beta', fontweight='bold', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
# Plot 2: Reconstruction loss
ax = axes[0, 1]
ax.plot(epochs, recon_losses, color='blue', linewidth=2)
ax.fill_between(epochs, 0, recon_losses, alpha=0.3, color='blue')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Reconstruction Loss', fontsize=11)
ax.set_title('Reconstruction Term: $E_q[\log p(x|z)]$', fontweight='bold', fontsize=12)
ax.grid(True, alpha=0.3)
# Plot 3: KL divergence
ax = axes[1, 0]
ax.plot(epochs, kl_losses, color='red', linewidth=2)
ax.fill_between(epochs, 0, kl_losses, alpha=0.3, color='red')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('KL Divergence', fontsize=11)
ax.set_title('KL Term: $D_{KL}(q(z|x) || p(z))$', fontweight='bold', fontsize=12)
ax.grid(True, alpha=0.3)
# Plot 4: Trade-off visualization
ax = axes[1, 1]
for beta in betas:
total = recon_losses + beta * kl_losses
ax.plot(epochs, total, label=f'Ξ²={beta}', linewidth=2)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Total Loss', fontsize=11)
ax.set_title('Total Loss: Recon + Ξ²Β·KL', fontweight='bold', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("BETA-VAE: UNDERSTANDING THE TRADE-OFF")
print("="*70)
print("\nReconstruction vs Regularization:")
print(" β’ Reconstruction: How accurately decoder reconstructs input")
print(" β’ KL Divergence: How close posterior is to prior N(0,I)")
print("\nEffect of Beta:")
print(" β’ Ξ² < 1: Prioritize reconstruction β sharper images")
print(" β’ Ξ² = 1: Standard VAE (balanced)")
print(" β’ Ξ² > 1: Prioritize regularization β disentangled representations")
print("\nChoosing Beta:")
print(" β’ Image generation: Ξ² β [0.5, 1.0]")
print(" β’ Disentanglement: Ξ² β [2.0, 10.0]")
print(" β’ Representation learning: Ξ² β [1.0, 4.0]")
print("="*70)
# Example: Generate synthetic data for demonstration
np.random.seed(42)
n_epochs = 100
# Simulate training dynamics
recon_loss = 150 * np.exp(-np.arange(n_epochs) / 30) + 20
kl_loss = 5 * (1 - np.exp(-np.arange(n_epochs) / 20))
# Add noise
recon_loss += np.random.randn(n_epochs) * 2
kl_loss += np.random.randn(n_epochs) * 0.3
print("Visualizing ELBO components for different beta values...\n")
visualize_elbo_components(recon_loss, kl_loss)
3. The Reparameterization TrickΒΆ
ProblemΒΆ
Need to compute gradient: $\(\nabla_\phi \mathbb{E}_{z \sim q_\phi(z|x)}[f(z)]\)$
Issue: Canβt backpropagate through sampling operation!
Solution: ReparameterizationΒΆ
Key idea: Express \(z\) as deterministic function of \(\phi\) and noise \(\epsilon\):
For Gaussian: If \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x)I)\), then: $\(z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon \quad \text{where} \quad \epsilon \sim \mathcal{N}(0, I)\)$
Now: $\(\nabla_\phi \mathbb{E}_{q_\phi(z|x)}[f(z)] = \nabla_\phi \mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g_\phi(\epsilon, x))] = \mathbb{E}_{\epsilon}[\nabla_\phi f(g_\phi(\epsilon, x))]\)$
Gradient flows through! β
# Visualize reparameterization trick
def visualize_reparameterization():
"""Demonstrate reparameterization for 1D Gaussian"""
# Parameters
mu = 2.0
sigma = 1.5
# Sample using reparameterization
epsilon = np.random.randn(10000)
z = mu + sigma * epsilon
# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Standard normal
axes[0].hist(epsilon, bins=50, density=True, alpha=0.7, edgecolor='black')
x_range = np.linspace(-4, 4, 100)
axes[0].plot(x_range, stats.norm.pdf(x_range), 'r-', linewidth=2, label='$\mathcal{N}(0,1)$')
axes[0].set_xlabel('$\\epsilon$')
axes[0].set_ylabel('Density')
axes[0].set_title('1. Sample $\\epsilon \\sim \mathcal{N}(0, 1)$', fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Transformation
axes[1].text(0.5, 0.5, '$z = \\mu + \\sigma \\cdot \\epsilon$\n\n$z = 2.0 + 1.5 \\cdot \\epsilon$',
ha='center', va='center', fontsize=20,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
transform=axes[1].transAxes)
axes[1].axis('off')
axes[1].set_title('2. Reparameterize', fontsize=12, fontweight='bold')
# Resulting distribution
axes[2].hist(z, bins=50, density=True, alpha=0.7, color='green', edgecolor='black')
z_range = np.linspace(-3, 7, 100)
axes[2].plot(z_range, stats.norm.pdf(z_range, mu, sigma), 'r-', linewidth=2,
label=f'$\mathcal{{N}}({mu}, {sigma}^2)$')
axes[2].set_xlabel('$z$')
axes[2].set_ylabel('Density')
axes[2].set_title(f'3. Result: $z \\sim \mathcal{{N}}({mu}, {sigma}^2)$', fontsize=12, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("β
Reparameterization allows gradients to flow through sampling!")
print(f" Mean: {z.mean():.3f} β {mu}")
print(f" Std: {z.std():.3f} β {sigma}")
visualize_reparameterization()
3.5. Advanced: Reparameterization for Other DistributionsΒΆ
General Reparameterization ConditionΒΆ
A distribution \(q_\phi(z)\) admits reparameterization if we can write: $\(z = g_\phi(\epsilon) \quad \text{where} \quad \epsilon \sim p(\epsilon)\)$
and \(g_\phi\) is differentiable w.r.t. \(\phi\)
Common Reparameterizable DistributionsΒΆ
1. Gaussian (already seen): $\(z \sim \mathcal{N}(\mu, \sigma^2) \implies z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0,1)\)$
2. Log-Normal: $\(z \sim \text{LogNormal}(\mu, \sigma^2) \implies z = \exp(\mu + \sigma \cdot \epsilon), \quad \epsilon \sim \mathcal{N}(0,1)\)$
3. Exponential: $\(z \sim \text{Exp}(\lambda) \implies z = -\frac{1}{\lambda}\log(u), \quad u \sim \text{Uniform}(0,1)\)$
4. Gumbel (used in Gumbel-Softmax): $\(z \sim \text{Gumbel}(\mu, \beta) \implies z = \mu - \beta\log(-\log(u)), \quad u \sim \text{Uniform}(0,1)\)$
Distributions That Donβt ReparameterizeΒΆ
Discrete distributions generally donβt admit reparameterization:
Bernoulli: \(z \in \{0, 1\}\)
Categorical: \(z \in \{1, \ldots, K\}\)
Solutions for discrete:
Gumbel-Softmax: Continuous relaxation of categorical
REINFORCE: Score function estimator (high variance)
Straight-Through Estimator: Biased gradient
Pathwise Derivatives vs Score FunctionΒΆ
Reparameterization (Pathwise): $\(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{p(\epsilon)}[\nabla_\phi f(g_\phi(\epsilon))]\)$
Score Function (REINFORCE): $\(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{q_\phi(z)}[f(z) \nabla_\phi \log q_\phi(z)]\)$
Advantages of Reparameterization:
β Lower variance
β No baseline needed
β Works for continuous variables
Advantages of Score Function:
β Works for discrete variables
β More general applicability
Multi-Sample Monte Carlo EstimatesΒΆ
For ELBO maximization, use multiple samples to reduce variance:
Single sample: \(\mathcal{L} \approx \log p_\theta(x|z^{(1)}) - D_{KL}(q_\phi(z|x) || p(z))\)
Multiple samples: \(\mathcal{L} \approx \frac{1}{K} \sum_{k=1}^K \log p_\theta(x|z^{(k)}) - D_{KL}(q_\phi(z|x) || p(z))\)
where \(z^{(k)} = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon^{(k)}\), \(\epsilon^{(k)} \sim \mathcal{N}(0, I)\)
Trade-off: Computational cost vs. gradient variance
4. VAE ImplementationΒΆ
ArchitectureΒΆ
Encoder \(q_\phi(z|x)\): $\(q_\phi(z|x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma_\phi^2(x)))\)$
Neural network outputs \(\mu_\phi(x)\) and \(\log \sigma_\phi^2(x)\)
Decoder \(p_\theta(x|z)\): $\(p_\theta(x|z) = \mathcal{N}(x; \mu_\theta(z), \sigma^2 I) \quad \text{or} \quad \text{Bernoulli}(\mu_\theta(z))\)$
Neural network outputs parameters
Loss FunctionΒΆ
For Gaussian decoder: $\(\log p(x|z) = -\frac{1}{2\sigma^2}||x - \mu_\theta(z)||^2 + \text{const}\)$
For Bernoulli decoder (binary data): $\(\log p(x|z) = \sum_j x_j \log \mu_{\theta,j}(z) + (1-x_j)\log(1-\mu_{\theta,j}(z))\)$
KL divergence (Gaussian β Gaussian): $\(D_{KL}(q(z|x) || p(z)) = \frac{1}{2}\sum_j \left( \mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right)\)$
# VAE implementation
class VAE(nn.Module):
"""Variational Autoencoder"""
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
"""Encoder: x -> mu, logvar"""
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
"""Reparameterization trick: z = mu + std * epsilon"""
std = torch.exp(0.5 * logvar)
epsilon = torch.randn_like(std)
z = mu + std * epsilon
return z
def decode(self, z):
"""Decoder: z -> x_reconstructed"""
h = F.relu(self.fc3(z))
x_recon = torch.sigmoid(self.fc4(h)) # Bernoulli
return x_recon
def forward(self, x):
"""Full forward pass"""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
def loss_function(self, x_recon, x, mu, logvar):
"""Compute VAE loss: -ELBO"""
# Reconstruction loss (binary cross-entropy)
BCE = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD, BCE, KLD
# Initialize model
vae = VAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)
print("VAE Architecture:")
print(vae)
print(f"\nLatent dimension: {vae.latent_dim}")
print(f"Total parameters: {sum(p.numel() for p in vae.parameters()):,}")
# Load MNIST data
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
# Visualize some samples
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i, ax in enumerate(axes.flat):
img, label = train_dataset[i]
ax.imshow(img.squeeze(), cmap='gray')
ax.set_title(f'{label}')
ax.axis('off')
plt.tight_layout()
plt.show()
# Train VAE
def train_vae(model, train_loader, n_epochs=10, lr=1e-3):
"""Train VAE"""
optimizer = optim.Adam(model.parameters(), lr=lr)
history = {'loss': [], 'bce': [], 'kld': []}
model.train()
for epoch in range(n_epochs):
total_loss, total_bce, total_kld = 0, 0, 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, 784).to(device)
optimizer.zero_grad()
# Forward pass
x_recon, mu, logvar = model(data)
# Compute loss
loss, bce, kld = model.loss_function(x_recon, data, mu, logvar)
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
total_bce += bce.item()
total_kld += kld.item()
# Average over epoch
avg_loss = total_loss / len(train_loader.dataset)
avg_bce = total_bce / len(train_loader.dataset)
avg_kld = total_kld / len(train_loader.dataset)
history['loss'].append(avg_loss)
history['bce'].append(avg_bce)
history['kld'].append(avg_kld)
print(f"Epoch {epoch+1}/{n_epochs} | "
f"Loss: {avg_loss:.4f} | BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}")
return history
print("Training VAE...")
print("="*60)
history = train_vae(vae, train_loader, n_epochs=10, lr=1e-3)
# Visualize training
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
# Total loss
axes[0].plot(history['loss'], linewidth=2, label='Total Loss (-ELBO)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Components
axes[1].plot(history['bce'], linewidth=2, label='Reconstruction (BCE)', alpha=0.7)
axes[1].plot(history['kld'], linewidth=2, label='KL Divergence', alpha=0.7)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss Component')
axes[1].set_title('Loss Components', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("β
VAE training complete!")
# Visualize reconstructions and samples
vae.eval()
# Reconstructions
with torch.no_grad():
test_data = next(iter(test_loader))[0][:8].view(-1, 784).to(device)
recon, _, _ = vae(test_data)
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for i in range(8):
# Original
axes[0, i].imshow(test_data[i].cpu().view(28, 28), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_title('Original', fontweight='bold')
# Reconstruction
axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_title('Reconstructed', fontweight='bold')
plt.tight_layout()
plt.show()
# Generate new samples
with torch.no_grad():
z = torch.randn(16, vae.latent_dim).to(device)
samples = vae.decode(z)
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
axes = axes.flatten()
for i, ax in enumerate(axes):
ax.imshow(samples[i].cpu().view(28, 28), cmap='gray')
ax.axis('off')
plt.suptitle('Generated Samples from VAE', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("β
VAE successfully generates realistic MNIST digits!")
5. Ξ²-VAE and DisentanglementΒΆ
Ξ²-VAE ObjectiveΒΆ
Modify ELBO with weight \(\beta\) on KL term:
Effect:
\(\beta > 1\): Stronger regularization, more disentangled representations
\(\beta < 1\): Weaker regularization, better reconstruction
DisentanglementΒΆ
Goal: Each latent dimension captures independent factor of variation
Example (faces):
\(z_1\): Pose
\(z_2\): Lighting
\(z_3\): Expression
Trade-off: Reconstruction quality vs. disentanglement
6. SummaryΒΆ
VAE FrameworkΒΆ
β Latent variable model: \(p(x) = \int p(x|z)p(z)dz\) β Variational inference: Approximate \(p(z|x)\) with \(q_\phi(z|x)\) β ELBO: Lower bound on log-likelihood
Key ComponentsΒΆ
Encoder: \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x)I)\)
Decoder: \(p_\theta(x|z)\) (Gaussian or Bernoulli)
Reparameterization: \(z = \mu + \sigma \odot \epsilon\)
Loss FunctionΒΆ
ImplementationΒΆ
β Derived ELBO rigorously β Implemented reparameterization trick β Trained VAE on MNIST β Generated new samples β Discussed Ξ²-VAE for disentanglement
Next Notebook: 04_neural_ode.ipynb