Variational Autoencoders (VAEs)ΒΆ

Learning Objectives:

  • Understand variational inference and ELBO

  • Implement reparameterization trick

  • Train VAE for image generation

  • Explore latent space structure

Prerequisites: Deep learning, probability theory, variational inference

Time: 90 minutes

πŸ“š Reference Materials:

1. Latent Variable ModelsΒΆ

The SetupΒΆ

Generative story:

  1. Sample latent code: \(z \sim p(z)\) (prior)

  2. Generate data: \(x \sim p(x|z)\) (likelihood)

Goal: Learn \(p(x) = \int p(x|z)p(z)dz\) (marginal likelihood)

Problem: Integral is intractable for complex \(p(x|z)\)!

Variational ApproachΒΆ

Idea: Introduce variational posterior \(q(z|x) \approx p(z|x)\)

Evidence Lower Bound (ELBO): $\(\log p(x) \geq \mathbb{E}_{z \sim q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) || p(z))\)$

Maximize ELBO = Maximize log-likelihood (approximately)

VAE ArchitectureΒΆ

  • Encoder: \(q_\phi(z|x)\) (approximate posterior)

  • Decoder: \(p_\theta(x|z)\) (likelihood)

  • Prior: \(p(z) = \mathcal{N}(0, I)\) (standard Gaussian)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, kl_divergence
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

2. ELBO DerivationΒΆ

Variational BoundΒΆ

Starting from log-likelihood: $\(\log p(x) = \log \int p(x, z) dz = \log \int q(z|x) \frac{p(x, z)}{q(z|x)} dz\)$

By Jensen’s inequality: $\(\log p(x) \geq \int q(z|x) \log \frac{p(x, z)}{q(z|x)} dz = \mathcal{L}(x; \theta, \phi)\)$

This is the Evidence Lower Bound (ELBO).

ELBO DecompositionΒΆ

\[\mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x, z) - \log q(z|x)]\]

Expand \(p(x, z) = p(x|z)p(z)\):

\[\mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] + \mathbb{E}_{q(z|x)}[\log p(z) - \log q(z|x)]\]
\[\boxed{\mathcal{L} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{D_{KL}(q(z|x) || p(z))}_{\text{KL Regularization}}}\]

InterpretationΒΆ

  1. Reconstruction term: How well decoder reconstructs \(x\) from \(z\)

  2. KL term: How close \(q(z|x)\) is to prior \(p(z)\)

Trade-off: Good reconstruction vs. simple latent distribution

2.5. Advanced ELBO AnalysisΒΆ

Tightness of the BoundΒΆ

The gap between \(\log p(x)\) and ELBO is: $\(\log p(x) - \mathcal{L} = D_{KL}(q(z|x) || p(z|x))\)$

Proof: $\(\log p(x) = \mathbb{E}_{q(z|x)}\left[\log p(x)\right]\)\( \)\(= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{p(z|x)}\right]\)\( \)\(= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{q(z|x)}\right] + \mathbb{E}_{q(z|x)}\left[\log \frac{q(z|x)}{p(z|x)}\right]\)\( \)\(= \mathcal{L} + D_{KL}(q(z|x) || p(z|x))\)$

Implication: ELBO is tight when \(q(z|x) = p(z|x)\) (exact posterior)!

Alternative Formulation: Importance Weighted BoundΒΆ

The standard ELBO uses a single sample. We can improve tightness with Importance Weighted Autoencoder (IWAE):

\[\log p(x) \geq \mathbb{E}_{z_1, \ldots, z_k \sim q(z|x)}\left[\log \frac{1}{k} \sum_{i=1}^k \frac{p(x, z_i)}{q(z_i|x)}\right] = \mathcal{L}_k\]

Properties:

  • \(\mathcal{L}_1\) is standard ELBO

  • \(\mathcal{L}_k \geq \mathcal{L}_1\) (tighter bound)

  • As \(k \to \infty\), \(\mathcal{L}_k \to \log p(x)\)

Beta-VAE: Disentangling Latent FactorsΒΆ

Modify ELBO with hyperparameter \(\beta\): $\(\mathcal{L}_\beta = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) || p(z))\)$

Effect of \(\beta\):

  • \(\beta < 1\): Prioritize reconstruction (sharper images)

  • \(\beta = 1\): Standard VAE

  • \(\beta > 1\): Stronger regularization (more disentangled latent space)

