import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. Motivation: Iterative RefinementΒΆ
Traditional Generative ModelsΒΆ
VAE: Single-step generation, blurry
GAN: Unstable training, mode collapse
Flow: Architectural constraints
Diffusion ModelsΒΆ
Key idea: Gradually denoise pure noise into data!
Advantages:
High-quality samples
Stable training (no adversarial)
Tractable likelihood
Disadvantage: Slow sampling (many steps)
Two ProcessesΒΆ
Forward (diffusion): Add noise progressively $\(q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I)\)$
Reverse (denoising): Learn to remove noise $\(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$
π Reference Materials:
generative_models.pdf - Generative Models
2. Forward Diffusion ProcessΒΆ
Markov ChainΒΆ
Starting from data \(x_0 \sim q(x_0)\): $\(q(x_{1:T} | x_0) = \prod_{t=1}^T q(x_t | x_{t-1})\)$
where: $\(q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I)\)$
Variance ScheduleΒΆ
Typical: \(\beta_t \in [10^{-4}, 0.02]\), \(T = 1000\)
Closed-Form SamplingΒΆ
Define \(\alpha_t = 1 - \beta_t\) and \(\bar{\alpha}_t = \prod_{s=1}^t \alpha_s\):
Reparameterization: $\(x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\)$
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""Linear variance schedule."""
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
"""Cosine schedule (improved)."""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
# Compute schedule quantities
T = 1000
betas = linear_beta_schedule(T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Visualize schedules
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
axes[0].plot(betas.numpy(), linewidth=2)
axes[0].set_xlabel('Timestep t', fontsize=11)
axes[0].set_ylabel('Ξ²β', fontsize=11)
axes[0].set_title('Variance Schedule', fontsize=12)
axes[0].grid(True, alpha=0.3)
axes[1].plot(alphas_cumprod.numpy(), linewidth=2)
axes[1].set_xlabel('Timestep t', fontsize=11)
axes[1].set_ylabel('αΎ±β = βΞ±β', fontsize=11)
axes[1].set_title('Cumulative Product', fontsize=12)
axes[1].grid(True, alpha=0.3)
axes[2].plot(torch.sqrt(1 - alphas_cumprod).numpy(), linewidth=2)
axes[2].set_xlabel('Timestep t', fontsize=11)
axes[2].set_ylabel('β(1-αΎ±β) (noise scale)', fontsize=11)
axes[2].set_title('Noise Level', fontsize=12)
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Visualize forward process
def forward_diffusion_sample(x0, t, alphas_cumprod):
"""Sample x_t from q(x_t | x_0)."""
noise = torch.randn_like(x0)
sqrt_alpha_bar = torch.sqrt(alphas_cumprod[t])[:, None, None, None]
sqrt_one_minus_alpha_bar = torch.sqrt(1 - alphas_cumprod[t])[:, None, None, None]
return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise, noise
# Load sample image
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
x0 = mnist[0][0].unsqueeze(0).to(device)
# Sample at different timesteps
timesteps_to_show = [0, 50, 200, 500, 999]
fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3))
for idx, t in enumerate(timesteps_to_show):
if t == 0:
img = x0[0, 0].cpu()
else:
t_tensor = torch.tensor([t]).to(device)
noisy, _ = forward_diffusion_sample(x0, t_tensor, alphas_cumprod.to(device))
img = noisy[0, 0].cpu()
axes[idx].imshow(img, cmap='gray')
axes[idx].set_title(f't = {t}', fontsize=11)
axes[idx].axis('off')
plt.suptitle('Forward Diffusion Process', fontsize=13, y=1.02)
plt.tight_layout()
plt.show()
3. Reverse Process & TrainingΒΆ
GoalΒΆ
Learn \(p_\theta(x_{t-1} | x_t)\) to reverse diffusion.
ParameterizationΒΆ
Predict the noise \(\epsilon\) added at step \(t\): $\(\epsilon_\theta(x_t, t) \approx \epsilon\)$
Then: $\(x_0 \approx \frac{x_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}\)$
Loss FunctionΒΆ
Simplified objective (Ho et al., 2020): $\(L_{simple} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]\)$
where \(x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon\).
class SinusoidalPositionEmbeddings(nn.Module):
"""Time step embeddings."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = np.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class SimpleUNet(nn.Module):
"""Simplified U-Net for MNIST."""
def __init__(self, time_dim=32):
super().__init__()
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_dim),
nn.Linear(time_dim, time_dim),
nn.ReLU()
)
# Encoder
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
# Decoder
self.conv4 = nn.Conv2d(128, 64, 3, padding=1)
self.conv5 = nn.Conv2d(64, 32, 3, padding=1)
self.conv6 = nn.Conv2d(32, 1, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x, t):
# Time embedding
t_emb = self.time_mlp(t)
# Encoder
x1 = F.relu(self.conv1(x))
x2 = self.pool(F.relu(self.conv2(x1)))
x3 = self.pool(F.relu(self.conv3(x2)))
# Decoder
x = self.upsample(x3)
x = F.relu(self.conv4(x))
x = self.upsample(x)
x = F.relu(self.conv5(x))
x = self.conv6(x)
return x
model = SimpleUNet().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Training
train_loader = DataLoader(mnist, batch_size=128, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 5
losses = []
alphas_cumprod_cuda = alphas_cumprod.to(device)
for epoch in range(n_epochs):
epoch_loss = 0
for batch_idx, (x0, _) in enumerate(train_loader):
x0 = x0.to(device)
batch_size = x0.shape[0]
# Sample random timesteps
t = torch.randint(0, T, (batch_size,), device=device).long()
# Sample noise
noise = torch.randn_like(x0)
# Forward diffusion
x_t, _ = forward_diffusion_sample(x0, t, alphas_cumprod_cuda)
# Predict noise
predicted_noise = model(x_t, t)
# Loss
loss = F.mse_loss(predicted_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}/{n_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
losses.append(epoch_loss / len(train_loader))
print(f"Epoch {epoch+1} average loss: {losses[-1]:.4f}")
print("\nTraining complete!")
4. Sampling (Reverse Process)ΒΆ
AlgorithmΒΆ
Start from pure noise \(x_T \sim \mathcal{N}(0, I)\):
For \(t = T, T-1, \ldots, 1\): $\(x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right) + \sigma_t z\)$
where \(z \sim \mathcal{N}(0, I)\) and \(\sigma_t = \sqrt{\beta_t}\).
@torch.no_grad()
def sample(model, n_samples=16, img_size=28):
"""Generate samples via reverse diffusion."""
model.eval()
# Start from noise
x = torch.randn(n_samples, 1, img_size, img_size).to(device)
# Reverse process
for t in reversed(range(T)):
t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
# Predict noise
predicted_noise = model(x, t_tensor)
# Compute coefficients
alpha = alphas[t].to(device)
alpha_bar = alphas_cumprod[t].to(device)
beta = betas[t].to(device)
# Denoising step
if t > 0:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
x = x + torch.sqrt(beta) * noise
return x
# Generate samples
samples = sample(model, n_samples=16)
# Visualize
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i in range(16):
ax = axes[i // 4, i % 4]
ax.imshow(samples[i, 0].cpu().numpy(), cmap='gray')
ax.axis('off')
plt.suptitle('Generated Samples (DDPM)', fontsize=14)
plt.tight_layout()
plt.show()
SummaryΒΆ
Key Components:ΒΆ
Forward process: \(q(x_t|x_{t-1}) = \mathcal{N}(\sqrt{1-\beta_t}x_{t-1}, \beta_t I)\)
Reverse process: \(p_\theta(x_{t-1}|x_t)\) learned via neural network
Training: Predict noise \(\epsilon_\theta(x_t, t) \approx \epsilon\)
Sampling: Iterative denoising from \(x_T \sim \mathcal{N}(0, I)\)
Advantages:ΒΆ
High-quality generation
Stable training (no adversarial)
Flexible architectures
Tractable likelihood
Modern Variants:ΒΆ
DDIM: Deterministic, faster sampling
Score-based models: Connection via score matching
Latent diffusion: Operate in compressed space (Stable Diffusion)
Conditional: Text-to-image (DALL-E 2, Imagen)
Applications:ΒΆ
Image synthesis (Stable Diffusion)
Text-to-image (DALL-E 2)
Audio generation
Video synthesis
Molecular design
Next Steps:ΒΆ
10_normalizing_flows.ipynb - Alternative exact likelihood model
03_variational_autoencoders_advanced.ipynb - Compare with VAEs
Explore latent diffusion for efficiency
Advanced Diffusion Models TheoryΒΆ
1. Mathematical FoundationsΒΆ
Denoising Diffusion Probabilistic Models (DDPM) (Ho et al., 2020)
The complete generative process defines a hierarchical VAE where latents have the same dimensionality as the data:
Joint Distribution: $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1}|x_t)\)$
where \(p(x_T) = \mathcal{N}(x_T; 0, I)\) and: $\(p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$
Variational Lower Bound (VLB): $\(\mathcal{L}_{VLB} = \mathbb{E}_q\left[-\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\right]\)$
Decomposing into terms: $\(\mathcal{L}_{VLB} = \underbrace{D_{KL}(q(x_T|x_0) \| p(x_T))}_{L_T} + \sum_{t>1}\underbrace{D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t))}_{L_{t-1}} + \underbrace{-\log p_\theta(x_0|x_1)}_{L_0}\)$
Posterior Distribution:
The true posterior (tractable due to Gaussian assumptions): $\(q(x_{t-1}|x_t, x_0) = \mathcal{N}\left(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I\right)\)$
where: $\(\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t\)$
Simplified Training Objective:
Instead of optimizing the full VLB, Ho et al. showed that a weighted variant works better: $\(\mathcal{L}_{simple} = \mathbb{E}_{t \sim U(1,T), x_0 \sim q(x_0), \epsilon \sim \mathcal{N}(0,I)}\left[\|\epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon, t)\|^2\right]\)$
This is equivalent to denoising score matching with specific weights.
2. Denoising Diffusion Implicit Models (DDIM)ΒΆ
Motivation: DDPM requires \(T\) (typically 1000) sampling steps β slow inference.
Key Insight: (Song et al., 2021)
DDPM defines a Markovian forward process, but the reverse process doesnβt need to be Markovian! DDIM uses a non-Markovian forward process with the same marginals: $\(q_\sigma(x_{1:T}|x_0) = q_\sigma(x_T|x_0) \prod_{t=2}^T q_\sigma(x_{t-1}|x_t, x_0)\)$
Generalized Reverse Process: $\(x_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\underbrace{\left(\frac{x_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}\right)}_{\text{predicted } x_0} + \underbrace{\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \cdot \epsilon_\theta(x_t, t)}_{\text{direction pointing to } x_t} + \underbrace{\sigma_t \epsilon_t}_{\text{random noise}}\)$
where \(\epsilon_t \sim \mathcal{N}(0, I)\).
Noise Schedule Parameter: $\(\sigma_t = \eta \cdot \sqrt{\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}} \cdot \sqrt{1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}}\)$
\(\eta = 1\): DDPM (stochastic)
\(\eta = 0\): DDIM (deterministic)
Acceleration:
DDIM allows skipping timesteps! Sample subsequence \(\tau = [\tau_1, \ldots, \tau_S]\) where \(S \ll T\): $\(x_{\tau_{i-1}} = \sqrt{\bar{\alpha}_{\tau_{i-1}}}\hat{x}_0 + \sqrt{1-\bar{\alpha}_{\tau_{i-1}}}\epsilon_\theta(x_{\tau_i}, \tau_i)\)$
Typical: \(S = 50\) gives 20Γ speedup with minimal quality loss.
3. Score-Based Generative ModelsΒΆ
Connection to Score Matching (Song & Ermon, 2019; Song et al., 2021)
Score Function: $\(s_\theta(x, t) = \nabla_x \log p_t(x)\)$
The gradient of the log density points toward higher density regions.
Equivalence to Diffusion:
The noise prediction network \(\epsilon_\theta(x_t, t)\) is related to the score: $\(\nabla_{x_t} \log p_t(x_t) = -\frac{1}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\)$
Therefore: $\(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\sqrt{1-\bar{\alpha}_t}}\)$
Denoising Score Matching:
Training objective becomes: $\(\mathcal{L}_{DSM} = \mathbb{E}_{t,x_0,x_t}\left[\lambda(t)\|\nabla_{x_t}\log q(x_t|x_0) - s_\theta(x_t, t)\|^2\right]\)$
where \(\lambda(t) = 1-\bar{\alpha}_t\) recovers the simplified objective.
Stochastic Differential Equations (SDE):
Forward process as continuous-time SDE: $\(dx = f(x,t)dt + g(t)dw\)$
where \(f(x,t)\) is drift, \(g(t)\) is diffusion coefficient, \(w\) is Wiener process.
Variance Preserving (VP) SDE: $\(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw\)$
Variance Exploding (VE) SDE: $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}}dw\)$
Reverse-time SDE: $\(dx = \left[f(x,t) - g(t)^2 \nabla_x \log p_t(x)\right]dt + g(t)d\bar{w}\)$
Sampling: Solve reverse SDE using learned score \(s_\theta(x,t)\).
Probability Flow ODE:
Deterministic counterpart (same marginals): $\(dx = \left[f(x,t) - \frac{1}{2}g(t)^2\nabla_x \log p_t(x)\right]dt\)$
Advantages:
Exact likelihood computation via change of variables
Faster sampling with adaptive ODE solvers
Enables manipulation in latent space
4. Conditional Generation & GuidanceΒΆ
Classifier Guidance (Dhariwal & Nichol, 2021)
Guide diffusion toward a target class \(y\) using a classifier \(p_\phi(y|x_t, t)\): $\(\tilde{\epsilon}_\theta(x_t, t, y) = \epsilon_\theta(x_t, t) - \sqrt{1-\bar{\alpha}_t} \cdot w \cdot \nabla_{x_t} \log p_\phi(y|x_t, t)\)$
where \(w\) is guidance scale (typically \(w \in [1, 10]\)).
Derivation: Modified score with Bayes rule: $\(\nabla_{x_t} \log p(x_t|y) = \nabla_{x_t}\log p(x_t) + \nabla_{x_t}\log p(y|x_t)\)$
Classifier-Free Guidance (Ho & Salimans, 2022)
Problem: Classifier guidance requires training separate classifier β expensive, may hurt diversity.
Solution: Train single conditional model \(\epsilon_\theta(x_t, t, c)\) where \(c\) is condition (text, class, etc.): $\(\tilde{\epsilon}_\theta(x_t, t, c) = (1+w)\epsilon_\theta(x_t, t, c) - w\epsilon_\theta(x_t, t, \emptyset)\)$
where \(\emptyset\) is null condition (unconditional).
Training: Randomly drop condition with probability \(p_{uncond}\) (typically 0.1): $\(c' = \begin{cases} \emptyset & \text{with prob. } p_{uncond}\\ c & \text{otherwise} \end{cases}\)$
Guidance Scale \(w\):
\(w = 0\): Unconditional
\(w > 0\): Stronger conditioning, less diversity
Typical: \(w \in [1, 10]\) for images, \(w \in [5, 20]\) for text-to-image
Advantages:
No separate classifier needed
Better sample quality at high guidance
Simpler training pipeline
5. Latent Diffusion Models (Stable Diffusion)ΒΆ
Motivation (Rombach et al., 2022)
Diffusion in pixel space is computationally expensive:
High resolution: \(1024 \times 1024 \times 3 \approx 3M\) dimensions
Memory: Store activations for \(T\) steps
Time: \(T\) forward passes through large U-Net
Key Idea: Perform diffusion in compressed latent space.
Architecture:
Encoder: \(\mathcal{E}: \mathbb{R}^{H \times W \times 3} \to \mathbb{R}^{h \times w \times c}\)
Typically trained VAE or VQ-VAE
Compression factor: \(f = H/h = 8\) (Stable Diffusion)
Diffusion Model: \(\epsilon_\theta(z_t, t, \tau_\theta(y))\)
Operates on latent \(z\) instead of \(x\)
\(\tau_\theta(y)\) is conditioning (e.g., CLIP text embeddings)
Decoder: \(\mathcal{D}: \mathbb{R}^{h \times w \times c} \to \mathbb{R}^{H \times W \times 3}\)
Reconstruct image from latent
Training: $\(\mathcal{L}_{LDM} = \mathbb{E}_{\mathcal{E}(x), \epsilon, t, c}\left[\|\epsilon - \epsilon_\theta(z_t, t, \tau_\theta(c))\|^2\right]\)$
where \(z = \mathcal{E}(x)\).
Computational Benefits:
Compression: \(f^2 = 64\) reduction in spatial dimensions
Memory: Process \(512^2\) image in latent space of \(64^2\)
Speed: 10-100Γ faster than pixel-space diffusion
Cross-Attention Conditioning:
For text-to-image, condition via cross-attention in U-Net: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)$
where:
\(Q = W_Q \cdot \phi(z_t)\): Query from latent features
\(K = W_K \cdot \tau_\theta(y)\): Keys from text embeddings
\(V = W_V \cdot \tau_\theta(y)\): Values from text embeddings
This allows spatially-adaptive conditioning based on text.
Text Encoder:
Stable Diffusion uses CLIP (Contrastive Language-Image Pre-training): $\(\tau_\theta(y) = \text{CLIP-TextEncoder}(y) \in \mathbb{R}^{77 \times 768}\)$
6. Advanced Sampling TechniquesΒΆ
Ancestral Sampling (original DDPM): $\(x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right) + \sigma_t z_t\)$
DDIM Deterministic: $\(x_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_0(x_t, t) + \sqrt{1-\bar{\alpha}_{t-1}}\epsilon_\theta(x_t, t)\)$
DPM-Solver (Lu et al., 2022):
Fast high-order ODE solver for diffusion ODEs
10-20 steps with DDIM-quality
Up to 3Γ faster than DDIM at same quality
Euler Solver: $\(x_{t-1} = x_t + (t-1 - t) \cdot \frac{dx}{dt}\bigg|_t\)$
where \(\frac{dx}{dt}\) from probability flow ODE.
Heunβs Method (2nd order): $\(\begin{align} k_1 &= \frac{dx}{dt}\bigg|_t\\ k_2 &= \frac{dx}{dt}\bigg|_{t + \Delta t, x_t + \Delta t \cdot k_1}\\ x_{t-1} &= x_t + \frac{\Delta t}{2}(k_1 + k_2) \end{align}\)$
Exponential Integrator:
Exact integration for linear part, approximation for nonlinear: $\(x_{t-1} = e^{-\int_{t}^{t-1}\beta(s)ds/2}x_t + \int_{t-1}^t e^{-\int_{s}^{t-1}\beta(r)dr/2}\sqrt{\beta(s)}\epsilon_\theta(x_s, s)ds\)$
7. Training ImprovementsΒΆ
\(v\)-Prediction (Salimans & Ho, 2022):
Instead of predicting noise \(\epsilon\) or \(x_0\), predict velocity: $\(v = \sqrt{\bar{\alpha}_t}\epsilon - \sqrt{1-\bar{\alpha}_t}x_0\)$
Recovery: $\(x_0 = \sqrt{\bar{\alpha}_t}x_t - \sqrt{1-\bar{\alpha}_t}v_\theta(x_t, t)\)\( \)\(\epsilon = \sqrt{\bar{\alpha}_t}v_\theta(x_t, t) + \sqrt{1-\bar{\alpha}_t}x_t\)$
Benefits:
Better numerical stability
Improved sample quality at low noise levels
Symmetric treatment of signal and noise
Progressive Distillation (Salimans & Ho, 2022):
Distill 1000-step model into 500-step, then 250-step, etc.: $\(\mathcal{L}_{distill} = \mathbb{E}\left[\|x_0^{student}(x_t) - x_0^{teacher}(x_{t+\Delta t})\|^2\right]\)$
Achieves 4-8 steps with minimal quality loss.
Min-SNR Weighting (Hang et al., 2023):
Standard loss weights all timesteps equally, but early steps (small \(t\), high noise) dominate gradient.
Solution: Cap weight by minimum SNR: $\(w(t) = \min\left(\text{SNR}(t), \gamma\right)\)$
where \(\text{SNR}(t) = \bar{\alpha}_t/(1-\bar{\alpha}_t)\) and \(\gamma = 5\) (typical).
Offset Noise (Stability AI):
Original noise \(\epsilon \sim \mathcal{N}(0, I)\) canβt produce fully black/white images.
Solution: Add small offset: $\(\epsilon = \epsilon_{pixel} + 0.1 \cdot \epsilon_{global}\)$
where \(\epsilon_{global} \sim \mathcal{N}(0, 1)\) (shared across pixels).
8. Evaluation MetricsΒΆ
FrΓ©chet Inception Distance (FID): $\(\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2})\)$
where \((\mu_r, \Sigma_r)\) and \((\mu_g, \Sigma_g)\) are mean and covariance of Inception-v3 features for real and generated images.
Lower is better. State-of-the-art: FID < 2 on ImageNet 256Γ256.
Inception Score (IS): $\(\text{IS} = \exp\left(\mathbb{E}_x\left[D_{KL}(p(y|x) \| p(y))\right]\right)\)$
Measures diversity (marginal \(p(y)\) uniform) and quality (conditional \(p(y|x)\) peaked).
CLIP Score:
For text-to-image, measure alignment: $\(\text{CLIP-Score} = \mathbb{E}_{x,c}\left[\text{cos-sim}(\text{CLIP}_{img}(x), \text{CLIP}_{text}(c))\right]\)$
Precision and Recall:
Precision: Fraction of generated samples in real data manifold (quality)
Recall: Fraction of real data covered by generated samples (diversity)
9. State-of-the-Art ModelsΒΆ
Imagen (Saharia et al., 2022, Google):
Cascade of diffusion models: \(64^2 \to 256^2 \to 1024^2\)
T5 text encoder (11B parameters)
Dynamic thresholding for stability
FID = 7.27 on COCO
DALL-E 2 (Ramesh et al., 2022, OpenAI):
CLIP latent space diffusion
Prior: Text β CLIP image embedding (autoregressive or diffusion)
Decoder: CLIP embedding β image (diffusion)
Stable Diffusion (Rombach et al., 2022, Stability AI):
Latent diffusion with \(f=8\) compression
U-Net with cross-attention (CLIP text)
860M parameters
Open-source, runs on consumer GPUs
ImageGen variants (eDiff-I, Muse):
Expert denoisers for different noise levels
Parallel decoding for speed
Super-resolution cascade
10. Advanced ApplicationsΒΆ
Inpainting:
Condition on masked image: $\(x_t^{known} = \sqrt{\bar{\alpha}_t}x_0^{known} + \sqrt{1-\bar{\alpha}_t}\epsilon\)$
Replace known regions at each step: $\(x_t = m \odot x_t^{known} + (1-m) \odot x_t^{generated}\)$
Image-to-Image (SDEdit):
Start from noisy input \(x_t\) (where \(t < T\)) instead of pure noise:
Forward: \(x_0 \to x_t\)
Reverse: \(x_t \to x_0'\)
Controls strength: larger \(t\) = more change.
ControlNet (Zhang et al., 2023):
Add spatial conditioning (edges, pose, depth):
Clone U-Net encoder
Zero-initialized 1Γ1 convs for injection
Preserves pre-trained weights
Video Diffusion:
Extend to 3D U-Net (temporal dimension): $\(\epsilon_\theta(x_t^{1:F}, t)\)$
Challenges:
Temporal consistency
Memory for \(F\) frames
Solutions: Latent diffusion, sparse attention, autoregressive
3D Generation:
Point cloud diffusion
Neural radiance field (NeRF) diffusion
Multi-view consistent generation
11. Theoretical InsightsΒΆ
Score Matching Equivalence:
Denoising is equivalent to score matching with specific noise schedule: $\(\nabla_x \log p_\sigma(x) = -\frac{1}{\sigma^2}(x - \mathbb{E}[x_0|x])\)$
Optimal Transport:
Diffusion implicitly solves optimal transport from data to noise distribution.
Connection to Energy-Based Models:
Score \(\nabla_x \log p(x) = -\nabla_x E(x)\) where \(E(x)\) is energy.
Manifold Hypothesis:
Data lies on low-dimensional manifold. Diffusion gradually adds noise until distribution fills ambient space, then reverses while staying near manifold.
Rate-Distortion Trade-off:
Number of steps \(T\) vs. sample quality:
Larger \(T\): Better approximation to continuous process
Smaller \(T\): Faster but coarser discretization
DDIM: Quality-speed trade-off via \(\eta\)
12. Practical ConsiderationsΒΆ
Hyperparameters:
Parameter |
Typical Value |
Notes |
|---|---|---|
\(T\) |
1000 |
DDPM, can reduce with DDIM |
\(\beta_{\min}\) |
\(10^{-4}\) |
Start of noise schedule |
\(\beta_{\max}\) |
0.02 |
End of noise schedule |
Learning rate |
\(2 \times 10^{-4}\) |
Adam optimizer |
Batch size |
128-2048 |
Depends on GPU memory |
EMA decay |
0.9999 |
For parameter averaging |
Guidance \(w\) |
1-10 (image), 5-20 (text) |
Trade-off quality/diversity |
Noise Schedules:
Linear: Simple but suboptimal
Cosine: Better for high resolution
Learned: Can adapt to data
Computational Requirements:
Training Stable Diffusion:
Dataset: LAION-5B (filtered to ~2B)
GPUs: 256Γ A100 (40GB)
Time: ~150,000 GPU-hours
Cost: ~$600k at cloud prices
Inference:
DDPM (1000 steps): ~10s per image (RTX 3090)
DDIM (50 steps): ~2s per image
DPM-Solver (20 steps): ~1s per image
13. Limitations & Open ProblemsΒΆ
Slow Sampling:
1000 steps impractical for real-time
Solutions: DDIM, distillation, faster solvers
Open: Single-step diffusion?
Mode Coverage:
Better than GANs but still challenges
Multi-modal distributions difficult
Open: Provable coverage guarantees?
Controllability:
Guidance helps but not fine-grained
Open: Precise spatial/semantic control?
3D and Video:
Computational challenges
Temporal consistency
Open: Efficient architectures?
Theoretical Understanding:
Why do diffusion models work so well?
Connection to other generative models?
Open: Formal convergence guarantees?
14. Key Papers (Chronological)ΒΆ
Sohl-Dickstein et al., 2015: βDeep Unsupervised Learning using Nonequilibrium Thermodynamicsβ (original diffusion)
Song & Ermon, 2019: βGenerative Modeling by Estimating Gradientsβ (score-based)
Ho et al., 2020: βDenoising Diffusion Probabilistic Modelsβ (DDPM, simplified training)
Song et al., 2021: βDenoising Diffusion Implicit Modelsβ (DDIM, fast sampling)
Song et al., 2021: βScore-Based Generative Modeling through SDEsβ (continuous formulation)
Dhariwal & Nichol, 2021: βDiffusion Models Beat GANsβ (classifier guidance, architecture improvements)
Ho & Salimans, 2022: βClassifier-Free Guidanceβ (simplify conditioning)
Rombach et al., 2022: βHigh-Resolution Image Synthesis with Latent Diffusionβ (Stable Diffusion)
Ramesh et al., 2022: βHierarchical Text-Conditional Image Generation with CLIP Latentsβ (DALL-E 2)
Saharia et al., 2022: βPhotorealistic Text-to-Image Diffusion with Deep Language Understandingβ (Imagen)
Lu et al., 2022: βDPM-Solver: Fast Solver for Diffusion Probabilistic Modelsβ
Zhang et al., 2023: βAdding Conditional Control to Text-to-Image Diffusionβ (ControlNet)
15. Comparison with Other Generative ModelsΒΆ
Model |
Likelihood |
Quality |
Diversity |
Speed |
Stability |
|---|---|---|---|---|---|
VAE |
Exact (lower bound) |
Medium |
High |
Fast |
High |
GAN |
Implicit |
High |
Medium |
Fast |
Low |
Flow |
Exact |
Medium |
High |
Fast |
High |
Autoregressive |
Exact |
High |
High |
Slow |
High |
Diffusion |
Approx. tractable |
Highest |
High |
Medium |
High |
When to Use Diffusion:
High-quality image/audio generation
Stable training (no adversarial dynamics)
Flexible conditioning (text, class, etc.)
Likelihood evaluation needed (via ODE)
When NOT to Use:
Real-time generation required β Use GANs or distilled diffusion
Limited compute β Use VAE or smaller model
Simple data β Simpler models sufficient
"""
Advanced Diffusion Models - Complete Implementations
This cell provides production-ready implementations of:
1. DDIM Sampler (deterministic, fast)
2. Score-Based SDE/ODE Solvers
3. Classifier-Free Guidance
4. DPM-Solver (high-order ODE solver)
5. Latent Diffusion Components
6. Advanced Conditioning (cross-attention)
7. Progressive Distillation
8. Evaluation Tools (FID, IS)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate
import warnings
warnings.filterwarnings('ignore')
# ============================================================================
# DDIM Sampler
# ============================================================================
class DDIMSampler:
"""
Denoising Diffusion Implicit Models (Song et al., 2021)
Theory:
- Non-Markovian forward process with same marginals as DDPM
- Deterministic sampling (Ξ·=0) or stochastic (Ξ·=1)
- Supports skipping timesteps for acceleration
"""
def __init__(self, model, betas, eta=0.0):
"""
Args:
model: Noise prediction network Ξ΅_ΞΈ(x_t, t)
betas: Noise schedule Ξ²_t
eta: Stochasticity parameter (0=deterministic, 1=DDPM)
"""
self.model = model
self.eta = eta
# Compute schedule
self.betas = betas
self.alphas = 1.0 - betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def _predict_x0_from_eps(self, x_t, t, eps):
"""Predict x_0 from x_t and predicted noise"""
sqrt_alpha = self.sqrt_alphas_cumprod[t]
sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
return (x_t - sqrt_one_minus_alpha * eps) / sqrt_alpha
def ddim_step(self, x_t, t, t_prev):
"""
Single DDIM step: x_t -> x_{t-1}
Formula:
x_{t-1} = sqrt(αΎ±_{t-1}) * xΜ_0 + sqrt(1-αΎ±_{t-1}-Ο_tΒ²) * Ξ΅_ΞΈ + Ο_t * z
where:
- xΜ_0 = (x_t - sqrt(1-αΎ±_t)*Ξ΅_ΞΈ) / sqrt(αΎ±_t)
- Ο_t = Ξ· * sqrt((1-αΎ±_{t-1})/(1-αΎ±_t)) * sqrt(1-αΎ±_t/αΎ±_{t-1})
"""
# Predict noise
eps = self.model(x_t, t)
# Predict x_0
x0_pred = self._predict_x0_from_eps(x_t, t, eps)
# Get alphas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0)
# Compute variance
sigma_t = self.eta * torch.sqrt(
(1 - alpha_prod_t_prev) / (1 - alpha_prod_t) *
(1 - alpha_prod_t / alpha_prod_t_prev)
)
# Compute direction pointing to x_t
dir_xt = torch.sqrt(1 - alpha_prod_t_prev - sigma_t**2) * eps
# Compute x_{t-1}
x_prev = torch.sqrt(alpha_prod_t_prev) * x0_pred + dir_xt
# Add noise if stochastic
if sigma_t > 0:
noise = torch.randn_like(x_t)
x_prev = x_prev + sigma_t * noise
return x_prev, x0_pred
@torch.no_grad()
def sample(self, shape, num_steps=50, device='cpu'):
"""
Generate samples using DDIM
Args:
shape: Output shape (B, C, H, W)
num_steps: Number of sampling steps (can be << T)
device: Device to run on
Returns:
samples: Generated samples
intermediates: List of intermediate states
"""
# Create timestep schedule (can skip timesteps!)
total_steps = len(self.betas)
timesteps = torch.linspace(total_steps-1, 0, num_steps, dtype=torch.long)
# Start from noise
x = torch.randn(shape, device=device)
intermediates = [x.cpu()]
# Reverse process
for i in range(len(timesteps) - 1):
t = timesteps[i]
t_prev = timesteps[i + 1]
# Single DDIM step
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
x, x0_pred = self.ddim_step(x, t_batch, t_prev)
# Store intermediate
if i % (num_steps // 5) == 0:
intermediates.append(x.cpu())
return x, intermediates
# ============================================================================
# Score-Based SDE/ODE Solvers
# ============================================================================
class ScoreBasedSDE:
"""
Score-Based Generative Modeling via SDE (Song et al., 2021)
Theory:
Forward SDE: dx = f(x,t)dt + g(t)dw
Reverse SDE: dx = [f(x,t) - g(t)Β²βlog p_t(x)]dt + g(t)dΜw
"""
def __init__(self, model, beta_min=0.1, beta_max=20.0, T=1.0):
"""
Args:
model: Score network s_ΞΈ(x, t)
beta_min: Minimum Ξ²(t)
beta_max: Maximum Ξ²(t)
T: Final time
"""
self.model = model
self.beta_min = beta_min
self.beta_max = beta_max
self.T = T
def beta(self, t):
"""Variance schedule Ξ²(t)"""
return self.beta_min + t * (self.beta_max - self.beta_min)
def marginal_prob_std(self, t):
"""
Standard deviation of p_t(x|x_0)
For VP-SDE: std(t) = sqrt(1 - exp(-β«_0^t Ξ²(s)ds))
"""
log_mean_coeff = -0.25 * t**2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
return torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
def sde(self, x, t):
"""
Forward SDE coefficients
Returns:
drift: f(x, t)
diffusion: g(t)
"""
beta_t = self.beta(t)
drift = -0.5 * beta_t * x
diffusion = torch.sqrt(beta_t)
return drift, diffusion
def reverse_sde(self, x, t, score):
"""
Reverse SDE coefficients
Returns:
drift: f(x,t) - g(t)Β²βlog p_t(x)
diffusion: g(t)
"""
drift, diffusion = self.sde(x, t)
drift = drift - diffusion**2 * score
return drift, diffusion
def ode(self, x, t, score):
"""
Probability flow ODE
dx/dt = f(x,t) - Β½g(t)Β²βlog p_t(x)
"""
drift, diffusion = self.sde(x, t)
drift = drift - 0.5 * diffusion**2 * score
return drift
@torch.no_grad()
def euler_maruyama_sampler(self, shape, num_steps=1000, device='cpu'):
"""
Sample using Euler-Maruyama method (stochastic)
x_{i+1} = x_i + f(x_i,t_i)Ξt + g(t_i)βΞtΒ·z_i
"""
# Time discretization
dt = self.T / num_steps
t = torch.ones(shape[0], device=device) * self.T
# Start from noise
x = torch.randn(shape, device=device) * self.marginal_prob_std(t)[:, None, None, None]
# Reverse process
for i in range(num_steps):
# Get score
score = self.model(x, t)
# Reverse SDE
drift, diffusion = self.reverse_sde(x, t, score)
# Euler-Maruyama step
x = x - drift * dt
x = x + diffusion * torch.sqrt(torch.tensor(dt)) * torch.randn_like(x)
# Update time
t = t - dt
return x
@torch.no_grad()
def ode_sampler(self, shape, num_steps=100, device='cpu', method='RK45'):
"""
Sample using probability flow ODE (deterministic)
Uses scipy.integrate for adaptive ODE solving
"""
# Start from noise
x_init = torch.randn(shape, device=device)
# Define ODE function
def ode_func(t, x_flat):
x = torch.tensor(x_flat, device=device).reshape(shape)
t_tensor = torch.ones(shape[0], device=device) * t
with torch.no_grad():
score = self.model(x, t_tensor)
drift = self.ode(x, t_tensor, score)
return drift.cpu().numpy().flatten()
# Solve ODE backward from T to 0
solution = integrate.solve_ivp(
ode_func,
(self.T, 0.0),
x_init.cpu().numpy().flatten(),
method=method,
t_eval=np.linspace(self.T, 0.0, num_steps)
)
# Final sample
x_final = torch.tensor(solution.y[:, -1], device=device).reshape(shape)
return x_final
# ============================================================================
# Classifier-Free Guidance
# ============================================================================
class ConditionalDiffusionModel(nn.Module):
"""
Conditional diffusion model with classifier-free guidance
Theory:
- Train single model: Ξ΅_ΞΈ(x_t, t, c)
- Randomly drop condition: c -> β
with prob p_uncond
- Inference: Ξ΅Μ = (1+w)Ξ΅_ΞΈ(x_t,t,c) - wΒ·Ξ΅_ΞΈ(x_t,t,β
)
"""
def __init__(self, base_model, num_classes=10, p_uncond=0.1):
"""
Args:
base_model: Base U-Net architecture
num_classes: Number of classes (+ 1 for unconditional)
p_uncond: Probability of unconditional training
"""
super().__init__()
self.base_model = base_model
self.num_classes = num_classes
self.p_uncond = p_uncond
# Embedding for class condition
self.class_emb = nn.Embedding(num_classes + 1, base_model.time_dim)
# Last index is "null" class for unconditional
def forward(self, x, t, c=None, guidance_scale=0.0):
"""
Forward pass with optional classifier-free guidance
Args:
x: Noisy input
t: Timestep
c: Class labels (None for unconditional)
guidance_scale: Guidance strength w (0 = unconditional)
Returns:
eps: Predicted noise
"""
batch_size = x.shape[0]
# Training mode: randomly drop condition
if self.training and c is not None:
# Random mask
mask = torch.rand(batch_size, device=x.device) < self.p_uncond
c_input = c.clone()
c_input[mask] = self.num_classes # Null class
# Inference with guidance
elif not self.training and guidance_scale > 0 and c is not None:
# Conditional prediction
c_emb_cond = self.class_emb(c)
eps_cond = self.base_model(x, t, c_emb_cond)
# Unconditional prediction
c_null = torch.full_like(c, self.num_classes)
c_emb_uncond = self.class_emb(c_null)
eps_uncond = self.base_model(x, t, c_emb_uncond)
# Classifier-free guidance
eps = (1 + guidance_scale) * eps_cond - guidance_scale * eps_uncond
return eps
# Standard forward
else:
if c is None:
c = torch.full((batch_size,), self.num_classes, device=x.device, dtype=torch.long)
c_emb = self.class_emb(c)
return self.base_model(x, t, c_emb)
# Standard training path
c_emb = self.class_emb(c_input)
return self.base_model(x, t, c_emb)
# ============================================================================
# DPM-Solver (Fast High-Order ODE Solver)
# ============================================================================
class DPMSolver:
"""
DPM-Solver: Fast ODE Solver for Diffusion ODEs (Lu et al., 2022)
Theory:
- Exploit semi-linear structure of diffusion ODE
- Exact integration for linear part
- High-order approximation for nonlinear part
- 10-20 steps with DDIM quality
"""
def __init__(self, model, alphas_cumprod):
"""
Args:
model: Noise prediction network
alphas_cumprod: Cumulative product of alphas
"""
self.model = model
self.alphas_cumprod = alphas_cumprod
def marginal_lambda(self, t):
"""Ξ»(t) = log(Ξ±_t / (1-Ξ±_t))"""
alpha = self.alphas_cumprod[t]
return torch.log(alpha / (1 - alpha))
def marginal_alpha(self, t):
"""βαΎ±_t"""
return torch.sqrt(self.alphas_cumprod[t])
def marginal_sigma(self, t):
"""β(1-αΎ±_t)"""
return torch.sqrt(1 - self.alphas_cumprod[t])
def noise_pred_fn(self, x, t):
"""Convert to data prediction"""
eps = self.model(x, t)
alpha_t = self.marginal_alpha(t)
sigma_t = self.marginal_sigma(t)
# x_0 prediction
x0 = (x - sigma_t * eps) / alpha_t
return x0
@torch.no_grad()
def dpm_solver_first_order(self, x, t, t_prev):
"""First-order DPM-Solver step"""
lambda_t = self.marginal_lambda(t)
lambda_prev = self.marginal_lambda(t_prev)
h = lambda_prev - lambda_t
alpha_t = self.marginal_alpha(t)
alpha_prev = self.marginal_alpha(t_prev)
sigma_prev = self.marginal_sigma(t_prev)
# Predict x_0
x0_t = self.noise_pred_fn(x, t)
# Update
x_prev = (alpha_prev / alpha_t) * x - sigma_prev * torch.expm1(h) * x0_t
return x_prev
@torch.no_grad()
def dpm_solver_second_order(self, x, t, t_prev, t_prev_prev, x0_prev=None):
"""Second-order DPM-Solver step (more accurate)"""
lambda_t = self.marginal_lambda(t)
lambda_prev = self.marginal_lambda(t_prev)
lambda_prev_prev = self.marginal_lambda(t_prev_prev)
h = lambda_prev - lambda_t
h_prev = lambda_prev_prev - lambda_prev
alpha_t = self.marginal_alpha(t)
alpha_prev = self.marginal_alpha(t_prev)
sigma_prev = self.marginal_sigma(t_prev)
# Predict x_0 at current step
x0_t = self.noise_pred_fn(x, t)
if x0_prev is None:
# First step: use first-order
x_prev = (alpha_prev / alpha_t) * x - sigma_prev * torch.expm1(h) * x0_t
else:
# Second-order: use previous x0 prediction
r = h_prev / h
D = (1 + 0.5 / r) * x0_t - 0.5 / r * x0_prev
x_prev = (alpha_prev / alpha_t) * x - sigma_prev * torch.expm1(h) * D
return x_prev, x0_t
@torch.no_grad()
def sample(self, shape, num_steps=20, order=2, device='cpu'):
"""
Generate samples using DPM-Solver
Args:
shape: Output shape
num_steps: Number of steps (10-20 typically)
order: Solver order (1 or 2)
device: Device
Returns:
samples: Generated samples
"""
# Timestep schedule
total_steps = len(self.alphas_cumprod)
timesteps = torch.linspace(total_steps-1, 0, num_steps, dtype=torch.long)
# Start from noise
x = torch.randn(shape, device=device)
x0_prev = None
for i in range(len(timesteps) - 1):
t = timesteps[i]
t_prev = timesteps[i + 1]
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
t_prev_batch = torch.full((shape[0],), t_prev, device=device, dtype=torch.long)
if order == 1 or i == 0:
# First-order step
x = self.dpm_solver_first_order(x, t_batch, t_prev_batch)
else:
# Second-order step
t_prev_prev = timesteps[i] # Previous timestep
t_prev_prev_batch = torch.full((shape[0],), t_prev_prev, device=device, dtype=torch.long)
x, x0_prev = self.dpm_solver_second_order(x, t_batch, t_prev_batch, t_prev_prev_batch, x0_prev)
return x
# ============================================================================
# Cross-Attention Conditioning (for Text-to-Image)
# ============================================================================
class CrossAttention(nn.Module):
"""
Cross-attention layer for conditioning
Theory:
Q = W_QΒ·Ο(z_t) (from latent features)
K = W_KΒ·Ο(y) (from condition, e.g., text)
V = W_VΒ·Ο(y)
Attention(Q,K,V) = softmax(QK^T/βd)V
"""
def __init__(self, query_dim, context_dim, num_heads=8):
"""
Args:
query_dim: Dimension of query (from image features)
context_dim: Dimension of context (from text embeddings)
num_heads: Number of attention heads
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = query_dim // num_heads
self.to_q = nn.Linear(query_dim, query_dim)
self.to_k = nn.Linear(context_dim, query_dim)
self.to_v = nn.Linear(context_dim, query_dim)
self.to_out = nn.Linear(query_dim, query_dim)
def forward(self, x, context):
"""
Args:
x: Image features (B, N, query_dim)
context: Text embeddings (B, M, context_dim)
Returns:
out: Attended features (B, N, query_dim)
"""
B, N, C = x.shape
# Compute Q, K, V
q = self.to_q(x) # (B, N, C)
k = self.to_k(context) # (B, M, C)
v = self.to_v(context) # (B, M, C)
# Reshape for multi-head attention
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, N, D)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, M, D)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, M, D)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B, H, N, M)
attn = F.softmax(attn, dim=-1)
# Apply attention to values
out = torch.matmul(attn, v) # (B, H, N, D)
# Reshape back
out = out.transpose(1, 2).contiguous().view(B, N, C)
# Output projection
out = self.to_out(out)
return out
# ============================================================================
# Demonstration & Utilities
# ============================================================================
print("Advanced Diffusion Models Implemented:")
print("=" * 70)
print("1. DDIMSampler - Deterministic fast sampling (20-50 steps)")
print("2. ScoreBasedSDE - SDE/ODE formulation, Euler-Maruyama, probability flow")
print("3. ConditionalDiffusionModel - Classifier-free guidance")
print("4. DPMSolver - High-order ODE solver (10-20 steps)")
print("5. CrossAttention - Text conditioning for Stable Diffusion")
print("=" * 70)
# Example: DDIM vs DDPM sampling speed comparison
print("\nExample: Sampling Speed Comparison")
print("-" * 70)
# Setup
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10)
def forward(self, x, t):
return torch.randn_like(x)
dummy_model = DummyModel()
# DDIM sampler
ddim_sampler = DDIMSampler(dummy_model, betas, eta=0.0)
# Compare steps
print(f"DDPM steps: {T}")
print(f"DDIM steps: 50 (20Γ faster)")
print(f"DPM-Solver steps: 20 (50Γ faster)")
print("\nDDIM achieves similar quality with much fewer steps!")
print("\n" + "=" * 70)
print("Key Advantages of Advanced Methods:")
print("=" * 70)
print("1. DDIM (Ξ·=0): Deterministic, enables interpolation, 20Γ faster")
print("2. DDIM (Ξ·>0): Stochastic, quality-speed tradeoff")
print("3. Score-based ODE: Exact likelihood, faster adaptive solvers")
print("4. Classifier-free guidance: Better control, no separate classifier")
print("5. DPM-Solver: 10-20 steps, DDIM quality, high-order accuracy")
print("6. Latent diffusion: 64Γ compression, consumer GPU friendly")
print("=" * 70)
print("\n" + "=" * 70)
print("When to Use Each Method:")
print("=" * 70)
print("β’ DDPM: Highest quality, have compute budget, T=1000 steps OK")
print("β’ DDIM: Need speed, deterministic interpolation, 20-50 steps")
print("β’ Score ODE: Need likelihood, latent manipulation, flexible solvers")
print("β’ Classifier-free: Conditional generation, avoid classifier training")
print("β’ DPM-Solver: Best quality/speed tradeoff, 10-20 steps")
print("β’ Latent Diffusion: High resolution, limited VRAM, Stable Diffusion")
print("=" * 70)
# Visualization: DDIM timestep schedule
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
total_steps = 1000
ddim_steps = [50, 100, 200, 500]
colors = ['red', 'orange', 'green', 'blue']
for steps, color in zip(ddim_steps, colors):
timesteps = np.linspace(total_steps-1, 0, steps).astype(int)
ax.scatter(timesteps, [steps]*len(timesteps), alpha=0.6, s=20, color=color, label=f'{steps} steps')
ax.set_xlabel('Timestep t', fontsize=12)
ax.set_ylabel('Number of DDIM Steps', fontsize=12)
ax.set_title('DDIM Timestep Schedules (Skipping Timesteps)', fontsize=13)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nVisualization shows how DDIM skips timesteps for acceleration!")
print("More steps = better quality, fewer steps = faster sampling")