import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Score MatchingΒΆ
Score Function:ΒΆ
Denoising Score Matching:ΒΆ
where \(\tilde{x} = x + \sigma \epsilon\), \(\epsilon \sim \mathcal{N}(0, I)\).
π Reference Materials:
generative_models.pdf - Generative Models
class ScoreNet(nn.Module):
"""Score network for MNIST."""
def __init__(self, sigma):
super().__init__()
self.sigma = sigma
self.net = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 1, 3, padding=1)
)
def forward(self, x):
return self.net(x)
print("ScoreNet defined")
Denoising Score Matching LossΒΆ
Score-based models learn the score function \(\nabla_x \log p(x)\) β the gradient of the log-density with respect to the data. Direct score matching is intractable, but denoising score matching provides an equivalent objective: add Gaussian noise \(\tilde{x} = x + \sigma \epsilon\) to the data, then train a network \(s_\theta(\tilde{x}, \sigma)\) to predict the score of the noisy distribution. The loss is \(\mathcal{L} = \mathbb{E}_{\sigma, x, \epsilon}\left[\|s_\theta(\tilde{x}, \sigma) + \epsilon / \sigma\|^2\right]\), which has the beautiful interpretation that the network simply learns to denoise β to point from the noisy sample back toward the clean data manifold.
def denoising_score_matching_loss(score_net, x, sigma):
"""Compute denoising score matching loss."""
# Add noise
noise = torch.randn_like(x)
x_noisy = x + sigma * noise
# Predicted score
score_pred = score_net(x_noisy)
# True score: -noise / sigma
score_true = -noise / sigma
# MSE loss
loss = 0.5 * ((score_pred - score_true) ** 2).sum(dim=(1, 2, 3)).mean()
return loss
print("Loss function defined")
3. Langevin Dynamics SamplingΒΆ
Update Rule:ΒΆ
where \(z_t \sim \mathcal{N}(0, I)\).
@torch.no_grad()
def langevin_dynamics(score_net, x_init, n_steps=100, step_size=1e-4):
"""Sample using Langevin dynamics."""
x = x_init.clone()
for t in range(n_steps):
# Compute score
score = score_net(x)
# Langevin update
noise = torch.randn_like(x)
x = x + step_size * score + np.sqrt(2 * step_size) * noise
# Clip to valid range
x = torch.clamp(x, 0, 1)
return x
print("Langevin dynamics defined")
Train Score NetworkΒΆ
The score network is trained across multiple noise levels \(\sigma_1 > \sigma_2 > \cdots > \sigma_L\), spanning from large noise (easy denoising, captures global structure) to small noise (precise denoising, captures fine details). At each training step, a noise level is randomly sampled, noise is added to a batch of data, and the network learns to predict the score at that noise level. The noise-conditional architecture (typically a U-Net with the noise level as an additional input) must handle the full range of noise magnitudes. Monitoring the loss per noise level helps diagnose whether the model struggles more with coarse or fine structure.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
# Model
sigma = 0.1
score_net = ScoreNet(sigma).to(device)
optimizer = torch.optim.Adam(score_net.parameters(), lr=1e-3)
# Train
losses = []
for epoch in range(5):
epoch_loss = 0
for x, _ in train_loader:
x = x.to(device)
loss = denoising_score_matching_loss(score_net, x, sigma)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_loader)
losses.append(avg_loss)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
Generate SamplesΒΆ
Sampling from a trained score model uses Langevin dynamics: starting from random noise, iteratively move in the direction of the score (toward higher density) with a small step size and added stochastic noise: \(x_{t+1} = x_t + \frac{\epsilon}{2} s_\theta(x_t, \sigma) + \sqrt{\epsilon}\, z\), where \(z \sim \mathcal{N}(0, I)\). Annealed Langevin dynamics starts with the largest noise level (where the score landscape is smooth and easy to follow) and gradually decreases to the smallest (where the score is sharp and detailed). This multi-scale sampling process produces high-quality samples and is the precursor to the diffusion model framework that now dominates image generation.
# Generate samples
score_net.eval()
n_samples = 16
# Random initialization
x_init = torch.rand(n_samples, 1, 28, 28).to(device)
# Sample using Langevin dynamics
samples = langevin_dynamics(score_net, x_init, n_steps=200, step_size=5e-5)
# Visualize
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
ax = axes[i // 4, i % 4]
ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('Score-Based Generated Samples', fontsize=12)
plt.tight_layout()
plt.show()
6. Annealed Langevin DynamicsΒΆ
Use multiple noise levels:
Train score network for each level, sample progressively.
class MultiScaleScoreNet(nn.Module):
"""Score network with noise conditioning."""
def __init__(self, n_sigmas):
super().__init__()
self.n_sigmas = n_sigmas
self.net = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 1, 3, padding=1)
)
# Sigma embedding
self.sigma_embed = nn.Embedding(n_sigmas, 64)
def forward(self, x, sigma_idx):
# Simple version: just use different forward passes
return self.net(x)
@torch.no_grad()
def annealed_langevin_dynamics(score_net, x_init, sigmas, n_steps_each=100):
"""Annealed Langevin dynamics."""
x = x_init.clone()
for i, sigma in enumerate(sigmas):
step_size = sigma ** 2 * 0.01
for t in range(n_steps_each):
score = score_net(x)
noise = torch.randn_like(x)
x = x + step_size * score + np.sqrt(2 * step_size) * noise
x = torch.clamp(x, 0, 1)
return x
print("Annealed dynamics defined")
SummaryΒΆ
Score-Based Models:ΒΆ
Key Concepts:
Learn score function \(\nabla_x \log p(x)\)
Sample via Langevin dynamics
Denoising score matching for training
Annealing for multi-scale generation
Connection to Diffusion:ΒΆ
Score matching β noise prediction
Langevin dynamics β reverse diffusion
Unified framework (Song et al.)
Advantages:ΒΆ
No adversarial training
Tractable likelihood
High sample quality
Flexible architectures
Applications:ΒΆ
Image generation
Inpainting
Super-resolution
Inverse problems
Advanced Score-Based Generative Models TheoryΒΆ
1. Mathematical FoundationsΒΆ
Score Function:
The score function is the gradient of the log probability density: $\(s(x) = \nabla_x \log p(x)\)$
Key Insight: The score points toward regions of higher probability density.
Advantages over density modeling:
No partition function needed: \(p(x) = \frac{\exp(-E(x))}{Z}\) where \(Z = \int \exp(-E(x))dx\)
Score: \(\nabla_x \log p(x) = -\nabla_x E(x)\) (partition function cancels!)
Avoid intractable normalization
Sampling via Langevin Dynamics:
Given score \(s(x) = \nabla_x \log p(x)\), sample from \(p(x)\) using: $\(x_{t+1} = x_t + \frac{\epsilon}{2}s(x_t) + \sqrt{\epsilon}z_t, \quad z_t \sim \mathcal{N}(0, I)\)$
Convergence: As \(\epsilon \to 0\) and \(T \to \infty\), \(x_T\) converges to sample from \(p(x)\).
Langevin MCMC Theorem:
The dynamics \(dx = \nabla_x \log p(x)dt + \sqrt{2}dw\) has invariant distribution \(p(x)\).
Discretization with step size \(\epsilon\) requires mixing time \(O(1/\epsilon)\) for convergence.
2. Score Matching ObjectivesΒΆ
Problem: We donβt know \(p(x)\), so we canβt compute \(\nabla_x \log p(x)\) directly.
Explicit Score Matching (HyvΓ€rinen, 2005):
Minimize: $\(\mathcal{L}_{ESM} = \mathbb{E}_{p(x)}\left[\frac{1}{2}\|s_\theta(x) - \nabla_x \log p(x)\|^2\right]\)$
Issue: Still requires \(\nabla_x \log p(x)\)!
Integration by Parts:
Under smoothness assumptions: $\(\mathcal{L}_{ESM} = \mathbb{E}_{p(x)}\left[\text{tr}(\nabla_x s_\theta(x)) + \frac{1}{2}\|s_\theta(x)\|^2\right] + \text{const}\)$
Now tractable! But computing \(\text{tr}(\nabla_x s_\theta(x))\) (Jacobian trace) is expensive.
Denoising Score Matching (Vincent, 2011):
More efficient alternative. Perturb data with noise: $\(q_\sigma(\tilde{x}|x) = \mathcal{N}(\tilde{x}; x, \sigma^2I)\)$
Then: $\(\mathcal{L}_{DSM} = \mathbb{E}_{p(x)}\mathbb{E}_{q_\sigma(\tilde{x}|x)}\left[\frac{1}{2}\left\|s_\theta(\tilde{x}) - \nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x)\right\|^2\right]\)$
Gradient of perturbed distribution: $\(\nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\)$
If \(\tilde{x} = x + \sigma\epsilon\) where \(\epsilon \sim \mathcal{N}(0,I)\): $\(\nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x) = -\frac{\epsilon}{\sigma}\)$
Simplified objective: $\(\mathcal{L}_{DSM} = \mathbb{E}_{x \sim p(x), \epsilon \sim \mathcal{N}(0,I)}\left[\frac{1}{2}\left\|s_\theta(x+\sigma\epsilon) + \frac{\epsilon}{\sigma}\right\|^2\right]\)$
Equivalence Theorem (Vincent, 2011):
Under mild conditions: $\(\nabla_\theta \mathcal{L}_{DSM}(\theta, \sigma) = \nabla_\theta \mathcal{L}_{ESM}(\theta) + O(\sigma^2)\)$
As \(\sigma \to 0\), denoising score matching β explicit score matching.
Sliced Score Matching (Song et al., 2019):
Alternative that avoids Jacobian trace using random projections: $\(\mathcal{L}_{SSM} = \mathbb{E}_{p(x), p(v)}\left[\frac{1}{2}v^T\nabla_x s_\theta(x)v + v^Ts_\theta(x) + \frac{1}{2}\|s_\theta(x)\|^2\right]\)$
where \(v \sim \mathcal{N}(0, I)\) is random projection direction.
Key advantage: Single backpropagation, no Jacobian computation.
3. Noise Conditional Score Networks (NCSN)ΒΆ
Motivation (Song & Ermon, 2019):
Single noise level \(\sigma\) problematic:
Low \(\sigma\): Score accurate near data manifold, but inaccurate in low-density regions
High \(\sigma\): Score accurate everywhere, but blurs data manifold
Solution: Use multiple noise levels!
Noise Schedule: $\(\sigma_1 > \sigma_2 > \cdots > \sigma_L\)$
Typical: geometric sequence \(\sigma_i = \sigma_1 \cdot (\sigma_L/\sigma_1)^{(i-1)/(L-1)}\)
Example: \(\sigma_1 = 1.0\), \(\sigma_L = 0.01\), \(L = 10\)
Conditional Score Network:
Learn \(s_\theta(x, \sigma_i) \approx \nabla_x \log p_{\sigma_i}(x)\) for all noise levels simultaneously.
Training Objective: $\(\mathcal{L}_{NCSN} = \sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(x), \mathcal{N}(\epsilon;0,I)}\left[\left\|s_\theta(x+\sigma_i\epsilon, \sigma_i) + \frac{\epsilon}{\sigma_i}\right\|^2\right]\)$
Weighting: \(\lambda(\sigma_i) = \sigma_i^2\) (variance weighting)
Annealed Langevin Dynamics Sampling:
Start from high noise, gradually reduce:
x_0 ~ N(0, Ο_1Β² I)
for i = 1 to L:
Ξ±_i = Ξ΅ Β· Ο_iΒ² / Ο_LΒ² (adaptive step size)
for t = 1 to T:
z_t ~ N(0, I)
x_t = x_{t-1} + (Ξ±_i/2)Β·s_ΞΈ(x_{t-1}, Ο_i) + βΞ±_iΒ·z_t
x_0 = x_T (initialize next level)
return x_0
Intuition:
Start from pure noise
High noise: Smooth out manifold, easy to sample
Progressively denoise
Low noise: Refine details on data manifold
4. Score-Based SDE (Continuous Formulation)ΒΆ
Limitation of NCSN: Discrete noise levels \(\{\sigma_i\}\), still many steps.
Solution (Song et al., 2021): Continuous-time formulation via SDE!
Forward Process (Diffusion): $\(dx = f(x,t)dt + g(t)dw\)$
where:
\(f(x,t)\): Drift coefficient
\(g(t)\): Diffusion coefficient
\(w\): Standard Wiener process
Marginal Distribution:
At time \(t\), \(x(t)\) has distribution \(p_t(x)\).
Score of Marginal: $\(s(x,t) = \nabla_x \log p_t(x)\)$
Reverse-Time SDE (Anderson, 1982):
The reverse process is also an SDE: $\(dx = \left[f(x,t) - g(t)^2\nabla_x \log p_t(x)\right]dt + g(t)d\bar{w}\)$
where \(\bar{w}\) is reverse-time Wiener process.
Key Insight: Given score \(s_\theta(x,t) \approx \nabla_x \log p_t(x)\), we can simulate reverse SDE!
Probability Flow ODE:
Deterministic alternative with same marginals: $\(\frac{dx}{dt} = f(x,t) - \frac{1}{2}g(t)^2\nabla_x \log p_t(x)\)$
Advantages:
Exact likelihood via change of variables
Faster sampling with adaptive ODE solvers
Enables latent space manipulation
Training:
Perturb data with \(p_t(x|x_0)\), then: $\(\mathcal{L}_{SDE} = \mathbb{E}_{t \sim U(0,T), x_0 \sim p_0, x_t \sim p_t(x_t|x_0)}\left[\lambda(t)\|\nabla_{x_t}\log p_t(x_t|x_0) - s_\theta(x_t, t)\|^2\right]\)$
5. Variance Preserving (VP) vs. Variance Exploding (VE)ΒΆ
Variance Preserving (VP-SDE): $\(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw\)$
Properties:
Marginal: \(p_t(x|x_0) = \mathcal{N}\left(x; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I\right)\)
Variance: \(\mathbb{E}[\|x(t)\|^2] \approx \mathbb{E}[\|x(0)\|^2]\) (preserved)
Equivalent to DDPM
Variance Exploding (VE-SDE): $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}}dw\)$
Properties:
Marginal: \(p_t(x|x_0) = \mathcal{N}(x; x_0, \sigma^2(t)I)\)
Variance: \(\mathbb{E}[\|x(t)\|^2] = \mathbb{E}[\|x(0)\|^2] + \sigma^2(t)\) (exploding)
Equivalent to NCSN
Unified Framework:
Both are special cases of general SDE. Can interpolate: $\(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)(1-\gamma)}dw\)$
where \(\gamma \in [0,1]\):
\(\gamma = 0\): VP-SDE
\(\gamma = 1\): VE-SDE
6. Predictor-Corrector SamplingΒΆ
Motivation: Pure reverse SDE can accumulate errors.
Solution: Alternate between:
Predictor: Update with reverse SDE/ODE
Corrector: Langevin MCMC to improve sample quality
Algorithm:
x_T ~ p_T
for i = T-1 down to 0:
# Predictor: one reverse SDE step
x_i = Predictor(x_{i+1}, s_ΞΈ, i)
# Corrector: M Langevin steps
for j = 1 to M:
x_i = x_i + Ρ·s_ΞΈ(x_i, i) + β(2Ξ΅)Β·z_j
return x_0
Predictors:
Euler-Maruyama
Heun (2nd order)
Ancestral sampling
Correctors:
Langevin dynamics
Annealed Langevin dynamics
None (pure predictor)
Trade-off:
More corrector steps: Better quality, slower
Fewer corrector steps: Faster, lower quality
Typical: 1-5 corrector steps per predictor step.
7. Controllable GenerationΒΆ
Conditional Sampling:
Want to sample from \(p(x|y)\) where \(y\) is condition (class, text, etc.).
Bayes Rule: $\(\nabla_x \log p(x|y) = \nabla_x \log p(x) + \nabla_x \log p(y|x)\)$
Conditional Score: $\(s(x,y,t) = s(x,t) + \nabla_x \log p_t(y|x)\)$
Classifier Guidance:
Train classifier \(p_\phi(y|x,t)\) on noisy data: $\(s(x,y,t) = s_\theta(x,t) + w \cdot \nabla_x \log p_\phi(y|x,t)\)$
where \(w\) is guidance weight.
Classifier-Free Guidance:
Joint training on \((x,y)\) and unconditional \(x\): $\(s(x,y,t) = (1+w)s_\theta(x,t,y) - w \cdot s_\theta(x,t,\emptyset)\)$
Imputation:
For missing data \(x_m\), keep observed \(x_o\) fixed: $\(x_m^{t+1} = x_m^t + \epsilon s_\theta([x_o, x_m^t], t) + \sqrt{2\epsilon}z\)$
Applications: Inpainting, super-resolution, compressed sensing.
Inverse Problems:
General form: \(y = A(x) + \eta\) where \(A\) is forward operator (blur, downsample, etc.).
Posterior Sampling: $\(\nabla_x \log p(x|y) \approx \nabla_x \log p(x) - \frac{1}{2\sigma^2}\nabla_x \|y - A(x)\|^2\)$
Update rule: $\(x^{t+1} = x^t + \epsilon\left[s_\theta(x^t,t) - \frac{1}{\sigma^2}\nabla_x\|y-A(x^t)\|^2\right] + \sqrt{2\epsilon}z\)$
8. Likelihood ComputationΒΆ
Probability Flow ODE: $\(\frac{dx}{dt} = f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\)$
Instantaneous Change of Variables: $\(\frac{d\log p_t(x(t))}{dt} = -\text{div}\left(f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\right)\)$
Log-Likelihood: $\(\log p_0(x(0)) = \log p_T(x(T)) - \int_0^T \text{div}\left(f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\right)dt\)$
Hutchinsonβs Trace Estimator:
For divergence \(\text{div}(f) = \text{tr}(\nabla_x f)\): $\(\mathbb{E}_{v \sim \mathcal{N}(0,I)}\left[v^T\nabla_x f \cdot v\right] = \text{div}(f)\)$
Algorithm:
Sample \(v \sim \mathcal{N}(0,I)\)
Compute \(v^T\nabla_x(f \cdot v)\) via vector-Jacobian product (single backprop!)
Estimate \(\text{div}(f)\)
Bits per Dimension: $\(\text{BPD} = -\frac{\log_2 p(x)}{D}\)$
where \(D\) is data dimensionality.
9. Connections to Other ModelsΒΆ
Energy-Based Models:
If \(p(x) = \frac{1}{Z}\exp(-E(x))\), then: $\(\nabla_x \log p(x) = -\nabla_x E(x)\)$
Score matching β‘ learning energy function gradient.
Denoising Autoencoders:
Denoising score matching objective equivalent to training DAE to denoise: $\(\mathcal{L}_{DAE} = \mathbb{E}_{x,\epsilon}\left[\left\|\frac{x - D_\theta(x+\sigma\epsilon)}{\sigma} + \frac{\epsilon}{\sigma}\right\|^2\right]\)$
where \(D_\theta\) is denoising network.
Diffusion Models:
DDPM noise prediction \(\epsilon_\theta(x_t,t)\) related to score: $\(s_\theta(x_t,t) = -\frac{\epsilon_\theta(x_t,t)}{\sqrt{1-\bar{\alpha}_t}}\)$
Score-based models and diffusion models are equivalent under continuous formulation!
Normalizing Flows:
Probability flow ODE converts score-based model into continuous normalizing flow.
VAE:
Diffusion/score models can be viewed as hierarchical VAE with:
Latents \(x_1, \ldots, x_T\) same dimension as \(x_0\)
Fixed encoder \(q(x_t|x_0)\)
Learned decoder \(p_\theta(x_{t-1}|x_t)\)
10. Advanced ArchitecturesΒΆ
U-Net with Time Embedding:
Standard architecture for score networks:
Time t β Sinusoidal Embedding β MLP β time_emb
Encoder:
x β Conv+GroupNorm+SiLU+time_emb β Skip1
β Downsample β Conv+GroupNorm+SiLU+time_emb β Skip2
β Downsample β Conv+GroupNorm+SiLU+time_emb
Decoder:
β Upsample β Concat(Skip2) β Conv+GroupNorm+SiLU+time_emb
β Upsample β Concat(Skip1) β Conv+GroupNorm+SiLU+time_emb
β Conv β output
Attention Mechanisms:
Add self-attention at low resolutions (e.g., 16Γ16): $\(\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)$
Residual Blocks:
class ResBlock(nn.Module):
def __init__(self, channels, time_emb_dim):
self.conv1 = Conv(channels, channels)
self.time_proj = Linear(time_emb_dim, channels)
self.conv2 = Conv(channels, channels)
def forward(self, x, t_emb):
h = self.conv1(x)
h = h + self.time_proj(t_emb)[:, :, None, None]
h = self.conv2(h)
return x + h # Residual connection
Fourier Features:
For better time embedding: $\(\gamma(t) = [\cos(2\pi\omega_1 t), \sin(2\pi\omega_1 t), \ldots, \cos(2\pi\omega_d t), \sin(2\pi\omega_d t)]\)$
Adaptive Group Normalization:
Condition normalization on time: $\(\text{AdaGN}(h, t) = s_t \cdot \text{GroupNorm}(h) + b_t\)$
where \(s_t, b_t\) are functions of time embedding.
11. Training ImprovementsΒΆ
Importance Sampling for Time:
Weight timesteps by importance: $\(\mathcal{L} = \mathbb{E}_{t \sim p(t)}\left[\frac{\lambda(t)}{p(t)}\mathcal{L}_t\right]\)$
Typical: \(p(t) \propto \lambda(t)\) (importance sampling)
Exponential Moving Average (EMA):
Maintain EMA of parameters for sampling: $\(\theta_{EMA} \leftarrow \gamma \theta_{EMA} + (1-\gamma)\theta\)$
Typical \(\gamma = 0.9999\).
Variance Reduction:
Use antithetic sampling for noise: $\(\mathcal{L} = \frac{1}{2}\left[\mathcal{L}(\epsilon) + \mathcal{L}(-\epsilon)\right]\)$
Consistency Models: (Song et al., 2023)
Self-consistency property: $\(f_\theta(x_t, t) = f_\theta(x_{t'}, t')\)$
for any \(t, t'\) on same trajectory.
Enables one-step generation!
12. Evaluation MetricsΒΆ
Inception Score (IS): $\(IS = \exp\left(\mathbb{E}_x\left[D_{KL}(p(y|x) \| p(y))\right]\right)\)$
FrΓ©chet Inception Distance (FID): $\(FID = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2})\)$
Negative Log-Likelihood (NLL):
Via probability flow ODE + Hutchinson estimator.
Sample Quality vs. Diversity:
Precision: Fraction of generated samples in real manifold
Recall: Fraction of real manifold covered by generated samples
13. State-of-the-Art ResultsΒΆ
ImageNet 256Γ256:
NCSN++ (Song et al., 2021): FID = 2.20
EDM (Karras et al., 2022): FID = 1.79 (SOTA)
Likelihood:
VP-SDE on CIFAR-10: 2.99 bits/dim
VE-SDE on CIFAR-10: 2.92 bits/dim
Speed:
DDPM: 1000 steps
DDIM: 50 steps (20Γ faster)
DPM-Solver: 20 steps (50Γ faster)
Consistency models: 1 step (1000Γ faster!)
14. Applications Beyond ImagesΒΆ
Audio Generation:
WaveGrad: High-quality speech synthesis
DiffWave: Vocoder for text-to-speech
3D Shapes:
Point cloud generation
Mesh generation via SDF
Molecular Design:
Equivariant diffusion for 3D molecules
Protein structure generation (RFdiffusion)
Video:
Video diffusion models
Frame interpolation
Recommendation Systems:
Collaborative filtering with diffusion
15. Practical ConsiderationsΒΆ
Hyperparameters:
Parameter |
Typical Value |
Notes |
|---|---|---|
\(\sigma_{\min}\) |
0.01 |
Minimum noise level |
\(\sigma_{\max}\) |
50-100 |
Maximum noise level |
\(L\) (noise levels) |
10-1000 |
More = better, slower |
Langevin steps |
1-5 per level |
Corrector steps |
Step size \(\epsilon\) |
\(10^{-5}\) to \(10^{-4}\) |
Depends on \(\sigma\) |
EMA decay |
0.9999 |
For parameter averaging |
Noise Schedule:
Geometric: \(\sigma_i = \sigma_{\max} \cdot (\sigma_{\min}/\sigma_{\max})^{i/(L-1)}\)
Computational Cost:
Training:
Similar to DDPM
Single network for all noise levels
Sampling:
NCSN: \(L \times T\) steps (e.g., 10 Γ 100 = 1000)
Faster with ODE solvers
16. Limitations & Open ProblemsΒΆ
Slow Sampling:
Still requires many steps
Solutions: Distillation, consistency models
Open: Single-step score-based generation?
Likelihood Estimation:
Requires ODE + divergence computation
Expensive for high dimensions
Open: Efficient exact likelihood?
Mode Coverage:
Better than GANs, but still challenges
Open: Provable mode coverage?
Theory:
Convergence guarantees for finite steps?
Sample complexity bounds?
Optimal noise schedules?
17. Key Papers (Chronological)ΒΆ
HyvΓ€rinen, 2005: βEstimation of Non-Normalized Statistical Models by Score Matchingβ (score matching foundation)
Vincent, 2011: βA Connection Between Score Matching and Denoising Autoencodersβ (denoising score matching)
Song & Ermon, 2019: βGenerative Modeling by Estimating Gradients of the Data Distributionβ (NCSN)
Song & Ermon, 2020: βImproved Techniques for Training Score-Based Generative Modelsβ (NCSN++)
Song et al., 2021: βScore-Based Generative Modeling through Stochastic Differential Equationsβ (continuous SDE formulation)
Song et al., 2021: βMaximum Likelihood Training of Score-Based Diffusion Modelsβ (likelihood computation)
Karras et al., 2022: βElucidating the Design Space of Diffusion-Based Generative Modelsβ (EDM, SOTA)
Song et al., 2023: βConsistency Modelsβ (one-step generation)
18. Comparison: Score-Based vs. DiffusionΒΆ
Aspect |
Score-Based |
Diffusion (DDPM) |
|---|---|---|
Formulation |
\(s_\theta(x,t) \approx \nabla_x \log p_t(x)\) |
\(\epsilon_\theta(x_t,t) \approx \epsilon\) |
Training |
Score matching |
Noise prediction |
Sampling |
Langevin dynamics |
Ancestral sampling |
Noise |
Continuous levels |
Discrete timesteps |
Framework |
Energy-based |
Hierarchical VAE |
Likelihood |
Via ODE (exact) |
VLB (lower bound) |
Unification:
Under continuous SDE formulation, they are equivalent: $\(s_\theta(x_t,t) = -\frac{\epsilon_\theta(x_t,t)}{\sqrt{1-\bar{\alpha}_t}}\)$
Both can use:
Same architectures (U-Net)
Same training objectives
Same sampling algorithms
When to Use Score-Based:
Need exact likelihood
Prefer energy-based perspective
Continuous-time formulation
Flexible noise schedules
When to Use Diffusion:
Prefer hierarchical VAE perspective
Simpler discrete formulation
Extensive existing codebases
"""
Advanced Score-Based Models - Complete Implementations
This cell provides production-ready implementations of:
1. Noise Conditional Score Network (NCSN)
2. Variance Preserving SDE (VP-SDE)
3. Variance Exploding SDE (VE-SDE)
4. Predictor-Corrector Samplers
5. Sliced Score Matching
6. Likelihood Computation via ODE
7. Conditional Generation
8. Complete Training & Sampling Pipeline
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate
from torch.autograd import grad
import warnings
warnings.filterwarnings('ignore')
# ============================================================================
# Noise Conditional Score Network (NCSN)
# ============================================================================
class ResidualBlock(nn.Module):
"""Residual block with time/noise conditioning"""
def __init__(self, in_channels, out_channels, noise_emb_dim=32):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# Noise level embedding projection
self.noise_proj = nn.Linear(noise_emb_dim, out_channels)
# Skip connection
self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
self.norm1 = nn.GroupNorm(8, out_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
def forward(self, x, noise_emb):
h = self.conv1(x)
h = self.norm1(h)
# Add noise conditioning
h = h + self.noise_proj(noise_emb)[:, :, None, None]
h = F.relu(h)
h = self.conv2(h)
h = self.norm2(h)
return F.relu(h + self.skip(x))
class NoiseConditionalScoreNetwork(nn.Module):
"""
NCSN: Learn score at multiple noise levels
Theory:
s_ΞΈ(x, Ο) β β_x log p_Ο(x)
Train with: L = Ξ£_i Ξ»(Ο_i) E[||s_ΞΈ(x+Ο_iΒ·Ξ΅, Ο_i) + Ξ΅/Ο_i||Β²]
"""
def __init__(self, channels=[32, 64, 128, 256], noise_emb_dim=32, num_res_blocks=2):
super().__init__()
self.noise_emb_dim = noise_emb_dim
# Noise level embedding
self.noise_embed = nn.Sequential(
nn.Linear(1, noise_emb_dim),
nn.ReLU(),
nn.Linear(noise_emb_dim, noise_emb_dim)
)
# Input projection
self.input_proj = nn.Conv2d(1, channels[0], 3, padding=1)
# Encoder
self.encoder_blocks = nn.ModuleList()
self.downsamples = nn.ModuleList()
for i in range(len(channels) - 1):
for _ in range(num_res_blocks):
self.encoder_blocks.append(
ResidualBlock(channels[i], channels[i], noise_emb_dim)
)
self.downsamples.append(nn.Conv2d(channels[i], channels[i+1], 3, stride=2, padding=1))
# Middle
self.middle = nn.ModuleList([
ResidualBlock(channels[-1], channels[-1], noise_emb_dim),
ResidualBlock(channels[-1], channels[-1], noise_emb_dim)
])
# Decoder
self.upsamples = nn.ModuleList()
self.decoder_blocks = nn.ModuleList()
for i in range(len(channels) - 1, 0, -1):
self.upsamples.append(nn.ConvTranspose2d(channels[i], channels[i-1], 4, stride=2, padding=1))
for _ in range(num_res_blocks):
self.decoder_blocks.append(
ResidualBlock(channels[i-1], channels[i-1], noise_emb_dim)
)
# Output
self.output = nn.Conv2d(channels[0], 1, 3, padding=1)
def forward(self, x, sigma):
"""
Args:
x: Input (B, 1, H, W)
sigma: Noise level (B,) or scalar
Returns:
score: Predicted score β_x log p_Ο(x)
"""
# Embed noise level
if isinstance(sigma, float):
sigma = torch.full((x.shape[0],), sigma, device=x.device)
noise_emb = self.noise_embed(sigma.view(-1, 1))
# Input
h = self.input_proj(x)
# Encoder
for block in self.encoder_blocks[:2]:
h = block(h, noise_emb)
h = self.downsamples[0](h)
for block in self.encoder_blocks[2:4]:
h = block(h, noise_emb)
h = self.downsamples[1](h)
for block in self.encoder_blocks[4:]:
h = block(h, noise_emb)
h = self.downsamples[2](h)
# Middle
for block in self.middle:
h = block(h, noise_emb)
# Decoder
h = self.upsamples[0](h)
for block in self.decoder_blocks[:2]:
h = block(h, noise_emb)
h = self.upsamples[1](h)
for block in self.decoder_blocks[2:4]:
h = block(h, noise_emb)
h = self.upsamples[2](h)
for block in self.decoder_blocks[4:]:
h = block(h, noise_emb)
# Output
return self.output(h)
# ============================================================================
# SDE Framework
# ============================================================================
class VariancePreservingSDE:
"""
VP-SDE: dx = -Β½Ξ²(t)x dt + βΞ²(t) dw
Properties:
- Variance preserved: E[||x(t)||Β²] β E[||x(0)||Β²]
- Equivalent to DDPM
"""
def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0):
self.beta_min = beta_min
self.beta_max = beta_max
self.T = T
def beta(self, t):
"""Linear schedule Ξ²(t)"""
return self.beta_min + t * (self.beta_max - self.beta_min)
def mean_coeff(self, t):
"""βΞ±Μ
_t for marginal p_t(x|x_0)"""
log_mean_coeff = -0.25 * t**2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
return torch.exp(log_mean_coeff)
def std(self, t):
"""β(1-Ξ±Μ
_t) for marginal p_t(x|x_0)"""
return torch.sqrt(1.0 - self.mean_coeff(t)**2)
def marginal_prob(self, x0, t):
"""
Sample from p_t(x|x_0) = N(βΞ±Μ
_tΒ·x_0, (1-Ξ±Μ
_t)I)
Returns:
x_t: Noisy sample
std: Standard deviation
"""
mean_coeff = self.mean_coeff(t)
std = self.std(t)
noise = torch.randn_like(x0)
x_t = mean_coeff[:, None, None, None] * x0 + std[:, None, None, None] * noise
return x_t, std, noise
def sde(self, x, t):
"""Forward SDE coefficients"""
beta_t = self.beta(t)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t)
return drift, diffusion
def reverse_sde(self, x, t, score):
"""Reverse SDE coefficients"""
drift, diffusion = self.sde(x, t)
drift = drift - diffusion[:, None, None, None]**2 * score
return drift, diffusion
def ode(self, x, t, score):
"""Probability flow ODE"""
drift, diffusion = self.sde(x, t)
drift = drift - 0.5 * diffusion[:, None, None, None]**2 * score
return drift
class VarianceExplodingSDE:
"""
VE-SDE: dx = β(dΟΒ²(t)/dt) dw
Properties:
- Variance exploding: E[||x(t)||Β²] = E[||x(0)||Β²] + ΟΒ²(t)
- Equivalent to NCSN
"""
def __init__(self, sigma_min=0.01, sigma_max=50.0, T=1.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.T = T
def sigma(self, t):
"""Noise schedule Ο(t)"""
return self.sigma_min * (self.sigma_max / self.sigma_min) ** t
def marginal_prob(self, x0, t):
"""
Sample from p_t(x|x_0) = N(x_0, ΟΒ²(t)I)
"""
std = self.sigma(t)
noise = torch.randn_like(x0)
x_t = x0 + std[:, None, None, None] * noise
return x_t, std, noise
def sde(self, x, t):
"""Forward SDE coefficients"""
sigma_t = self.sigma(t)
drift = torch.zeros_like(x)
diffusion = sigma_t * torch.sqrt(torch.tensor(2 * np.log(self.sigma_max / self.sigma_min)))
return drift, diffusion
def reverse_sde(self, x, t, score):
"""Reverse SDE coefficients"""
drift, diffusion = self.sde(x, t)
drift = -diffusion[:, None, None, None]**2 * score
return drift, diffusion
def ode(self, x, t, score):
"""Probability flow ODE"""
drift, diffusion = self.sde(x, t)
drift = -0.5 * diffusion[:, None, None, None]**2 * score
return drift
# ============================================================================
# Predictor-Corrector Samplers
# ============================================================================
class EulerMaruyamaPredictor:
"""Euler-Maruyama method for reverse SDE"""
def __init__(self, sde, score_fn):
self.sde = sde
self.score_fn = score_fn
def step(self, x, t, dt):
"""Single reverse SDE step"""
score = self.score_fn(x, t)
drift, diffusion = self.sde.reverse_sde(x, t, score)
# Euler-Maruyama: x_{t-dt} = x_t - driftΒ·dt + diffusionΒ·βdtΒ·z
x = x - drift * dt
x = x + diffusion[:, None, None, None] * torch.sqrt(dt) * torch.randn_like(x)
return x
class LangevinCorrector:
"""
Langevin MCMC corrector
Theory:
x' = x + Ρ·s_ΞΈ(x,t) + β(2Ξ΅)Β·z
Refines samples at fixed time
"""
def __init__(self, sde, score_fn, snr=0.16, n_steps=1):
"""
Args:
snr: Signal-to-noise ratio for step size
n_steps: Number of Langevin steps
"""
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
def step(self, x, t):
"""Multiple Langevin refinement steps"""
for _ in range(self.n_steps):
score = self.score_fn(x, t)
# Adaptive step size
noise_norm = torch.norm(score.reshape(score.shape[0], -1), dim=-1).mean()
step_size = (self.snr * self.sde.std(t)[0] / noise_norm) ** 2
# Langevin step
x = x + step_size * score + torch.sqrt(2 * step_size) * torch.randn_like(x)
return x
class PredictorCorrectorSampler:
"""
Combined predictor-corrector sampler
Algorithm:
for t = T to 0:
x = Predictor(x, t) # Move backward in time
x = Corrector(x, t) # Refine at current time
"""
def __init__(self, predictor, corrector):
self.predictor = predictor
self.corrector = corrector
@torch.no_grad()
def sample(self, shape, num_steps=1000, device='cpu'):
"""
Generate samples
Args:
shape: Output shape (B, C, H, W)
num_steps: Number of discretization steps
device: Device
Returns:
samples: Generated samples
"""
# Start from noise
x = torch.randn(shape, device=device)
# Time discretization
dt = self.predictor.sde.T / num_steps
for i in range(num_steps):
t = torch.ones(shape[0], device=device) * (1 - i / num_steps) * self.predictor.sde.T
# Predictor step
x = self.predictor.step(x, t, dt)
# Corrector step
x = self.corrector.step(x, t)
return x
# ============================================================================
# Sliced Score Matching
# ============================================================================
class SlicedScoreMatching:
"""
Sliced Score Matching (Song et al., 2019)
Theory:
L = E[Β½v^Tβ_x s_ΞΈ(x)v + v^T s_ΞΈ(x) + Β½||s_ΞΈ(x)||Β²]
Advantage: No Jacobian trace computation!
"""
@staticmethod
def loss(score_fn, x):
"""
Compute sliced score matching loss
Args:
score_fn: Score network (must support gradients)
x: Data samples
Returns:
loss: Sliced score matching loss
"""
x = x.requires_grad_(True)
# Random projection direction
v = torch.randn_like(x)
# Compute score
score = score_fn(x)
# v^T Β· s_ΞΈ(x)
v_score = torch.sum(v * score)
# v^T Β· β_x s_ΞΈ(x) Β· v (using double backprop)
grad_v_score = grad(v_score, x, create_graph=True)[0]
v_grad_v = torch.sum(v * grad_v_score)
# Loss
loss = 0.5 * v_grad_v + v_score + 0.5 * torch.sum(score ** 2)
return loss / x.shape[0]
# ============================================================================
# Likelihood Computation via ODE
# ============================================================================
class LikelihoodComputer:
"""
Compute exact likelihood via probability flow ODE
Theory:
log p_0(x(0)) = log p_T(x(T)) - β«_0^T div(f) dt
Uses Hutchinson's trace estimator for divergence
"""
def __init__(self, sde, score_fn):
self.sde = sde
self.score_fn = score_fn
def divergence_hutchinson(self, x, t, score, v):
"""
Hutchinson's trace estimator for divergence
E_v[v^T β_x(fΒ·v)] = tr(β_x f) = div(f)
"""
# ODE drift
ode_drift = self.sde.ode(x, t, score)
# v^T Β· ode_drift
v_ode = torch.sum(v * ode_drift)
# β_x(v^T Β· ode_drift) Β· v
grad_v_ode = grad(v_ode, x, create_graph=True)[0]
div_estimate = torch.sum(v * grad_v_ode)
return div_estimate
@torch.no_grad()
def compute_likelihood(self, x0, num_steps=100):
"""
Compute log-likelihood via ODE integration
Args:
x0: Data sample (1, C, H, W)
num_steps: ODE integration steps
Returns:
log_likelihood: log p(x0)
"""
# Prior log probability at T
x_T = torch.randn_like(x0)
log_p_T = -0.5 * torch.sum(x_T**2) - 0.5 * np.prod(x0.shape[1:]) * np.log(2 * np.pi)
# Integrate divergence from 0 to T
dt = self.sde.T / num_steps
divergence_integral = 0
x = x0.clone().requires_grad_(True)
for i in range(num_steps):
t = torch.ones(1, device=x.device) * (i / num_steps) * self.sde.T
# Random projection for Hutchinson estimator
v = torch.randn_like(x)
# Score
score = self.score_fn(x, t)
# Divergence
div = self.divergence_hutchinson(x, t, score, v)
divergence_integral += div * dt
# Update x along ODE
ode_drift = self.sde.ode(x, t, score)
x = x + ode_drift * dt
x = x.detach().requires_grad_(True)
# Log-likelihood
log_likelihood = log_p_T - divergence_integral
return log_likelihood.item()
# ============================================================================
# Training Loop
# ============================================================================
class ScoreBasedTrainer:
"""Complete training pipeline for score-based models"""
def __init__(self, model, sde, optimizer, device='cpu'):
self.model = model
self.sde = sde
self.optimizer = optimizer
self.device = device
self.ema = None # Exponential moving average
def denoising_score_matching_loss(self, x):
"""
Denoising score matching loss
L = E_t,x,Ξ΅[Ξ»(t)||s_ΞΈ(x_t,t) - βlog p_t(x_t|x_0)||Β²]
"""
batch_size = x.shape[0]
# Random time
t = torch.rand(batch_size, device=self.device) * self.sde.T
# Perturb data
x_t, std, noise = self.sde.marginal_prob(x, t)
# Predict score
score_pred = self.model(x_t, std)
# True score: -noise/std
score_true = -noise / std[:, None, None, None]
# Loss (weighted by stdΒ²)
loss = torch.mean(std**2 * torch.sum((score_pred - score_true)**2, dim=(1,2,3)))
return loss
def train_step(self, x):
"""Single training step"""
self.model.train()
loss = self.denoising_score_matching_loss(x)
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# Update EMA
if self.ema is not None:
self.update_ema()
return loss.item()
def update_ema(self, decay=0.9999):
"""Update exponential moving average of parameters"""
if self.ema is None:
self.ema = {k: v.clone().detach() for k, v in self.model.state_dict().items()}
else:
for k, v in self.model.state_dict().items():
self.ema[k] = decay * self.ema[k] + (1 - decay) * v
# ============================================================================
# Demonstration
# ============================================================================
print("Advanced Score-Based Models Implemented:")
print("=" * 70)
print("1. NoiseConditionalScoreNetwork - NCSN with residual blocks")
print("2. VariancePreservingSDE - VP-SDE (equivalent to DDPM)")
print("3. VarianceExplodingSDE - VE-SDE (equivalent to NCSN)")
print("4. EulerMaruyamaPredictor - Reverse SDE sampler")
print("5. LangevinCorrector - MCMC refinement")
print("6. PredictorCorrectorSampler - Combined sampling")
print("7. SlicedScoreMatching - Alternative training objective")
print("8. LikelihoodComputer - Exact likelihood via ODE")
print("9. ScoreBasedTrainer - Complete training pipeline")
print("=" * 70)
# Example: Create models
print("\nExample: Model Instantiation")
print("-" * 70)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# NCSN model
ncsn = NoiseConditionalScoreNetwork(channels=[32, 64, 128, 256]).to(device)
print(f"NCSN parameters: {sum(p.numel() for p in ncsn.parameters()):,}")
# SDE
vp_sde = VariancePreservingSDE(beta_min=0.1, beta_max=20.0)
ve_sde = VarianceExplodingSDE(sigma_min=0.01, sigma_max=50.0)
print(f"VP-SDE: Ξ²_min={vp_sde.beta_min}, Ξ²_max={vp_sde.beta_max}")
print(f"VE-SDE: Ο_min={ve_sde.sigma_min}, Ο_max={ve_sde.sigma_max}")
# Predictor-Corrector sampler
def score_fn(x, t):
"""Wrapper for score function"""
return ncsn(x, vp_sde.std(t))
predictor = EulerMaruyamaPredictor(vp_sde, score_fn)
corrector = LangevinCorrector(vp_sde, score_fn, snr=0.16, n_steps=1)
pc_sampler = PredictorCorrectorSampler(predictor, corrector)
print(f"Predictor-Corrector sampler created")
print("\n" + "=" * 70)
print("Key Advantages:")
print("=" * 70)
print("1. Score-based: No adversarial training, stable optimization")
print("2. NCSN: Multiple noise levels for robust learning")
print("3. SDE formulation: Unified continuous framework")
print("4. Predictor-Corrector: Better sample quality with refinement")
print("5. Exact likelihood: Via probability flow ODE")
print("6. Flexible: VP-SDE (DDPM) or VE-SDE (NCSN)")
print("=" * 70)
print("\n" + "=" * 70)
print("Comparison: VP-SDE vs VE-SDE")
print("=" * 70)
# Compare noise schedules
t_values = torch.linspace(0, 1, 100)
vp_stds = torch.tensor([vp_sde.std(t).item() for t in t_values])
ve_stds = torch.tensor([ve_sde.sigma(t).item() for t in t_values])
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(t_values, vp_stds, label='VP-SDE (variance preserving)', linewidth=2)
ax.plot(t_values, ve_stds, label='VE-SDE (variance exploding)', linewidth=2)
ax.set_xlabel('Time t', fontsize=12)
ax.set_ylabel('Noise Level Ο(t)', fontsize=12)
ax.set_title('Noise Schedules: VP-SDE vs VE-SDE', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nVisualization shows different noise evolution strategies!")
print("VP-SDE: Bounded variance, similar to diffusion models")
print("VE-SDE: Unbounded variance, original NCSN formulation")
print("\n" + "=" * 70)
print("When to Use Each Method:")
print("=" * 70)
print("β’ VP-SDE: When equivalent to DDPM desired, bounded variance")
print("β’ VE-SDE: When following NCSN, unbounded noise acceptable")
print("β’ Predictor-Corrector: When highest quality needed, have compute")
print("β’ Pure Predictor: When speed critical, quality acceptable")
print("β’ Sliced Score Matching: When memory limited (no Jacobian trace)")
print("β’ Likelihood Computation: When exact probabilities needed")
print("=" * 70)
Advanced Score-Based Generative Models: Mathematical Foundations and Modern ArchitecturesΒΆ
1. Introduction to Score-Based ModelsΒΆ
Score-based generative models learn to estimate the score function (gradient of the log-density) instead of directly modeling the probability distribution. This approach offers a powerful alternative to traditional generative models by avoiding intractable normalizations and mode collapse.
Core concept: Model the score function $\(s_\theta(x) = \nabla_x \log p(x) = \frac{\nabla_x p(x)}{p(x)}\)$
Key insight: Score doesnβt require knowing the normalizing constant! $\(\nabla_x \log p(x) = \nabla_x \log \frac{p_{\text{unnorm}}(x)}{Z} = \nabla_x \log p_{\text{unnorm}}(x) - \nabla_x \log Z = \nabla_x \log p_{\text{unnorm}}(x)\)$
Generation via Langevin dynamics: $\(x_{t+1} = x_t + \frac{\epsilon}{2} s_\theta(x_t) + \sqrt{\epsilon} \, z_t, \quad z_t \sim \mathcal{N}(0, I)\)$
Starting from noise \(x_0 \sim \mathcal{N}(0, I)\), this converges to \(p(x)\) as \(\epsilon \to 0\) and \(T \to \infty\).
Advantages:
Flexible architectures: Any neural network (no invertibility constraints)
Mode coverage: Better than GANs (no mode collapse)
Training stability: No adversarial dynamics
High quality: State-of-the-art results (diffusion models)
Unified view: Score-based models unify:
Denoising diffusion models (DDPM)
Noise Conditional Score Networks (NCSN)
Stochastic differential equations (SDE) framework
2. Score Function and Score MatchingΒΆ
2.1 Score Function DefinitionΒΆ
For probability density \(p(x)\), the score function is: $\(s(x) = \nabla_x \log p(x) = \frac{1}{p(x)} \nabla_x p(x)\)$
Geometric interpretation: Points in direction of increasing probability density.
Properties:
Independent of normalization: \(s(x)\) same for \(p(x)\) and \(c \cdot p(x)\)
Zero at modes: \(\nabla_x p(x) = 0\) at local maxima
Curl-free: \(\nabla \times s(x) = 0\) (gradient field)
Example (Gaussian): $\(p(x) = \mathcal{N}(\mu, \Sigma) \implies s(x) = -\Sigma^{-1}(x - \mu)\)$
2.2 Explicit Score MatchingΒΆ
Goal: Match model score \(s_\theta(x)\) to data score \(s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x)\).
Naive objective: $\(\mathcal{L}_{\text{naive}}(\theta) = \frac{1}{2} \mathbb{E}_{x \sim p_{\text{data}}}[\|s_\theta(x) - \nabla_x \log p_{\text{data}}(x)\|^2]\)$
Problem: \(\nabla_x \log p_{\text{data}}(x)\) unknown!
Solution (HyvΓ€rinen, 2005): Integration by parts gives equivalent objective $\(\mathcal{L}_{\text{ESM}}(\theta) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\frac{1}{2}\|s_\theta(x)\|^2 + \text{tr}(\nabla_x s_\theta(x))\right] + \text{const}\)$
where \(\text{tr}(\nabla_x s_\theta(x)) = \sum_{i=1}^D \frac{\partial s_\theta^i(x)}{\partial x_i}\) is the divergence.
Gradient: $\(\nabla_\theta \mathcal{L}_{\text{ESM}} = \mathbb{E}_{x \sim p_{\text{data}}}\left[s_\theta(x) \nabla_\theta s_\theta(x)^T + \nabla_\theta \text{tr}(\nabla_x s_\theta(x))\right]\)$
Computational cost: \(O(D^2)\) for Jacobian trace (expensive for images).
2.3 Denoising Score Matching (DSM)ΒΆ
Key idea (Vincent, 2011): Add noise to data, then match score of noisy distribution.
Noise perturbation: \(q(x | x_0) = \mathcal{N}(x | x_0, \sigma^2 I)\)
Noisy distribution: \(q(x) = \int q(x | x_0) p_{\text{data}}(x_0) dx_0\)
True score of noisy distribution: $\(\nabla_x \log q(x | x_0) = -\frac{x - x_0}{\sigma^2}\)$
DSM objective: $\(\mathcal{L}_{\text{DSM}}(\theta, \sigma) = \frac{1}{2} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x \sim q(x|x_0)}\left[\left\|s_\theta(x) + \frac{x - x_0}{\sigma^2}\right\|^2\right]\)$
Advantages:
\(O(D)\) complexity (no Jacobian)
Fully differentiable
Equivalent to explicit score matching under certain conditions
Implementation:
x_0 ~ p_data
noise = Ο * Ξ΅, Ξ΅ ~ N(0, I)
x = x_0 + noise
loss = ||s_ΞΈ(x) - (-noise/ΟΒ²)||Β²
3. Multi-Scale Score Matching (Noise Conditioning)ΒΆ
3.1 MotivationΒΆ
Problem with single noise level:
Low noise (\(\sigma\) small): Score accurate near data, but unstable in low-density regions
High noise (\(\sigma\) large): Score stable everywhere, but data structure lost
Solution: Train score network at multiple noise levels \(\{\sigma_1, \ldots, \sigma_L\}\) where \(\sigma_1 > \sigma_2 > \cdots > \sigma_L\).
3.2 Noise Conditional Score Networks (NCSN)ΒΆ
Score network: \(s_\theta(x, \sigma): \mathbb{R}^D \times \mathbb{R}_+ \to \mathbb{R}^D\)
Objective: $\(\mathcal{L}_{\text{NCSN}}(\theta) = \sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x \sim \mathcal{N}(x_0, \sigma_i^2 I)}\left[\left\|s_\theta(x, \sigma_i) + \frac{x - x_0}{\sigma_i^2}\right\|^2\right]\)$
Weighting: \(\lambda(\sigma_i) = \sigma_i^2\) (variance-weighted) or uniform.
Noise schedule: Geometric progression $\(\sigma_i = \sigma_{\max} \cdot \left(\frac{\sigma_{\min}}{\sigma_{\max}}\right)^{(i-1)/(L-1)}, \quad i = 1, \ldots, L\)$
Typical: \(\sigma_{\max} = 50\), \(\sigma_{\min} = 0.01\), \(L = 10\) for CIFAR-10.
3.3 Annealed Langevin DynamicsΒΆ
Sampling: Gradually decrease noise level during Langevin dynamics.
Algorithm:
x_L ~ N(0, Ο_1Β² I) # Initialize from largest noise
for i = 1 to L:
Ξ±_i = Ξ΅ Β· Ο_iΒ² / Ο_LΒ² # Adaptive step size
for t = 1 to T:
z_t ~ N(0, I)
x β x + (Ξ±_i/2) s_ΞΈ(x, Ο_i) + βΞ±_i z_t
x_{i-1} β x
return x_0
Intuition:
Large \(\sigma\): Explore global structure
Small \(\sigma\): Refine local details
Convergence: Proven for smooth densities and sufficient annealing steps.
4. Score-Based SDEsΒΆ
4.1 Forward SDE (Diffusion Process)ΒΆ
Continuous-time view: Gradually perturb data with SDE
where:
\(f(x, t)\): Drift coefficient
\(g(t)\): Diffusion coefficient
\(w\): Brownian motion
Variance-Preserving (VP) SDE: $\(dx = -\frac{1}{2}\beta(t) x \, dt + \sqrt{\beta(t)} \, dw\)$
Variance-Exploding (VE) SDE: $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} \, dw\)$
sub-VP SDE: $\(dx = -\frac{1}{2}\beta(t) x \, dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s) ds})} \, dw\)$
4.2 Reverse SDEΒΆ
Anderson (1982) theorem: Reverse-time SDE is
where \(\bar{w}\) is reverse-time Brownian motion.
Key insight: If we know score \(\nabla_x \log p_t(x)\) at all times \(t\), we can reverse the diffusion!
Score approximation: Replace \(\nabla_x \log p_t(x)\) with \(s_\theta(x, t)\)
4.3 Probability Flow ODEΒΆ
Alternative: Reverse process as ODE (deterministic, no noise)
Approximation: $\(\frac{dx}{dt} = f(x, t) - \frac{1}{2} g(t)^2 s_\theta(x, t)\)$
Properties:
Same marginals as reverse SDE: \(p_t(x)\) identical
Deterministic trajectories (useful for inversion, interpolation)
Faster sampling (adaptive ODE solvers)
Example (VP-SDE): $\(\frac{dx}{dt} = -\frac{1}{2}\beta(t)[x + s_\theta(x, t)]\)$
5. Training ObjectivesΒΆ
5.1 Denoising Score Matching with TimeΒΆ
Continuous-time objective: $\(\mathcal{L}_{\text{DSM}}(\theta) = \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x_t \sim p_t(x_t|x_0)}\left[\lambda(t) \left\|s_\theta(x_t, t) - \nabla_{x_t} \log p_t(x_t | x_0)\right\|^2\right]\)$
where \(p_t(x_t | x_0)\) is the transition kernel of the forward SDE.
VP-SDE: \(p_t(x_t | x_0) = \mathcal{N}(x_t | \alpha_t x_0, \beta_t^2 I)\)
\(\alpha_t = e^{-\frac{1}{2}\int_0^t \beta(s) ds}\)
\(\beta_t^2 = 1 - e^{-\int_0^t \beta(s) ds}\)
True score: \(\nabla_{x_t} \log p_t(x_t | x_0) = -\frac{x_t - \alpha_t x_0}{\beta_t^2}\)
Loss: $\(\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\lambda(t) \left\|s_\theta(x_t, t) + \frac{x_t - \alpha_t x_0}{\beta_t^2}\right\|^2\right]\)$
where \(x_t = \alpha_t x_0 + \beta_t \epsilon\), \(\epsilon \sim \mathcal{N}(0, I)\).
5.2 Noise Prediction ParameterizationΒΆ
Reparameterization: Predict noise instead of score $\(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\beta_t}\)$
Noise prediction objective: $\(\mathcal{L}_{\epsilon}(\theta) = \mathbb{E}_{t, x_0, \epsilon}\left[\lambda(t) \|\epsilon_\theta(x_t, t) - \epsilon\|^2\right]\)$
Weighting:
Simple: \(\lambda(t) = 1\)
SNR-weighted: \(\lambda(t) = \beta_t^2\)
Variance-preserving: \(\lambda(t) = \frac{1}{2\beta_t^2}\)
Equivalence: Score matching β noise prediction (up to weighting).
5.3 Likelihood WeightingΒΆ
Optimal weighting (Song et al., 2021): $\(\lambda(t) = g(t)^2\)$
Corresponds to maximizing ELBO (variational lower bound on likelihood).
Weighted loss: $\(\mathcal{L}_{\text{weighted}}(\theta) = \mathbb{E}_{t \sim p(t)} \left[w(t) \mathbb{E}_{x_0, \epsilon}\left[\|s_\theta(x_t, t) - s_t(x_t; x_0)\|^2\right]\right]\)$
where \(w(t) = g(t)^2 / p(t)\).
6. Sampling AlgorithmsΒΆ
6.1 Predictor-Corrector SamplingΒΆ
Two-step process:
Predictor: Numerical SDE/ODE solver step
Corrector: Langevin MCMC step to refine
Algorithm:
x_T ~ N(0, I)
for i = T-1 to 0:
# Predictor (e.g., Euler-Maruyama)
x_i β x_{i+1} + Ξt Β· f(x_{i+1}, t_{i+1}) + βΞt Β· g(t_{i+1}) Β· z
# Corrector (Langevin)
for j = 1 to M:
x_i β x_i + Ξ΅ s_ΞΈ(x_i, t_i) + β(2Ξ΅) z
return x_0
Advantage: Corrector improves sample quality at each timestep.
6.2 SDE SolversΒΆ
Euler-Maruyama (first-order): $\(x_{t-\Delta t} = x_t + [f(x_t, t) - g(t)^2 s_\theta(x_t, t)] \Delta t + g(t) \sqrt{\Delta t} \, z\)$
Stochastic Runge-Kutta (higher-order):
RK2, RK4 adaptations for SDEs
Better accuracy with larger timesteps
Adaptive stepping:
Error estimation (Richardson extrapolation)
Adjust \(\Delta t\) based on local truncation error
6.3 ODE SolversΒΆ
Euler method: $\(x_{t-\Delta t} = x_t + [f(x_t, t) - \frac{1}{2}g(t)^2 s_\theta(x_t, t)] \Delta t\)$
Heunβs method (RK2):
k1 = f(x_t, t) - 0.5 g(t)Β² s_ΞΈ(x_t, t)
xΜ = x_t + Ξt k1
k2 = f(xΜ, t-Ξt) - 0.5 g(t-Ξt)Β² s_ΞΈ(xΜ, t-Ξt)
x_{t-Ξt} = x_t + Ξt (k1 + k2) / 2
DPM-Solver (Song et al., 2022):
Exponential integrator
10-20Γ faster than Euler
~10-20 steps for high quality
DDIM (Song et al., 2021):
Deterministic sampling (ODE)
Skip timesteps (acceleration)
Invertible (can encode images to latent)
7. Continuous vs. Discrete TimeΒΆ
7.1 DDPM (Discrete-Time Diffusion)ΒΆ
Forward process (Markov chain): $\(q(x_t | x_{t-1}) = \mathcal{N}(x_t | \sqrt{1 - \beta_t} x_{t-1}, \beta_t I)\)$
Cumulative: $\(q(x_t | x_0) = \mathcal{N}(x_t | \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I)\)$
where \(\bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s)\).
Reverse process: $\(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1} | \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$
Mean prediction: $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{1 - \beta_t}}\left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$
Variance: Fixed \(\Sigma_\theta = \beta_t I\) or learned.
7.2 Connection to Score-Based ModelsΒΆ
Score from noise prediction: $\(\nabla_x \log p(x_t) = -\frac{\epsilon_\theta(x_t, t)}{\sqrt{1 - \bar{\alpha}_t}}\)$
Unified framework:
DDPM: Discrete-time formulation, variance schedule \(\{\beta_t\}\)
Score SDE: Continuous-time formulation, SDE coefficients \(f, g\)
Conversion: DDPM with \(\beta_t = \beta(t) \Delta t\) β VP-SDE as \(\Delta t \to 0\)
8. Architecture DesignΒΆ
8.1 U-Net for Score NetworksΒΆ
Standard choice: U-Net with:
Downsampling path (encoder)
Upsampling path (decoder)
Skip connections (preserve spatial info)
Time embedding (condition on \(t\) or \(\sigma\))
Time embedding: Sinusoidal positional encoding $\(\gamma(t) = [\sin(2\pi f_1 t), \cos(2\pi f_1 t), \ldots, \sin(2\pi f_k t), \cos(2\pi f_k t)]\)$
Injected via:
FiLM (Feature-wise Linear Modulation): \(\text{FiLM}(h, \gamma) = \gamma_s \odot h + \gamma_b\)
Cross-attention
Adaptive group normalization
8.2 Attention MechanismsΒΆ
Self-attention layers: Model long-range dependencies
Multi-head self-attention: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V\)$
where \(Q = h W_Q\), \(K = h W_K\), \(V = h W_V\).
Placement: Typically at coarser resolutions (16Γ16, 8Γ8) due to \(O(n^2)\) cost.
Cross-attention: For conditional generation (text, class labels)
Queries from image features
Keys, values from conditioning
8.3 Modern ImprovementsΒΆ
EDM (Karras et al., 2022):
Preconditioning (input/output scaling)
Optimal noise schedule
Improved architecture
DiT (Diffusion Transformer):
Replace U-Net with Transformer
Better scalability (parameter count)
State-of-the-art on ImageNet
Latent Diffusion (Stable Diffusion):
Apply diffusion in VAE latent space
4-8Γ faster than pixel-space
Maintains quality
9. Conditional GenerationΒΆ
9.1 Class-Conditional GenerationΒΆ
Conditional score: \(s_\theta(x, t, y)\) where \(y\) is class label
Training: Standard DSM with \((x, y)\) pairs
Classifier-Free Guidance: $\(\tilde{s}_\theta(x, t, y) = s_\theta(x, t) + w \cdot [s_\theta(x, t, y) - s_\theta(x, t)]\)$
where:
\(s_\theta(x, t)\): Unconditional score (train with \(y = \emptyset\) dropout)
\(s_\theta(x, t, y)\): Conditional score
\(w\): Guidance weight (e.g., 1.5-7.5)
Effect: Higher \(w\) β stronger conditioning, lower diversity.
9.2 Text-to-Image GenerationΒΆ
Cross-attention conditioning:
Text embedding: \(c = \text{CLIP/T5}(\text{prompt})\)
Cross-attention: \(\text{Attn}(Q=f(x), K=c, V=c)\)
Classifier-Free Guidance: $\(\tilde{s}_\theta(x, t, c) = s_\theta(x, t) + w \cdot [s_\theta(x, t, c) - s_\theta(x, t)]\)$
Examples:
DALL-E 2: CLIP conditioning + diffusion
Imagen: T5 text encoder + cascaded diffusion
Stable Diffusion: CLIP + latent diffusion
9.3 Image Editing and InpaintingΒΆ
Inpainting: Generate missing region \(x_{\bar{M}}\) given observed \(x_M\)
Method 1 (Repaint): Resample known region at each step
for t = T to 0:
x_t β reverse_step(x_t)
x_t[M] β forward_step(x_0[M], t) # Restore known region
Method 2 (Conditioning): Train \(s_\theta(x, t, x_M)\) with masked inputs
Image-to-image: SDEdit (Meng et al., 2021)
Add noise to input: \(x_T = x_{\text{input}} + \sigma_T \epsilon\)
Denoise with score model
Result: Variation of input
10. Likelihood ComputationΒΆ
10.1 Exact Likelihood via ODEΒΆ
Instantaneous change of variables: $\(\log p_0(x_0) = \log p_T(x_T) - \int_0^T \text{div}(f_\theta)(x_t, t) dt\)$
where \(f_\theta(x, t) = f(x, t) - \frac{1}{2}g(t)^2 s_\theta(x, t)\) is the ODE drift.
Divergence: \(\text{div}(f_\theta) = \text{tr}(\nabla_x f_\theta(x, t))\)
Computational cost: \(O(D^2)\) for Jacobian trace.
Hutchinsonβs estimator (unbiased): $\(\text{tr}(\nabla_x f) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0, I)}[\epsilon^T (\nabla_x f) \epsilon]\)$
Can be computed via vector-Jacobian product (efficient in autograd).
10.2 Approximate LikelihoodΒΆ
ELBO: Variational lower bound (from DDPM) $\(\log p(x_0) \geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log \frac{p_T(x_T) \prod_{t=1}^T p_\theta(x_{t-1}|x_t)}{q(x_{1:T}|x_0)}\right]\)$
Simplified: Sum of denoising scores across timesteps.
11. Advanced TopicsΒΆ
11.1 Riemannian Score-Based ModelsΒΆ
Manifold data: \(x \in \mathcal{M}\) (e.g., SO(3) rotations, protein structures)
Score on manifold: \(s_\theta: \mathcal{M} \times \mathbb{R}_+ \to T\mathcal{M}\) (tangent bundle)
Riemannian diffusion: $\(dx = s_\theta(x, t) dt + \sqrt{2} \, dW_{\mathcal{M}}\)$
where \(W_{\mathcal{M}}\) is Brownian motion on manifold.
Applications:
SE(3) diffusion for protein design
SO(3) for 3D rotations
Hyperbolic space for hierarchical data
11.2 Consistency ModelsΒΆ
Motivation: Distill score-based models to single-step generators.
Consistency function: \(f: (x_t, t) \mapsto x_0\) satisfies $\(f(x_t, t) = f(x_s, s) \quad \forall s, t\)$
Training:
Consistency distillation: Use pretrained score model
Consistency training: Train from scratch
Advantage: 1-step generation (1000Γ faster than DDPM)
Trade-off: Slight quality degradation vs. iterative sampling.
11.3 Flow MatchingΒΆ
Alternative: Directly learn vector field instead of score.
Continuous normalizing flow: $\(\frac{dx}{dt} = v_\theta(x, t)\)$
Training: Regression to optimal transport paths $\(\mathcal{L} = \mathbb{E}_{t, x_0, x_1}\left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\)$
where \(x_t = (1-t)x_0 + t x_1\) (linear interpolation).
Relation to score models: Flow matching β score matching with specific drift.
12. Theoretical AnalysisΒΆ
12.1 Score Matching ConsistencyΒΆ
Theorem (HyvΓ€rinen, 2005): Explicit score matching recovers true distribution.
If \(\theta^* = \arg\min_\theta \mathcal{L}_{\text{ESM}}(\theta)\), then \(s_{\theta^*}(x) = \nabla_x \log p_{\text{data}}(x)\) (under regularity conditions).
Proof sketch:
Objective minimized when \(s_\theta(x) = \nabla_x \log p_{\text{data}}(x)\) for all \(x\)
Integration gives \(\log p_\theta(x) = \log p_{\text{data}}(x) + c\)
Normalization constraint β \(c = 0\)
12.2 Convergence of Langevin DynamicsΒΆ
Theorem: Under log-concavity and smoothness, Langevin dynamics converges exponentially to target distribution.
where \(W_2\) is 2-Wasserstein distance.
Practice: Non-log-concave data, finite steps β approximate samples.
12.3 Sample ComplexityΒΆ
Theorem (Song et al., 2021): Score matching sample complexity is \(\tilde{O}(d^2 / \epsilon^2)\) for \(\epsilon\)-accurate score estimation in dimension \(d\).
Comparison:
GANs: Potentially exponential in \(d\) (mode collapse)
Normalizing flows: \(O(d)\) architectures only
Score models: Polynomial, flexible architectures
13. Training TechniquesΒΆ
13.1 Noise Schedule DesignΒΆ
Linear schedule (DDPM): $\(\beta_t = \beta_{\min} + (\beta_{\max} - \beta_{\min}) \frac{t}{T}\)$
Cosine schedule (Improved DDPM): $\(\bar{\alpha}_t = \frac{f(t)}{f(0)}, \quad f(t) = \cos\left(\frac{t/T + s}{1 + s} \cdot \frac{\pi}{2}\right)^2\)$
Learned schedule: Optimize \(\beta_t\) or \(\sigma_t\) via variational bound.
EDM schedule (Karras et al., 2022): $\(\sigma_i = \left(\sigma_{\max}^{1/\rho} + \frac{i-1}{N-1}(\sigma_{\min}^{1/\rho} - \sigma_{\max}^{1/\rho})\right)^\rho\)$
with \(\rho = 7\), \(\sigma_{\max} = 80\), \(\sigma_{\min} = 0.002\).
13.2 EMA (Exponential Moving Average)ΒΆ
Maintain EMA of parameters: $\(\theta_{\text{EMA}} \leftarrow \gamma \theta_{\text{EMA}} + (1 - \gamma) \theta\)$
Typical: \(\gamma = 0.9999\)
Benefit: Smoother model, better generation quality.
13.3 Mixed Precision TrainingΒΆ
FP16 training: Reduce memory, accelerate training.
Loss scaling: Prevent gradient underflow $\(\mathcal{L}_{\text{scaled}} = \text{scale} \cdot \mathcal{L}\)$
Gradient clipping: Stabilize training $\(g \leftarrow \frac{g}{\max(1, \|g\| / \text{clip\_norm})}\)$
14. Applications and ResultsΒΆ
14.1 Image GenerationΒΆ
CIFAR-10:
NCSN++: FID 2.2
DDPM++: FID 2.78
EDM: FID 1.97
ImageNet 256Γ256:
Improved DDPM: FID 10.94
CDM (Cascaded): FID 4.88
DiT-XL/2: FID 2.27
High-resolution:
Latent Diffusion (Stable Diffusion): 512Γ512, FID 12.63 on COCO
DALL-E 2: 1024Γ1024
Imagen: 1024Γ1024, state-of-art text-to-image
14.2 Audio and SpeechΒΆ
WaveGrad: Raw audio waveform generation
24 kHz audio
FID competitive with GANs
DiffWave: Vocoder (mel-spectrogram to waveform)
MOS (Mean Opinion Score) 4.4+ (near human quality)
Grad-TTS: Text-to-speech
End-to-end diffusion
Natural prosody
14.3 Video GenerationΒΆ
Video Diffusion Models (Ho et al., 2022):
Factorized space-time U-Net
16 frames @ 64Γ64, FVD 481
Imagen Video: High-resolution text-to-video
Cascaded diffusion (24 β 128 β 1024)
Temporal attention
14.4 3D and MoleculesΒΆ
Point-E (OpenAI): Text-to-3D point clouds
Diffusion on point clouds
1-2 minutes per shape
DreamFusion: Text-to-3D via distillation
Score Distillation Sampling (SDS)
Neural radiance fields (NeRF)
Molecule generation:
EDM for molecular graphs
SE(3)-equivariant diffusion
15. ComparisonsΒΆ
15.1 Score-Based vs. GANsΒΆ
Aspect |
Score-Based |
GANs |
|---|---|---|
Training |
Stable |
Unstable (mode collapse) |
Sampling |
Slow (iterative) |
Fast (single pass) |
Diversity |
High |
Can be limited |
Likelihood |
Tractable (ODE) |
Intractable |
Architecture |
Flexible |
Generator + Discriminator |
Quality |
State-of-art |
High |
Recommendation: Score models for quality/diversity, GANs for speed.
15.2 Score-Based vs. VAEsΒΆ
Aspect |
Score-Based |
VAEs |
|---|---|---|
Latent space |
No explicit latent |
Structured latent |
Likelihood |
Exact (ODE) |
ELBO (approximate) |
Sampling |
Slow |
Fast |
Interpolation |
ODE trajectories |
Latent interpolation |
Quality |
Higher |
Moderate (blurry) |
Recommendation: Score models for quality, VAEs for latent manipulation.
15.3 Continuous vs. Discrete TimeΒΆ
Aspect |
SDE (Continuous) |
DDPM (Discrete) |
|---|---|---|
Formulation |
Stochastic differential equation |
Markov chain |
Flexibility |
General SDEs (VP, VE, sub-VP) |
Fixed variance schedule |
Likelihood |
ODE change of variables |
ELBO |
Solvers |
Adaptive SDE/ODE solvers |
Ancestral sampling |
Theory |
Anderson theorem |
Variational inference |
Unified: Both are equivalent in the limit.
16. Limitations and Future DirectionsΒΆ
16.1 Current LimitationsΒΆ
Sampling speed:
50-1000 steps for high quality (vs. 1 for GANs)
~50 seconds for 256Γ256 image (vs. 0.1s GAN)
Computational cost:
Training: 100-1000 GPU-days for large models
Inference: Multiple forward passes
Determinism:
Stochastic generation (randomness in sampling)
Reproducibility requires seed control
16.2 Acceleration TechniquesΒΆ
Faster samplers:
DPM-Solver: 10-20 steps (20-50Γ speedup)
DDIM: Deterministic, skip steps
Consistency models: 1-4 steps
Distillation:
Progressive distillation (Salimans & Ho, 2022)
4 steps β 2 steps β 1 step
Latent diffusion:
Diffuse in compressed latent space
4-8Γ faster than pixel space
16.3 Future DirectionsΒΆ
Unified frameworks:
Connect score models, flows, and ODEs
Optimal transport theory
Continuous-time models:
Neural ODEs/SDEs
Infinitely deep networks
Applications:
Scientific computing (PDEs, molecular dynamics)
Inverse problems (super-resolution, inpainting, CT reconstruction)
Controllable generation (fine-grained control)
Theoretical understanding:
Sample complexity bounds
Approximation theory
Convergence guarantees
17. SummaryΒΆ
Key Concepts:
Score function: \(s(x) = \nabla_x \log p(x)\) independent of normalization
Score matching: Train via denoising (avoid intractable score)
Multi-scale: Noise conditioning for stable training across densities
SDE framework: Continuous-time diffusion (forward) and reverse processes
Sampling: Langevin dynamics, predictor-corrector, ODE solvers
Training Recipe:
Choose SDE (VP, VE, sub-VP)
Design U-Net with time embedding
Train with denoising score matching loss
Use EMA, mixed precision, gradient clipping
Sampling Recipe:
Initialize from noise \(x_T \sim \mathcal{N}(0, I)\)
Run reverse SDE/ODE with score network
Use adaptive solvers (DPM-Solver, Heunβs method)
Apply classifier-free guidance for conditioning
Advantages:
State-of-the-art generation quality
Training stability (no adversarial dynamics)
Mode coverage (no collapse)
Exact likelihood computation (ODE)
Disadvantages:
Slow sampling (iterative process)
Computational cost (training and inference)
Memory requirements (U-Net parameters)
Best for:
High-quality image/audio/video generation
Text-to-image synthesis
Likelihood-based modeling
Controllable generation
"""
Advanced Score-Based Models - Production Implementation
Comprehensive PyTorch implementations with SDE solvers and modern architectures
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List, Callable
from dataclasses import dataclass
# ===========================
# 1. Time Embeddings
# ===========================
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for time conditioning"""
def __init__(self, dim: int, max_period: float = 10000.0):
super().__init__()
self.dim = dim
self.max_period = max_period
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Args:
t: (batch_size,) tensor of timesteps
Returns:
(batch_size, dim) embeddings
"""
half_dim = self.dim // 2
embeddings = math.log(self.max_period) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
embeddings = t[:, None] * embeddings[None, :]
embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
return embeddings
# ===========================
# 2. U-Net Building Blocks
# ===========================
class ResidualBlock(nn.Module):
"""Residual block with time conditioning via FiLM"""
def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int,
dropout: float = 0.1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# Time embedding projection (FiLM parameters)
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, 2 * out_channels) # scale and bias
)
self.norm1 = nn.GroupNorm(8, out_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
self.dropout = nn.Dropout(dropout)
# Shortcut
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
else:
self.shortcut = nn.Identity()
def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W)
time_emb: (B, time_emb_dim)
"""
h = self.conv1(F.silu(self.norm1(x)))
# FiLM conditioning: scale and bias from time embedding
time_out = self.time_mlp(time_emb)
scale, bias = time_out.chunk(2, dim=1)
h = h * (1 + scale[:, :, None, None]) + bias[:, :, None, None]
h = self.dropout(h)
h = self.conv2(F.silu(self.norm2(h)))
return h + self.shortcut(x)
class AttentionBlock(nn.Module):
"""Multi-head self-attention block"""
def __init__(self, channels: int, num_heads: int = 4):
super().__init__()
self.channels = channels
self.num_heads = num_heads
assert channels % num_heads == 0, "channels must be divisible by num_heads"
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, 1)
self.proj_out = nn.Conv2d(channels, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W)
"""
B, C, H, W = x.shape
h = self.norm(x)
qkv = self.qkv(h)
# Reshape for multi-head attention
q, k, v = qkv.chunk(3, dim=1)
q = q.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
k = k.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
v = v.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
# Attention: (B, num_heads, HW, HW)
scale = (C // self.num_heads) ** -0.5
attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) * scale, dim=-1)
# Apply attention to values
h = torch.matmul(attn, v) # (B, num_heads, HW, C//num_heads)
h = h.transpose(2, 3).reshape(B, C, H, W)
return x + self.proj_out(h)
class Downsample(nn.Module):
"""Downsampling with conv stride 2"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Upsample(nn.Module):
"""Upsampling with nearest neighbor + conv"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
# ===========================
# 3. Score Network (U-Net)
# ===========================
class ScoreNet(nn.Module):
"""
U-Net architecture for score function estimation
Outputs score s_ΞΈ(x, t) = β_x log p_t(x)
"""
def __init__(self,
in_channels: int = 3,
model_channels: int = 128,
out_channels: int = 3,
num_res_blocks: int = 2,
attention_resolutions: List[int] = [16, 8],
dropout: float = 0.1,
channel_mult: List[int] = [1, 2, 2, 2],
num_heads: int = 4):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
# Time embedding
time_emb_dim = model_channels * 4
self.time_embed = nn.Sequential(
SinusoidalPosEmb(model_channels),
nn.Linear(model_channels, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
# Initial convolution
self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1)
# Downsampling path
self.down_blocks = nn.ModuleList()
self.down_samples = nn.ModuleList()
channels = [model_channels]
now_channels = model_channels
for level, mult in enumerate(channel_mult):
out_ch = model_channels * mult
for _ in range(num_res_blocks):
layers = [ResidualBlock(now_channels, out_ch, time_emb_dim, dropout)]
now_channels = out_ch
# Add attention at specified resolutions
if 2 ** level in attention_resolutions:
layers.append(AttentionBlock(now_channels, num_heads))
self.down_blocks.append(nn.ModuleList(layers))
channels.append(now_channels)
# Downsample (except last level)
if level != len(channel_mult) - 1:
self.down_samples.append(Downsample(now_channels))
channels.append(now_channels)
# Middle blocks
self.middle_block = nn.ModuleList([
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
AttentionBlock(now_channels, num_heads),
ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
])
# Upsampling path
self.up_blocks = nn.ModuleList()
self.up_samples = nn.ModuleList()
for level, mult in enumerate(reversed(channel_mult)):
for i in range(num_res_blocks + 1):
# Skip connection from downsampling
skip_ch = channels.pop()
out_ch = model_channels * mult
layers = [ResidualBlock(now_channels + skip_ch, out_ch, time_emb_dim, dropout)]
now_channels = out_ch
# Add attention at specified resolutions
if 2 ** (len(channel_mult) - 1 - level) in attention_resolutions:
layers.append(AttentionBlock(now_channels, num_heads))
# Upsample (except first iteration and last level)
if i == num_res_blocks and level != len(channel_mult) - 1:
layers.append(Upsample(now_channels))
self.up_blocks.append(nn.ModuleList(layers))
# Output
self.output_conv = nn.Sequential(
nn.GroupNorm(8, now_channels),
nn.SiLU(),
nn.Conv2d(now_channels, out_channels, 3, padding=1)
)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) noisy input
t: (B,) timesteps
Returns:
score: (B, C, H, W) estimated score
"""
# Time embedding
t_emb = self.time_embed(t)
# Initial conv
h = self.input_conv(x)
# Downsampling with skip connections
skip_connections = [h]
for i, block in enumerate(self.down_blocks):
for layer in block:
if isinstance(layer, ResidualBlock):
h = layer(h, t_emb)
else:
h = layer(h)
skip_connections.append(h)
if i < len(self.down_samples):
h = self.down_samples[i](h)
skip_connections.append(h)
# Middle
for layer in self.middle_block:
if isinstance(layer, ResidualBlock):
h = layer(h, t_emb)
else:
h = layer(h)
# Upsampling with skip connections
for block in self.up_blocks:
skip = skip_connections.pop()
h = torch.cat([h, skip], dim=1)
for layer in block:
if isinstance(layer, ResidualBlock):
h = layer(h, t_emb)
else:
h = layer(h)
# Output
return self.output_conv(h)
# ===========================
# 4. Noise Schedules (VP, VE, sub-VP)
# ===========================
@dataclass
class NoiseScheduleConfig:
"""Configuration for noise schedule"""
schedule_type: str # 'VP', 'VE', 'sub-VP'
beta_min: float = 0.1
beta_max: float = 20.0
sigma_min: float = 0.01
sigma_max: float = 50.0
T: float = 1.0 # Total time
class NoiseSchedule:
"""Noise schedule for SDE"""
def __init__(self, config: NoiseScheduleConfig):
self.config = config
def beta(self, t: torch.Tensor) -> torch.Tensor:
"""Variance schedule Ξ²(t)"""
if self.config.schedule_type in ['VP', 'sub-VP']:
# Linear schedule
return self.config.beta_min + (self.config.beta_max - self.config.beta_min) * t
else:
return torch.zeros_like(t)
def alpha_t(self, t: torch.Tensor) -> torch.Tensor:
"""Cumulative product Ξ±_t = exp(-1/2 β«_0^t Ξ²(s) ds)"""
if self.config.schedule_type == 'VP':
integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
return torch.exp(-0.5 * integral)
elif self.config.schedule_type == 'sub-VP':
integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
return torch.exp(-0.5 * integral)
else: # VE
return torch.ones_like(t)
def sigma_t(self, t: torch.Tensor) -> torch.Tensor:
"""Noise level Ο(t)"""
if self.config.schedule_type == 'VP':
alpha = self.alpha_t(t)
return torch.sqrt(1 - alpha ** 2)
elif self.config.schedule_type == 'sub-VP':
alpha = self.alpha_t(t)
integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
return torch.sqrt(1 - torch.exp(-integral))
else: # VE
# Geometric interpolation
log_sigma = torch.log(torch.tensor(self.config.sigma_min)) + \
t * (torch.log(torch.tensor(self.config.sigma_max)) - torch.log(torch.tensor(self.config.sigma_min)))
return torch.exp(log_sigma)
def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Marginal distribution p_t(x | x_0) = N(ΞΌ_t, Ο_tΒ²I)
Returns: (mean, std)
"""
alpha = self.alpha_t(t).view(-1, 1, 1, 1)
sigma = self.sigma_t(t).view(-1, 1, 1, 1)
mean = alpha * x_0
return mean, sigma
def perturbation_kernel(self, x_0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from p_t(x | x_0)
Returns: (x_t, noise)
"""
mean, std = self.marginal_prob(x_0, t)
noise = torch.randn_like(x_0)
x_t = mean + std * noise
return x_t, noise
# ===========================
# 5. Denoising Score Matching Trainer
# ===========================
class DenoisingScoreMatching:
"""Denoising Score Matching training"""
def __init__(self,
score_net: ScoreNet,
noise_schedule: NoiseSchedule,
device: str = 'cuda'):
self.score_net = score_net.to(device)
self.noise_schedule = noise_schedule
self.device = device
def loss(self, x_0: torch.Tensor, t: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
"""
Denoising score matching loss
L = E_t E_{x_0} E_{x_t|x_0} [Ξ»(t) ||s_ΞΈ(x_t, t) - β log p_t(x_t | x_0)||Β²]
"""
batch_size = x_0.shape[0]
# Sample random timesteps
if t is None:
t = torch.rand(batch_size, device=self.device) * self.noise_schedule.config.T
# Perturb data: x_t = Ξ±_t x_0 + Ο_t Ξ΅
x_t, noise = self.noise_schedule.perturbation_kernel(x_0, t)
# True score: β log p_t(x_t | x_0) = -(x_t - Ξ±_t x_0) / Ο_tΒ²
alpha = self.noise_schedule.alpha_t(t).view(-1, 1, 1, 1)
sigma = self.noise_schedule.sigma_t(t).view(-1, 1, 1, 1)
true_score = -(x_t - alpha * x_0) / (sigma ** 2)
# Predicted score
pred_score = self.score_net(x_t, t)
# Loss: ||s_ΞΈ - s_true||Β²
# Equivalent to noise prediction with Ξ»(t) = ΟΒ²
loss = torch.mean((pred_score - true_score) ** 2)
metrics = {
'loss': loss.item(),
'mean_t': t.mean().item(),
'mean_sigma': sigma.mean().item()
}
return loss, metrics
def train_step(self, x_0: torch.Tensor, optimizer: torch.optim.Optimizer) -> dict:
"""Single training step"""
optimizer.zero_grad()
loss, metrics = self.loss(x_0)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.score_net.parameters(), max_norm=1.0)
optimizer.step()
return metrics
# ===========================
# 6. SDE Samplers
# ===========================
class EulerMaruyamaSampler:
"""Euler-Maruyama SDE solver for reverse-time sampling"""
def __init__(self,
score_net: ScoreNet,
noise_schedule: NoiseSchedule,
num_steps: int = 1000,
device: str = 'cuda'):
self.score_net = score_net
self.noise_schedule = noise_schedule
self.num_steps = num_steps
self.device = device
@torch.no_grad()
def sample(self, batch_size: int, img_shape: Tuple[int, int, int]) -> torch.Tensor:
"""
Sample from p_0 using reverse SDE
dx = [f(x,t) - g(t)Β² β log p_t(x)] dt + g(t) dwΜ
"""
T = self.noise_schedule.config.T
dt = -T / self.num_steps
# Initialize from noise
x = torch.randn(batch_size, *img_shape, device=self.device)
x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
for i in range(self.num_steps):
t = timesteps[i]
t_batch = torch.full((batch_size,), t, device=self.device)
# Compute score
score = self.score_net(x, t_batch)
# SDE coefficients
beta_t = self.noise_schedule.beta(t_batch)
if self.noise_schedule.config.schedule_type == 'VP':
# dx = -1/2 Ξ²(t) [x + s_ΞΈ(x,t)] dt + βΞ²(t) dw
drift = -0.5 * beta_t.view(-1, 1, 1, 1) * (x + score)
diffusion = torch.sqrt(beta_t).view(-1, 1, 1, 1)
elif self.noise_schedule.config.schedule_type == 'VE':
# dx = Ο(t) βΟ(t) s_ΞΈ(x,t) dt + β(d[ΟΒ²(t)]/dt) dw
sigma = self.noise_schedule.sigma_t(t_batch).view(-1, 1, 1, 1)
# d[ΟΒ²]/dt for VE schedule
dsigma2_dt = 2 * sigma * (torch.log(torch.tensor(self.noise_schedule.config.sigma_max)) -
torch.log(torch.tensor(self.noise_schedule.config.sigma_min)))
drift = sigma ** 2 * score
diffusion = torch.sqrt(torch.abs(dsigma2_dt))
else: # sub-VP
drift = -0.5 * beta_t.view(-1, 1, 1, 1) * (x + score)
alpha = self.noise_schedule.alpha_t(t_batch).view(-1, 1, 1, 1)
diffusion = torch.sqrt(beta_t * (1 - alpha ** 2)).view(-1, 1, 1, 1)
# Euler-Maruyama step
z = torch.randn_like(x) if i < self.num_steps - 1 else torch.zeros_like(x)
x = x + drift * dt + diffusion * torch.sqrt(torch.abs(dt)) * z
return x
class ODESampler:
"""ODE solver for deterministic sampling (probability flow ODE)"""
def __init__(self,
score_net: ScoreNet,
noise_schedule: NoiseSchedule,
num_steps: int = 100,
device: str = 'cuda'):
self.score_net = score_net
self.noise_schedule = noise_schedule
self.num_steps = num_steps
self.device = device
@torch.no_grad()
def sample(self, batch_size: int, img_shape: Tuple[int, int, int],
method: str = 'heun') -> torch.Tensor:
"""
Sample using probability flow ODE
dx/dt = f(x,t) - 1/2 g(t)Β² β log p_t(x)
"""
T = self.noise_schedule.config.T
dt = -T / self.num_steps
# Initialize
x = torch.randn(batch_size, *img_shape, device=self.device)
x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
for i in range(self.num_steps):
t = timesteps[i]
t_batch = torch.full((batch_size,), t, device=self.device)
if method == 'euler':
# Euler method
drift = self._ode_drift(x, t_batch)
x = x + drift * dt
elif method == 'heun':
# Heun's method (RK2)
drift1 = self._ode_drift(x, t_batch)
x_tilde = x + drift1 * dt
t_next = timesteps[i + 1]
t_next_batch = torch.full((batch_size,), t_next, device=self.device)
drift2 = self._ode_drift(x_tilde, t_next_batch)
x = x + 0.5 * (drift1 + drift2) * dt
return x
def _ode_drift(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Compute ODE drift: f(x,t) - 1/2 g(t)Β² s_ΞΈ(x,t)"""
score = self.score_net(x, t)
beta_t = self.noise_schedule.beta(t).view(-1, 1, 1, 1)
if self.noise_schedule.config.schedule_type == 'VP':
# f = -1/2 Ξ²(t) x, gΒ² = Ξ²(t)
drift = -0.5 * beta_t * (x + score)
elif self.noise_schedule.config.schedule_type == 'VE':
sigma = self.noise_schedule.sigma_t(t).view(-1, 1, 1, 1)
drift = 0.5 * sigma ** 2 * score
else: # sub-VP
drift = -0.5 * beta_t * (x + score)
return drift
# ===========================
# 7. Predictor-Corrector Sampler
# ===========================
class PredictorCorrectorSampler:
"""Predictor-Corrector sampling (combines SDE step + Langevin refinement)"""
def __init__(self,
score_net: ScoreNet,
noise_schedule: NoiseSchedule,
num_steps: int = 1000,
num_corrector_steps: int = 1,
snr: float = 0.16,
device: str = 'cuda'):
self.score_net = score_net
self.noise_schedule = noise_schedule
self.num_steps = num_steps
self.num_corrector_steps = num_corrector_steps
self.snr = snr # Signal-to-noise ratio for Langevin
self.device = device
@torch.no_grad()
def sample(self, batch_size: int, img_shape: Tuple[int, int, int]) -> torch.Tensor:
"""Predictor-Corrector sampling"""
T = self.noise_schedule.config.T
dt = -T / self.num_steps
x = torch.randn(batch_size, *img_shape, device=self.device)
x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
for i in range(self.num_steps):
t = timesteps[i]
t_batch = torch.full((batch_size,), t, device=self.device)
# Predictor: Euler-Maruyama step
score = self.score_net(x, t_batch)
beta_t = self.noise_schedule.beta(t_batch).view(-1, 1, 1, 1)
drift = -0.5 * beta_t * (x + score)
diffusion = torch.sqrt(beta_t)
z = torch.randn_like(x) if i < self.num_steps - 1 else torch.zeros_like(x)
x = x + drift * dt + diffusion * torch.sqrt(torch.abs(dt)) * z
# Corrector: Langevin MCMC steps
if i < self.num_steps - 1:
for _ in range(self.num_corrector_steps):
score = self.score_net(x, t_batch)
# Step size based on SNR
grad_norm = torch.norm(score.reshape(batch_size, -1), dim=-1).mean()
noise_norm = math.sqrt(np.prod(img_shape))
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
z = torch.randn_like(x)
x = x + step_size * score + torch.sqrt(2 * step_size) * z
return x
# ===========================
# 8. Classifier-Free Guidance
# ===========================
class ConditionalScoreNet(nn.Module):
"""Score network with classifier-free guidance"""
def __init__(self, base_net: ScoreNet, num_classes: int, dropout_prob: float = 0.1):
super().__init__()
self.base_net = base_net
self.num_classes = num_classes
self.dropout_prob = dropout_prob
# Class embedding
self.class_embed = nn.Embedding(num_classes + 1, base_net.model_channels * 4) # +1 for unconditional
def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: (B, C, H, W)
t: (B,)
y: (B,) class labels or None for unconditional
"""
if y is None:
y = torch.full((x.shape[0],), self.num_classes, device=x.device, dtype=torch.long)
# During training, randomly drop labels for unconditional training
if self.training:
mask = torch.rand(x.shape[0], device=x.device) < self.dropout_prob
y = torch.where(mask, self.num_classes, y)
# Add class embedding to time embedding (simple approach)
# In practice, inject via FiLM or cross-attention
# For simplicity, we'll pass through base network
# (Real implementation would modify ResidualBlock to accept class_emb)
return self.base_net(x, t)
@torch.no_grad()
def forward_with_guidance(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor,
guidance_scale: float = 1.0) -> torch.Tensor:
"""
Classifier-free guidance:
sΜ_ΞΈ(x,t,y) = s_ΞΈ(x,t) + wΒ·[s_ΞΈ(x,t,y) - s_ΞΈ(x,t)]
= (1-w)Β·s_ΞΈ(x,t) + wΒ·s_ΞΈ(x,t,y)
"""
# Conditional score
cond_score = self.forward(x, t, y)
if guidance_scale == 1.0:
return cond_score
# Unconditional score
uncond_score = self.forward(x, t, None)
# Guided score
return uncond_score + guidance_scale * (cond_score - uncond_score)
# ===========================
# 9. Demo Functions
# ===========================
def demo_time_embedding():
"""Demonstrate sinusoidal time embeddings"""
print("=" * 50)
print("Demo: Sinusoidal Time Embedding")
print("=" * 50)
emb = SinusoidalPosEmb(dim=128)
t = torch.linspace(0, 1, 10)
embeddings = emb(t)
print(f"Input timesteps: {t.shape} -> {t[:5].tolist()[:5]}...")
print(f"Embeddings shape: {embeddings.shape}")
print(f"First embedding (t=0): {embeddings[0, :8].tolist()}")
print(f"Last embedding (t=1): {embeddings[-1, :8].tolist()}")
print()
def demo_score_network():
"""Demonstrate score network forward pass"""
print("=" * 50)
print("Demo: Score Network (U-Net)")
print("=" * 50)
model = ScoreNet(
in_channels=3,
model_channels=64,
out_channels=3,
num_res_blocks=2,
attention_resolutions=[16],
channel_mult=[1, 2, 2]
)
x = torch.randn(2, 3, 32, 32)
t = torch.rand(2)
score = model(x, t)
num_params = sum(p.numel() for p in model.parameters())
print(f"Input: x={x.shape}, t={t.shape}")
print(f"Output score: {score.shape}")
print(f"Score range: [{score.min():.3f}, {score.max():.3f}]")
print(f"Total parameters: {num_params:,}")
print(f"Model size: ~{num_params * 4 / 1024**2:.1f} MB (FP32)")
print()
def demo_noise_schedules():
"""Demonstrate different noise schedules"""
print("=" * 50)
print("Demo: Noise Schedules (VP, VE, sub-VP)")
print("=" * 50)
t = torch.linspace(0, 1, 11)
for schedule_type in ['VP', 'VE', 'sub-VP']:
config = NoiseScheduleConfig(schedule_type=schedule_type)
schedule = NoiseSchedule(config)
print(f"\n{schedule_type} Schedule:")
print(f"{'t':>6} {'Ξ²(t)':>8} {'Ξ±_t':>8} {'Ο_t':>8}")
print("-" * 32)
for ti in [0.0, 0.25, 0.5, 0.75, 1.0]:
ti_tensor = torch.tensor([ti])
beta = schedule.beta(ti_tensor).item()
alpha = schedule.alpha_t(ti_tensor).item()
sigma = schedule.sigma_t(ti_tensor).item()
print(f"{ti:>6.2f} {beta:>8.3f} {alpha:>8.3f} {sigma:>8.3f}")
print()
def demo_denoising_score_matching():
"""Demonstrate DSM training step"""
print("=" * 50)
print("Demo: Denoising Score Matching")
print("=" * 50)
device = 'cpu'
model = ScoreNet(in_channels=3, model_channels=32, out_channels=3,
num_res_blocks=1, channel_mult=[1, 2])
config = NoiseScheduleConfig(schedule_type='VP')
schedule = NoiseSchedule(config)
dsm = DenoisingScoreMatching(model, schedule, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Dummy training
x_0 = torch.randn(4, 3, 32, 32)
print("Training for 5 steps...")
for step in range(5):
metrics = dsm.train_step(x_0, optimizer)
print(f"Step {step+1}: loss={metrics['loss']:.4f}, mean_t={metrics['mean_t']:.3f}, "
f"mean_sigma={metrics['mean_sigma']:.3f}")
print()
def demo_sampling():
"""Demonstrate sampling with different methods"""
print("=" * 50)
print("Demo: Sampling Methods")
print("=" * 50)
device = 'cpu'
model = ScoreNet(in_channels=3, model_channels=32, out_channels=3,
num_res_blocks=1, channel_mult=[1, 2])
model.eval()
config = NoiseScheduleConfig(schedule_type='VP')
schedule = NoiseSchedule(config)
# Euler-Maruyama (SDE)
print("\n1. Euler-Maruyama SDE Sampler:")
sampler_sde = EulerMaruyamaSampler(model, schedule, num_steps=10, device=device)
samples_sde = sampler_sde.sample(batch_size=2, img_shape=(3, 32, 32))
print(f" Generated samples: {samples_sde.shape}")
print(f" Sample range: [{samples_sde.min():.2f}, {samples_sde.max():.2f}]")
# ODE (deterministic)
print("\n2. ODE Sampler (Heun):")
sampler_ode = ODESampler(model, schedule, num_steps=10, device=device)
samples_ode = sampler_ode.sample(batch_size=2, img_shape=(3, 32, 32), method='heun')
print(f" Generated samples: {samples_ode.shape}")
print(f" Sample range: [{samples_ode.min():.2f}, {samples_ode.max():.2f}]")
# Predictor-Corrector
print("\n3. Predictor-Corrector Sampler:")
sampler_pc = PredictorCorrectorSampler(model, schedule, num_steps=10,
num_corrector_steps=1, device=device)
samples_pc = sampler_pc.sample(batch_size=2, img_shape=(3, 32, 32))
print(f" Generated samples: {samples_pc.shape}")
print(f" Sample range: [{samples_pc.min():.2f}, {samples_pc.max():.2f}]")
print()
def demo_classifier_free_guidance():
"""Demonstrate classifier-free guidance"""
print("=" * 50)
print("Demo: Classifier-Free Guidance")
print("=" * 50)
base_net = ScoreNet(in_channels=3, model_channels=32, out_channels=3,
num_res_blocks=1, channel_mult=[1, 2])
model = ConditionalScoreNet(base_net, num_classes=10, dropout_prob=0.1)
model.eval()
x = torch.randn(2, 3, 32, 32)
t = torch.rand(2)
y = torch.tensor([3, 7])
# No guidance
score_1 = model.forward_with_guidance(x, t, y, guidance_scale=1.0)
print(f"Score (w=1.0): {score_1.shape}, range=[{score_1.min():.3f}, {score_1.max():.3f}]")
# With guidance
score_2 = model.forward_with_guidance(x, t, y, guidance_scale=2.0)
print(f"Score (w=2.0): {score_2.shape}, range=[{score_2.min():.3f}, {score_2.max():.3f}]")
# Strong guidance
score_5 = model.forward_with_guidance(x, t, y, guidance_scale=5.0)
print(f"Score (w=5.0): {score_5.shape}, range=[{score_5.min():.3f}, {score_5.max():.3f}]")
print("\nInterpretation:")
print(" Higher guidance scale β stronger conditioning β less diversity")
print(" Typical values: 1.0 (no guidance) to 7.5 (strong guidance)")
print()
def print_performance_comparison():
"""Comprehensive performance comparison and decision guide"""
print("=" * 80)
print("PERFORMANCE COMPARISON: Score-Based Generative Models")
print("=" * 80)
# 1. Image Generation Quality
print("\n1. Image Generation Quality (FID β, IS β)")
print("-" * 80)
data = [
("Model", "CIFAR-10 FID", "ImageNet 256 FID", "Notes"),
("-" * 30, "-" * 12, "-" * 15, "-" * 30),
("NCSN (Song 2019)", "25.3", "N/A", "Early score-based model"),
("NCSN++ (Song 2020)", "2.2", "N/A", "Improved architecture"),
("DDPM (Ho 2020)", "3.17", "N/A", "Discrete-time diffusion"),
("Improved DDPM", "2.9", "10.94", "Cosine schedule + hybrid loss"),
("Score SDE (VP)", "2.20", "9.89", "Continuous-time VP-SDE"),
("Score SDE (VE)", "2.38", "11.3", "Variance-exploding SDE"),
("Score SDE (sub-VP)", "2.61", "9.56", "Sub-variance-preserving"),
("EDM (Karras 2022)", "1.97", "N/A", "Optimal preconditioning"),
("DiT-XL/2 (Peebles 2023)", "N/A", "2.27", "Diffusion Transformer"),
("", "", "", ""),
("COMPARISON:", "", "", ""),
("StyleGAN2", "2.92", "2.71", "Best GAN (fast sampling)"),
("BigGAN-deep", "6.95", "6.95", "Class-conditional GAN"),
("VAE (Ξ²=1)", "~80", "N/A", "Blurry reconstructions"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<12} {row[2]:<15} {row[3]:<30}")
# 2. Sampling Speed
print("\n2. Sampling Speed Comparison")
print("-" * 80)
data = [
("Method", "Steps", "Time (256Γ256)", "Quality", "Type"),
("-" * 25, "-" * 6, "-" * 15, "-" * 10, "-" * 15),
("DDPM (ancestral)", "1000", "~50 sec", "Excellent", "Stochastic"),
("DDIM (deterministic)", "50", "~5 sec", "Very Good", "Deterministic"),
("DPM-Solver++", "10-20", "~1-2 sec", "Excellent", "Deterministic"),
("Consistency Model", "1-4", "~0.1-0.5 sec", "Good", "Deterministic"),
("EDM (Heun solver)", "35-79", "~3-8 sec", "SOTA", "Hybrid"),
("Latent Diffusion", "50", "~2-3 sec", "Excellent", "Latent space"),
("", "", "", "", ""),
("GAN (StyleGAN2)", "1", "~0.1 sec", "Excellent", "Single-step"),
("VAE", "1", "~0.05 sec", "Moderate", "Single-step"),
("Normalizing Flow", "1", "~0.2 sec", "Good", "Invertible"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<6} {row[2]:<15} {row[3]:<10} {row[4]:<15}")
# 3. Training Stability
print("\n3. Training Stability and Convergence")
print("-" * 80)
data = [
("Model", "Stability", "Mode Coverage", "Hyperparameter Sensitivity"),
("-" * 20, "-" * 12, "-" * 15, "-" * 30),
("Score-based/Diffusion", "Very Stable", "Excellent", "Low (robust)"),
("GAN", "Unstable", "Poor-Moderate", "Very High (careful tuning)"),
("VAE", "Stable", "Good", "Moderate (Ξ²-VAE)"),
("Normalizing Flow", "Stable", "Good", "Moderate (architecture)"),
("Energy-based", "Moderate", "Good", "High (MCMC sensitive)"),
]
for row in data:
print(f"{row[0]:<20} {row[1]:<12} {row[2]:<15} {row[3]:<30}")
# 4. Likelihood Evaluation
print("\n4. Likelihood Evaluation (bits/dim on CIFAR-10, β better)")
print("-" * 80)
data = [
("Model", "Likelihood", "Method", "Notes"),
("-" * 25, "-" * 12, "-" * 20, "-" * 30),
("Score SDE (VP)", "2.99", "ODE (exact)", "Continuous normalizing flow"),
("DDPM (improved)", "2.94", "ELBO (lower bound)", "Discrete-time variational"),
("Glow (Flow)", "3.35", "Exact", "Invertible architecture"),
("VAE (PixelCNN++)", "2.92", "ELBO", "Hybrid VAE + autoregressive"),
("PixelCNN++", "2.92", "Exact", "Autoregressive"),
("", "", "", ""),
("Score models", "Comparable", "ODE tractable", "Flexible architecture"),
("Note:", "", "", "Likelihood β sample quality"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<12} {row[2]:<20} {row[3]:<30}")
# 5. SDE Type Comparison
print("\n5. SDE Type Comparison")
print("-" * 80)
data = [
("SDE Type", "Forward SDE", "Best Use Case", "FID (ImageNet 256)"),
("-" * 15, "-" * 35, "-" * 25, "-" * 20),
("VP", "dx = -Β½Ξ²(t)x dt + βΞ²(t) dw", "General purpose", "9.89"),
("VE", "dx = β(d[ΟΒ²]/dt) dw", "High-resolution images", "11.3"),
("sub-VP", "dx = -Β½Ξ²(t)x dt + β(Ξ²(1-Ξ±Β²)) dw", "Better likelihood", "9.56"),
("DDPM-equiv", "Discrete Markov chain", "Simple implementation", "10.94"),
]
for row in data:
print(f"{row[0]:<15} {row[1]:<35} {row[2]:<25} {row[3]:<20}")
# 6. Training Hyperparameters
print("\n6. Recommended Training Hyperparameters")
print("-" * 80)
data = [
("Parameter", "CIFAR-10", "ImageNet 256Γ256", "Notes"),
("-" * 25, "-" * 15, "-" * 20, "-" * 30),
("Model channels", "128", "256", "Base width"),
("Channel multipliers", "[1,2,2,2]", "[1,1,2,2,4,4]", "Depth scaling"),
("Num res blocks", "2-4", "2-3", "Per resolution"),
("Attention res", "[16]", "[32,16,8]", "Apply self-attention"),
("Dropout", "0.1", "0.1", "Regularization"),
("", "", "", ""),
("Batch size", "128", "256-2048", "Larger better"),
("Learning rate", "2e-4", "1e-4", "Adam optimizer"),
("EMA decay", "0.9999", "0.9999", "For sampling"),
("Gradient clip", "1.0", "1.0", "Stability"),
("", "", "", ""),
("Training steps", "800K", "1M-3M", "Until convergence"),
("Noise schedule", "Linear", "Cosine/EDM", "Ξ²(t) or Ο(t)"),
("T (time horizon)", "1.0", "1.0", "Total diffusion time"),
("Ο_min / Ο_max (VE)", "0.01 / 50", "0.002 / 80", "Noise range"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<15} {row[2]:<20} {row[3]:<30}")
# 7. Sampling Configuration
print("\n7. Sampling Configuration Trade-offs")
print("-" * 80)
data = [
("Method", "Steps", "Quality", "Speed", "Deterministic", "Use Case"),
("-" * 20, "-" * 6, "-" * 10, "-" * 8, "-" * 12, "-" * 25),
("Ancestral (DDPM)", "1000", "Excellent", "Slow", "No", "Best quality"),
("DDIM", "50-100", "Very Good", "Medium", "Yes", "Fast + invertible"),
("DPM-Solver++", "10-20", "Excellent", "Fast", "Yes", "Production (recommended)"),
("Euler-Maruyama", "100-500", "Good", "Medium", "No", "General SDE"),
("Heun (RK2)", "35-79", "SOTA", "Medium", "Yes", "EDM (best quality)"),
("Predictor-Corrector", "100-500", "Excellent", "Slow", "No", "High quality + diversity"),
("ODE (prob flow)", "20-100", "Very Good", "Fast", "Yes", "Likelihood / inversion"),
("Consistency", "1-4", "Good", "Very Fast", "Yes", "Real-time applications"),
]
for row in data:
print(f"{row[0]:<20} {row[1]:<6} {row[2]:<10} {row[3]:<8} {row[4]:<12} {row[5]:<25}")
# 8. Guidance Trade-offs
print("\n8. Classifier-Free Guidance Trade-offs")
print("-" * 80)
data = [
("Guidance Scale (w)", "Quality", "Diversity", "Condition Strength", "Use Case"),
("-" * 18, "-" * 12, "-" * 12, "-" * 18, "-" * 30),
("1.0 (no guidance)", "Good", "High", "Weak", "Unconditional / diverse"),
("1.5", "Good", "High", "Moderate", "Slight conditioning"),
("3.0", "Very Good", "Moderate", "Strong", "Balanced (text-to-image)"),
("5.0", "Excellent", "Low", "Very Strong", "Precise control"),
("7.5", "SOTA", "Very Low", "Extreme", "DALL-E 2 / Stable Diffusion"),
("10.0+", "Saturated", "Minimal", "Maximum", "Overfitting / artifacts"),
]
for row in data:
print(f"{row[0]:<18} {row[1]:<12} {row[2]:<12} {row[3]:<18} {row[4]:<30}")
print("Note: Classifier-free guidance requires dropout_prob=0.1 during training")
# 9. Application-Specific Results
print("\n9. Application-Specific Results")
print("-" * 80)
data = [
("Application", "Model", "Result", "Notes"),
("-" * 25, "-" * 25, "-" * 30, "-" * 30),
("Text-to-Image", "DALL-E 2", "Human-quality 1024Γ1024", "CLIP + diffusion"),
("", "Imagen", "SOTA photorealism", "T5 + cascaded diffusion"),
("", "Stable Diffusion", "512Γ512, open-source", "Latent diffusion"),
("", "Midjourney v5", "Artistic generation", "Commercial"),
("", "", "", ""),
("Image Editing", "SDEdit", "Stroke-to-image", "Stochastic editing"),
("", "Repaint", "Inpainting", "Resample known region"),
("", "DiffEdit", "Text-guided editing", "Mask + diffusion"),
("", "", "", ""),
("Audio", "WaveGrad", "24kHz waveform generation", "Raw audio diffusion"),
("", "DiffWave", "MOS 4.4+ vocoder", "Mel-to-waveform"),
("", "Grad-TTS", "Natural TTS", "End-to-end diffusion"),
("", "", "", ""),
("Video", "Video Diffusion", "16 frames @ 64Γ64", "Factorized space-time"),
("", "Imagen Video", "1024p text-to-video", "Cascaded diffusion"),
("", "", "", ""),
("3D", "Point-E", "Text-to-3D point clouds", "Diffusion on point clouds"),
("", "DreamFusion", "Text-to-NeRF", "Score distillation"),
("", "", "", ""),
("Science", "Molecule generation", "Valid molecules", "Graph diffusion"),
("", "Protein design", "SE(3)-equivariant", "Manifold diffusion"),
("", "Inverse problems", "CT reconstruction", "Posterior sampling"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<25} {row[2]:<30} {row[3]:<30}")
# 10. Decision Guide
print("\n10. DECISION GUIDE: When to Use Score-Based Models")
print("=" * 80)
print("\nβ USE Score-Based Models When:")
advantages = [
"β’ Need state-of-the-art generation quality (images, audio, video)",
"β’ Require diverse samples (avoid mode collapse)",
"β’ Want training stability (no adversarial dynamics)",
"β’ Need exact likelihood (via ODE)",
"β’ Flexible architecture constraints (any network)",
"β’ Conditional generation (text-to-image, class-conditional)",
"β’ Image editing and inpainting applications",
"β’ Scientific computing (inverse problems, molecular generation)",
]
for adv in advantages:
print(adv)
print("\nβ AVOID Score-Based Models When:")
limitations = [
"β’ Real-time generation required (GANs 100Γ faster)",
"β’ Limited computational budget (training costly)",
"β’ Single-step sampling mandatory (VAEs, GANs, Flows)",
"β’ Exact control over latent space needed (VAEs better)",
"β’ Very low-resolution images (overkill, simple models sufficient)",
]
for lim in limitations:
print(lim)
print("\nβ RECOMMENDED ALTERNATIVE:")
alternatives = [
"β’ Fast sampling β Latent Diffusion (4-8Γ faster) or DPM-Solver++ (20 steps)",
"β’ Real-time β Consistency Models (1-4 steps) or distilled models",
"β’ Latent manipulation β VAE or GAN inversion",
"β’ Extremely fast β StyleGAN2 or other GANs",
]
for alt in alternatives:
print(alt)
# 11. Variant Selection Guide
print("\n11. VARIANT SELECTION GUIDE")
print("=" * 80)
data = [
("Variant", "When to Use", "Pros", "Cons"),
("-" * 20, "-" * 30, "-" * 30, "-" * 30),
("DDPM", "Simple implementation", "Well-documented, stable", "Slower sampling (1000 steps)"),
("Score SDE (VP)", "General purpose", "Flexible, continuous-time", "Complex formulation"),
("Score SDE (VE)", "High-res images", "Better for large images", "Slightly lower FID"),
("DDIM", "Fast deterministic", "50 steps, invertible", "Slight quality loss"),
("DPM-Solver++", "Production deployment", "10-20 steps, SOTA", "Newer (less tested)"),
("EDM", "Best quality", "SOTA FID, optimal", "Complex preconditioning"),
("Latent Diffusion", "Fast high-res", "4-8Γ faster, 512Γ512+", "Requires VAE training"),
("Consistency", "Real-time", "1-4 steps, very fast", "Quality trade-off"),
]
for row in data:
print(f"{row[0]:<20} {row[1]:<30} {row[2]:<30} {row[3]:<30}")
# 12. Troubleshooting
print("\n12. TROUBLESHOOTING COMMON ISSUES")
print("=" * 80)
data = [
("Problem", "Possible Cause", "Solution"),
("-" * 30, "-" * 35, "-" * 40),
("Poor sample quality", "Insufficient training", "Train longer (1M+ steps)"),
("", "Bad noise schedule", "Try cosine or EDM schedule"),
("", "Too few sampling steps", "Increase to 50-100 steps"),
("", "", ""),
("Slow convergence", "Learning rate too low", "Increase LR to 1e-4 or 2e-4"),
("", "Small batch size", "Increase to 128+"),
("", "", ""),
("Training instability", "Exploding gradients", "Gradient clipping (max_norm=1.0)"),
("", "Bad initialization", "Use default PyTorch init"),
("", "", ""),
("Blurry samples", "Score network too small", "Increase model_channels"),
("", "Not enough denoising", "More sampling steps"),
("", "", ""),
("OOM (out of memory)", "Batch size too large", "Reduce batch size or resolution"),
("", "Model too large", "Reduce model_channels"),
("", "", "Use mixed precision (FP16)"),
("", "", ""),
("Slow sampling", "Too many steps", "Use DPM-Solver++ (10-20 steps)"),
("", "High resolution", "Use latent diffusion"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<35} {row[2]:<40}")
print("\n" + "=" * 80)
print("Summary: Score-based models offer SOTA quality with stable training,")
print("but require iterative sampling. Use DPM-Solver++ or latent diffusion")
print("for practical deployment. Classifier-free guidance boosts quality.")
print("=" * 80)
print()
# ===========================
# Run All Demos
# ===========================
if __name__ == "__main__":
print("\n" + "=" * 80)
print("SCORE-BASED GENERATIVE MODELS - COMPREHENSIVE IMPLEMENTATION")
print("=" * 80 + "\n")
demo_time_embedding()
demo_score_network()
demo_noise_schedules()
demo_denoising_score_matching()
demo_sampling()
demo_classifier_free_guidance()
print_performance_comparison()
print("\n" + "=" * 80)
print("All demos completed successfully!")
print("=" * 80)