Intuition: Higher \(\beta\) forces encoder to use latent dimensions more independently

Information Theoretic ViewΒΆ

The ELBO can be rewritten as: $\(\mathcal{L} = \underbrace{I_q(x; z)}_{\text{Mutual Information}} - \underbrace{D_{KL}(q(z) || p(z))}_{\text{Marginal KL}}\)$

where:

  • \(I_q(x; z)\) measures how much information \(z\) captures about \(x\)

  • \(D_{KL}(q(z) || p(z))\) encourages marginal \(q(z) = \int q(z|x)p(x)dx\) to match prior

Trade-off: Encode information vs. maintain simple prior

This perspective explains:

  • Rate-distortion trade-off in compression

  • Information bottleneck in representation learning

# Visualize ELBO decomposition and trade-offs

def visualize_elbo_components(recon_losses, kl_losses, betas=[0.5, 1.0, 2.0, 4.0]):
    """
    Visualize how different beta values affect the ELBO components
    
    Demonstrates:
    - Reconstruction vs KL trade-off
    - Effect of beta on latent space regularization
    - Optimal beta selection
    """
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: ELBO components for different betas
    ax = axes[0, 0]
    epochs = np.arange(len(recon_losses))
    
    for beta in betas:
        elbo = -(recon_losses + beta * kl_losses)  # Negative for plotting
        ax.plot(epochs, elbo, label=f'Ξ²={beta}', linewidth=2)
    
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('ELBO (higher is better)', fontsize=11)
    ax.set_title('ELBO vs Beta', fontweight='bold', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Reconstruction loss
    ax = axes[0, 1]
    ax.plot(epochs, recon_losses, color='blue', linewidth=2)
    ax.fill_between(epochs, 0, recon_losses, alpha=0.3, color='blue')
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('Reconstruction Loss', fontsize=11)
    ax.set_title('Reconstruction Term: $E_q[\log p(x|z)]$', fontweight='bold', fontsize=12)
    ax.grid(True, alpha=0.3)
    
    # Plot 3: KL divergence
    ax = axes[1, 0]
    ax.plot(epochs, kl_losses, color='red', linewidth=2)
    ax.fill_between(epochs, 0, kl_losses, alpha=0.3, color='red')
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('KL Divergence', fontsize=11)
    ax.set_title('KL Term: $D_{KL}(q(z|x) || p(z))$', fontweight='bold', fontsize=12)
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Trade-off visualization
    ax = axes[1, 1]
    for beta in betas:
        total = recon_losses + beta * kl_losses
        ax.plot(epochs, total, label=f'Ξ²={beta}', linewidth=2)
    
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('Total Loss', fontsize=11)
    ax.set_title('Total Loss: Recon + Ξ²Β·KL', fontweight='bold', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*70)
    print("BETA-VAE: UNDERSTANDING THE TRADE-OFF")
    print("="*70)
    print("\nReconstruction vs Regularization:")
    print("  β€’ Reconstruction: How accurately decoder reconstructs input")
    print("  β€’ KL Divergence: How close posterior is to prior N(0,I)")
    print("\nEffect of Beta:")
    print("  β€’ Ξ² < 1: Prioritize reconstruction β†’ sharper images")
    print("  β€’ Ξ² = 1: Standard VAE (balanced)")
    print("  β€’ Ξ² > 1: Prioritize regularization β†’ disentangled representations")
    print("\nChoosing Beta:")
    print("  β€’ Image generation: Ξ² ∈ [0.5, 1.0]")
    print("  β€’ Disentanglement: Ξ² ∈ [2.0, 10.0]")
    print("  β€’ Representation learning: Ξ² ∈ [1.0, 4.0]")
    print("="*70)

# Example: Generate synthetic data for demonstration
np.random.seed(42)
n_epochs = 100

# Simulate training dynamics
recon_loss = 150 * np.exp(-np.arange(n_epochs) / 30) + 20
kl_loss = 5 * (1 - np.exp(-np.arange(n_epochs) / 20))

# Add noise
recon_loss += np.random.randn(n_epochs) * 2
kl_loss += np.random.randn(n_epochs) * 0.3

print("Visualizing ELBO components for different beta values...\n")
visualize_elbo_components(recon_loss, kl_loss)

3. The Reparameterization TrickΒΆ

ProblemΒΆ

Need to compute gradient: $\(\nabla_\phi \mathbb{E}_{z \sim q_\phi(z|x)}[f(z)]\)$

Issue: Can’t backpropagate through sampling operation!

Solution: ReparameterizationΒΆ

Key idea: Express \(z\) as deterministic function of \(\phi\) and noise \(\epsilon\):

\[z = g_\phi(\epsilon, x) \quad \text{where} \quad \epsilon \sim p(\epsilon)\]

For Gaussian: If \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x)I)\), then: $\(z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon \quad \text{where} \quad \epsilon \sim \mathcal{N}(0, I)\)$

Now: $\(\nabla_\phi \mathbb{E}_{q_\phi(z|x)}[f(z)] = \nabla_\phi \mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g_\phi(\epsilon, x))] = \mathbb{E}_{\epsilon}[\nabla_\phi f(g_\phi(\epsilon, x))]\)$

Gradient flows through! βœ…

# Visualize reparameterization trick

def visualize_reparameterization():
    """Demonstrate reparameterization for 1D Gaussian"""
    
    # Parameters
    mu = 2.0
    sigma = 1.5
    
    # Sample using reparameterization
    epsilon = np.random.randn(10000)
    z = mu + sigma * epsilon
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Standard normal
    axes[0].hist(epsilon, bins=50, density=True, alpha=0.7, edgecolor='black')
    x_range = np.linspace(-4, 4, 100)
    axes[0].plot(x_range, stats.norm.pdf(x_range), 'r-', linewidth=2, label='$\mathcal{N}(0,1)$')
    axes[0].set_xlabel('$\\epsilon$')
    axes[0].set_ylabel('Density')
    axes[0].set_title('1. Sample $\\epsilon \\sim \mathcal{N}(0, 1)$', fontsize=12, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Transformation
    axes[1].text(0.5, 0.5, '$z = \\mu + \\sigma \\cdot \\epsilon$\n\n$z = 2.0 + 1.5 \\cdot \\epsilon$',
                ha='center', va='center', fontsize=20, 
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
                transform=axes[1].transAxes)
    axes[1].axis('off')
    axes[1].set_title('2. Reparameterize', fontsize=12, fontweight='bold')
    
    # Resulting distribution
    axes[2].hist(z, bins=50, density=True, alpha=0.7, color='green', edgecolor='black')
    z_range = np.linspace(-3, 7, 100)
    axes[2].plot(z_range, stats.norm.pdf(z_range, mu, sigma), 'r-', linewidth=2,
                label=f'$\mathcal{{N}}({mu}, {sigma}^2)$')
    axes[2].set_xlabel('$z$')
    axes[2].set_ylabel('Density')
    axes[2].set_title(f'3. Result: $z \\sim \mathcal{{N}}({mu}, {sigma}^2)$', fontsize=12, fontweight='bold')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("βœ… Reparameterization allows gradients to flow through sampling!")
    print(f"   Mean: {z.mean():.3f} β‰ˆ {mu}")
    print(f"   Std:  {z.std():.3f} β‰ˆ {sigma}")

visualize_reparameterization()

3.5. Advanced: Reparameterization for Other DistributionsΒΆ

General Reparameterization ConditionΒΆ

A distribution \(q_\phi(z)\) admits reparameterization if we can write: $\(z = g_\phi(\epsilon) \quad \text{where} \quad \epsilon \sim p(\epsilon)\)$

and \(g_\phi\) is differentiable w.r.t. \(\phi\)

Common Reparameterizable DistributionsΒΆ

1. Gaussian (already seen): $\(z \sim \mathcal{N}(\mu, \sigma^2) \implies z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0,1)\)$

2. Log-Normal: $\(z \sim \text{LogNormal}(\mu, \sigma^2) \implies z = \exp(\mu + \sigma \cdot \epsilon), \quad \epsilon \sim \mathcal{N}(0,1)\)$

3. Exponential: $\(z \sim \text{Exp}(\lambda) \implies z = -\frac{1}{\lambda}\log(u), \quad u \sim \text{Uniform}(0,1)\)$

4. Gumbel (used in Gumbel-Softmax): $\(z \sim \text{Gumbel}(\mu, \beta) \implies z = \mu - \beta\log(-\log(u)), \quad u \sim \text{Uniform}(0,1)\)$

Distributions That Don’t ReparameterizeΒΆ

Discrete distributions generally don’t admit reparameterization:

  • Bernoulli: \(z \in \{0, 1\}\)

  • Categorical: \(z \in \{1, \ldots, K\}\)

Solutions for discrete:

  1. Gumbel-Softmax: Continuous relaxation of categorical

  2. REINFORCE: Score function estimator (high variance)

  3. Straight-Through Estimator: Biased gradient

Pathwise Derivatives vs Score FunctionΒΆ

Reparameterization (Pathwise): $\(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{p(\epsilon)}[\nabla_\phi f(g_\phi(\epsilon))]\)$

Score Function (REINFORCE): $\(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{q_\phi(z)}[f(z) \nabla_\phi \log q_\phi(z)]\)$

Advantages of Reparameterization:

  • βœ… Lower variance

  • βœ… No baseline needed

  • βœ… Works for continuous variables

Advantages of Score Function:

  • βœ… Works for discrete variables

  • βœ… More general applicability

Multi-Sample Monte Carlo EstimatesΒΆ

For ELBO maximization, use multiple samples to reduce variance:

Single sample: \(\mathcal{L} \approx \log p_\theta(x|z^{(1)}) - D_{KL}(q_\phi(z|x) || p(z))\)

Multiple samples: \(\mathcal{L} \approx \frac{1}{K} \sum_{k=1}^K \log p_\theta(x|z^{(k)}) - D_{KL}(q_\phi(z|x) || p(z))\)

where \(z^{(k)} = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon^{(k)}\), \(\epsilon^{(k)} \sim \mathcal{N}(0, I)\)

Trade-off: Computational cost vs. gradient variance

4. VAE ImplementationΒΆ

ArchitectureΒΆ

Encoder \(q_\phi(z|x)\): $\(q_\phi(z|x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma_\phi^2(x)))\)$

Neural network outputs \(\mu_\phi(x)\) and \(\log \sigma_\phi^2(x)\)

Decoder \(p_\theta(x|z)\): $\(p_\theta(x|z) = \mathcal{N}(x; \mu_\theta(z), \sigma^2 I) \quad \text{or} \quad \text{Bernoulli}(\mu_\theta(z))\)$

Neural network outputs parameters

Loss FunctionΒΆ

\[\mathcal{L} = \frac{1}{N} \sum_{i=1}^N \left[ \mathbb{E}_{q(z|x_i)}[\log p(x_i|z)] - D_{KL}(q(z|x_i) || p(z)) \right]\]

For Gaussian decoder: $\(\log p(x|z) = -\frac{1}{2\sigma^2}||x - \mu_\theta(z)||^2 + \text{const}\)$

For Bernoulli decoder (binary data): $\(\log p(x|z) = \sum_j x_j \log \mu_{\theta,j}(z) + (1-x_j)\log(1-\mu_{\theta,j}(z))\)$

KL divergence (Gaussian β†’ Gaussian): $\(D_{KL}(q(z|x) || p(z)) = \frac{1}{2}\sum_j \left( \mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right)\)$

# VAE implementation

class VAE(nn.Module):
    """Variational Autoencoder"""
    
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x):
        """Encoder: x -> mu, logvar"""
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + std * epsilon"""
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z
    
    def decode(self, z):
        """Decoder: z -> x_reconstructed"""
        h = F.relu(self.fc3(z))
        x_recon = torch.sigmoid(self.fc4(h))  # Bernoulli
        return x_recon
    
    def forward(self, x):
        """Full forward pass"""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar
    
    def loss_function(self, x_recon, x, mu, logvar):
        """Compute VAE loss: -ELBO"""
        # Reconstruction loss (binary cross-entropy)
        BCE = F.binary_cross_entropy(x_recon, x, reduction='sum')
        
        # KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return BCE + KLD, BCE, KLD

# Initialize model
vae = VAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)

print("VAE Architecture:")
print(vae)
print(f"\nLatent dimension: {vae.latent_dim}")
print(f"Total parameters: {sum(p.numel() for p in vae.parameters()):,}")
# Load MNIST data

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Visualize some samples
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'{label}')
    ax.axis('off')
plt.tight_layout()
plt.show()
# Train VAE

def train_vae(model, train_loader, n_epochs=10, lr=1e-3):
    """Train VAE"""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {'loss': [], 'bce': [], 'kld': []}
    
    model.train()
    for epoch in range(n_epochs):
        total_loss, total_bce, total_kld = 0, 0, 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(-1, 784).to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            x_recon, mu, logvar = model(data)
            
            # Compute loss
            loss, bce, kld = model.loss_function(x_recon, data, mu, logvar)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_bce += bce.item()
            total_kld += kld.item()
        
        # Average over epoch
        avg_loss = total_loss / len(train_loader.dataset)
        avg_bce = total_bce / len(train_loader.dataset)
        avg_kld = total_kld / len(train_loader.dataset)
        
        history['loss'].append(avg_loss)
        history['bce'].append(avg_bce)
        history['kld'].append(avg_kld)
        
        print(f"Epoch {epoch+1}/{n_epochs} | "
              f"Loss: {avg_loss:.4f} | BCE: {avg_bce:.4f} | KLD: {avg_kld:.4f}")
    
    return history

print("Training VAE...")
print("="*60)
history = train_vae(vae, train_loader, n_epochs=10, lr=1e-3)
# Visualize training

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Total loss
axes[0].plot(history['loss'], linewidth=2, label='Total Loss (-ELBO)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Components
axes[1].plot(history['bce'], linewidth=2, label='Reconstruction (BCE)', alpha=0.7)
axes[1].plot(history['kld'], linewidth=2, label='KL Divergence', alpha=0.7)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss Component')
axes[1].set_title('Loss Components', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("βœ… VAE training complete!")
# Visualize reconstructions and samples

vae.eval()

# Reconstructions
with torch.no_grad():
    test_data = next(iter(test_loader))[0][:8].view(-1, 784).to(device)
    recon, _, _ = vae(test_data)

fig, axes = plt.subplots(2, 8, figsize=(14, 4))

for i in range(8):
    # Original
    axes[0, i].imshow(test_data[i].cpu().view(28, 28), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original', fontweight='bold')
    
    # Reconstruction
    axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontweight='bold')

plt.tight_layout()
plt.show()

# Generate new samples
with torch.no_grad():
    z = torch.randn(16, vae.latent_dim).to(device)
    samples = vae.decode(z)

fig, axes = plt.subplots(2, 8, figsize=(14, 4))
axes = axes.flatten()

for i, ax in enumerate(axes):
    ax.imshow(samples[i].cpu().view(28, 28), cmap='gray')
    ax.axis('off')

plt.suptitle('Generated Samples from VAE', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("βœ… VAE successfully generates realistic MNIST digits!")

5. Ξ²-VAE and DisentanglementΒΆ

Ξ²-VAE ObjectiveΒΆ

Modify ELBO with weight \(\beta\) on KL term:

\[\mathcal{L}_{\beta} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) || p(z))\]

Effect:

  • \(\beta > 1\): Stronger regularization, more disentangled representations

  • \(\beta < 1\): Weaker regularization, better reconstruction

DisentanglementΒΆ

Goal: Each latent dimension captures independent factor of variation

Example (faces):

  • \(z_1\): Pose

  • \(z_2\): Lighting

  • \(z_3\): Expression

Trade-off: Reconstruction quality vs. disentanglement

6. SummaryΒΆ

VAE FrameworkΒΆ

βœ… Latent variable model: \(p(x) = \int p(x|z)p(z)dz\) βœ… Variational inference: Approximate \(p(z|x)\) with \(q_\phi(z|x)\) βœ… ELBO: Lower bound on log-likelihood

Key ComponentsΒΆ

  1. Encoder: \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x)I)\)

  2. Decoder: \(p_\theta(x|z)\) (Gaussian or Bernoulli)

  3. Reparameterization: \(z = \mu + \sigma \odot \epsilon\)

Loss FunctionΒΆ

\[-\text{ELBO} = \underbrace{\text{Reconstruction Loss}}_{\text{BCE or MSE}} + \underbrace{\text{KL Divergence}}_{\text{Regularization}}\]

ImplementationΒΆ

βœ… Derived ELBO rigorously βœ… Implemented reparameterization trick βœ… Trained VAE on MNIST βœ… Generated new samples βœ… Discussed Ξ²-VAE for disentanglement

Next Notebook: 04_neural_ode.ipynb