import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Score MatchingΒΆ

Score Function:ΒΆ

\[s_\theta(x) \approx \nabla_x \log p(x)\]

Denoising Score Matching:ΒΆ

\[\mathcal{L} = \mathbb{E}_{x, \tilde{x}} \left[\frac{1}{2}\|s_\theta(\tilde{x}) - \nabla_{\tilde{x}} \log q(\tilde{x}|x)\|^2\right]\]

where \(\tilde{x} = x + \sigma \epsilon\), \(\epsilon \sim \mathcal{N}(0, I)\).

πŸ“š Reference Materials:

class ScoreNet(nn.Module):
    """Score network for MNIST."""
    
    def __init__(self, sigma):
        super().__init__()
        self.sigma = sigma
        
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 1, 3, padding=1)
        )
    
    def forward(self, x):
        return self.net(x)

print("ScoreNet defined")

Denoising Score Matching LossΒΆ

Score-based models learn the score function \(\nabla_x \log p(x)\) – the gradient of the log-density with respect to the data. Direct score matching is intractable, but denoising score matching provides an equivalent objective: add Gaussian noise \(\tilde{x} = x + \sigma \epsilon\) to the data, then train a network \(s_\theta(\tilde{x}, \sigma)\) to predict the score of the noisy distribution. The loss is \(\mathcal{L} = \mathbb{E}_{\sigma, x, \epsilon}\left[\|s_\theta(\tilde{x}, \sigma) + \epsilon / \sigma\|^2\right]\), which has the beautiful interpretation that the network simply learns to denoise – to point from the noisy sample back toward the clean data manifold.

def denoising_score_matching_loss(score_net, x, sigma):
    """Compute denoising score matching loss."""
    # Add noise
    noise = torch.randn_like(x)
    x_noisy = x + sigma * noise
    
    # Predicted score
    score_pred = score_net(x_noisy)
    
    # True score: -noise / sigma
    score_true = -noise / sigma
    
    # MSE loss
    loss = 0.5 * ((score_pred - score_true) ** 2).sum(dim=(1, 2, 3)).mean()
    
    return loss

print("Loss function defined")

3. Langevin Dynamics SamplingΒΆ

Update Rule:ΒΆ

\[x_{t+1} = x_t + \epsilon \nabla_x \log p(x_t) + \sqrt{2\epsilon} z_t\]

where \(z_t \sim \mathcal{N}(0, I)\).

@torch.no_grad()
def langevin_dynamics(score_net, x_init, n_steps=100, step_size=1e-4):
    """Sample using Langevin dynamics."""
    x = x_init.clone()
    
    for t in range(n_steps):
        # Compute score
        score = score_net(x)
        
        # Langevin update
        noise = torch.randn_like(x)
        x = x + step_size * score + np.sqrt(2 * step_size) * noise
        
        # Clip to valid range
        x = torch.clamp(x, 0, 1)
    
    return x

print("Langevin dynamics defined")

Train Score NetworkΒΆ

The score network is trained across multiple noise levels \(\sigma_1 > \sigma_2 > \cdots > \sigma_L\), spanning from large noise (easy denoising, captures global structure) to small noise (precise denoising, captures fine details). At each training step, a noise level is randomly sampled, noise is added to a batch of data, and the network learns to predict the score at that noise level. The noise-conditional architecture (typically a U-Net with the noise level as an additional input) must handle the full range of noise magnitudes. Monitoring the loss per noise level helps diagnose whether the model struggles more with coarse or fine structure.

# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)

# Model
sigma = 0.1
score_net = ScoreNet(sigma).to(device)
optimizer = torch.optim.Adam(score_net.parameters(), lr=1e-3)

# Train
losses = []
for epoch in range(5):
    epoch_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        
        loss = denoising_score_matching_loss(score_net, x, sigma)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

Generate SamplesΒΆ

Sampling from a trained score model uses Langevin dynamics: starting from random noise, iteratively move in the direction of the score (toward higher density) with a small step size and added stochastic noise: \(x_{t+1} = x_t + \frac{\epsilon}{2} s_\theta(x_t, \sigma) + \sqrt{\epsilon}\, z\), where \(z \sim \mathcal{N}(0, I)\). Annealed Langevin dynamics starts with the largest noise level (where the score landscape is smooth and easy to follow) and gradually decreases to the smallest (where the score is sharp and detailed). This multi-scale sampling process produces high-quality samples and is the precursor to the diffusion model framework that now dominates image generation.

# Generate samples
score_net.eval()
n_samples = 16

# Random initialization
x_init = torch.rand(n_samples, 1, 28, 28).to(device)

# Sample using Langevin dynamics
samples = langevin_dynamics(score_net, x_init, n_steps=200, step_size=5e-5)

# Visualize
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
    ax = axes[i // 4, i % 4]
    ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')

plt.suptitle('Score-Based Generated Samples', fontsize=12)
plt.tight_layout()
plt.show()

6. Annealed Langevin DynamicsΒΆ

Use multiple noise levels:

\[\sigma_1 > \sigma_2 > ... > \sigma_L\]

Train score network for each level, sample progressively.

class MultiScaleScoreNet(nn.Module):
    """Score network with noise conditioning."""
    
    def __init__(self, n_sigmas):
        super().__init__()
        self.n_sigmas = n_sigmas
        
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 1, 3, padding=1)
        )
        
        # Sigma embedding
        self.sigma_embed = nn.Embedding(n_sigmas, 64)
    
    def forward(self, x, sigma_idx):
        # Simple version: just use different forward passes
        return self.net(x)

@torch.no_grad()
def annealed_langevin_dynamics(score_net, x_init, sigmas, n_steps_each=100):
    """Annealed Langevin dynamics."""
    x = x_init.clone()
    
    for i, sigma in enumerate(sigmas):
        step_size = sigma ** 2 * 0.01
        
        for t in range(n_steps_each):
            score = score_net(x)
            noise = torch.randn_like(x)
            x = x + step_size * score + np.sqrt(2 * step_size) * noise
            x = torch.clamp(x, 0, 1)
    
    return x

print("Annealed dynamics defined")

SummaryΒΆ

Score-Based Models:ΒΆ

Key Concepts:

  1. Learn score function \(\nabla_x \log p(x)\)

  2. Sample via Langevin dynamics

  3. Denoising score matching for training

  4. Annealing for multi-scale generation

Connection to Diffusion:ΒΆ

  • Score matching β‰ˆ noise prediction

  • Langevin dynamics β‰ˆ reverse diffusion

  • Unified framework (Song et al.)

Advantages:ΒΆ

  • No adversarial training

  • Tractable likelihood

  • High sample quality

  • Flexible architectures

Applications:ΒΆ

  • Image generation

  • Inpainting

  • Super-resolution

  • Inverse problems

Advanced Score-Based Generative Models TheoryΒΆ

1. Mathematical FoundationsΒΆ

Score Function:

The score function is the gradient of the log probability density: $\(s(x) = \nabla_x \log p(x)\)$

Key Insight: The score points toward regions of higher probability density.

Advantages over density modeling:

  • No partition function needed: \(p(x) = \frac{\exp(-E(x))}{Z}\) where \(Z = \int \exp(-E(x))dx\)

  • Score: \(\nabla_x \log p(x) = -\nabla_x E(x)\) (partition function cancels!)

  • Avoid intractable normalization

Sampling via Langevin Dynamics:

Given score \(s(x) = \nabla_x \log p(x)\), sample from \(p(x)\) using: $\(x_{t+1} = x_t + \frac{\epsilon}{2}s(x_t) + \sqrt{\epsilon}z_t, \quad z_t \sim \mathcal{N}(0, I)\)$

Convergence: As \(\epsilon \to 0\) and \(T \to \infty\), \(x_T\) converges to sample from \(p(x)\).

Langevin MCMC Theorem:

The dynamics \(dx = \nabla_x \log p(x)dt + \sqrt{2}dw\) has invariant distribution \(p(x)\).

Discretization with step size \(\epsilon\) requires mixing time \(O(1/\epsilon)\) for convergence.

2. Score Matching ObjectivesΒΆ

Problem: We don’t know \(p(x)\), so we can’t compute \(\nabla_x \log p(x)\) directly.

Explicit Score Matching (HyvΓ€rinen, 2005):

Minimize: $\(\mathcal{L}_{ESM} = \mathbb{E}_{p(x)}\left[\frac{1}{2}\|s_\theta(x) - \nabla_x \log p(x)\|^2\right]\)$

Issue: Still requires \(\nabla_x \log p(x)\)!

Integration by Parts:

Under smoothness assumptions: $\(\mathcal{L}_{ESM} = \mathbb{E}_{p(x)}\left[\text{tr}(\nabla_x s_\theta(x)) + \frac{1}{2}\|s_\theta(x)\|^2\right] + \text{const}\)$

Now tractable! But computing \(\text{tr}(\nabla_x s_\theta(x))\) (Jacobian trace) is expensive.

Denoising Score Matching (Vincent, 2011):

More efficient alternative. Perturb data with noise: $\(q_\sigma(\tilde{x}|x) = \mathcal{N}(\tilde{x}; x, \sigma^2I)\)$

Then: $\(\mathcal{L}_{DSM} = \mathbb{E}_{p(x)}\mathbb{E}_{q_\sigma(\tilde{x}|x)}\left[\frac{1}{2}\left\|s_\theta(\tilde{x}) - \nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x)\right\|^2\right]\)$

Gradient of perturbed distribution: $\(\nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\)$

If \(\tilde{x} = x + \sigma\epsilon\) where \(\epsilon \sim \mathcal{N}(0,I)\): $\(\nabla_{\tilde{x}}\log q_\sigma(\tilde{x}|x) = -\frac{\epsilon}{\sigma}\)$

Simplified objective: $\(\mathcal{L}_{DSM} = \mathbb{E}_{x \sim p(x), \epsilon \sim \mathcal{N}(0,I)}\left[\frac{1}{2}\left\|s_\theta(x+\sigma\epsilon) + \frac{\epsilon}{\sigma}\right\|^2\right]\)$

Equivalence Theorem (Vincent, 2011):

Under mild conditions: $\(\nabla_\theta \mathcal{L}_{DSM}(\theta, \sigma) = \nabla_\theta \mathcal{L}_{ESM}(\theta) + O(\sigma^2)\)$

As \(\sigma \to 0\), denoising score matching β‰ˆ explicit score matching.

Sliced Score Matching (Song et al., 2019):

Alternative that avoids Jacobian trace using random projections: $\(\mathcal{L}_{SSM} = \mathbb{E}_{p(x), p(v)}\left[\frac{1}{2}v^T\nabla_x s_\theta(x)v + v^Ts_\theta(x) + \frac{1}{2}\|s_\theta(x)\|^2\right]\)$

where \(v \sim \mathcal{N}(0, I)\) is random projection direction.

Key advantage: Single backpropagation, no Jacobian computation.

3. Noise Conditional Score Networks (NCSN)ΒΆ

Motivation (Song & Ermon, 2019):

Single noise level \(\sigma\) problematic:

  • Low \(\sigma\): Score accurate near data manifold, but inaccurate in low-density regions

  • High \(\sigma\): Score accurate everywhere, but blurs data manifold

Solution: Use multiple noise levels!

Noise Schedule: $\(\sigma_1 > \sigma_2 > \cdots > \sigma_L\)$

Typical: geometric sequence \(\sigma_i = \sigma_1 \cdot (\sigma_L/\sigma_1)^{(i-1)/(L-1)}\)

Example: \(\sigma_1 = 1.0\), \(\sigma_L = 0.01\), \(L = 10\)

Conditional Score Network:

Learn \(s_\theta(x, \sigma_i) \approx \nabla_x \log p_{\sigma_i}(x)\) for all noise levels simultaneously.

Training Objective: $\(\mathcal{L}_{NCSN} = \sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(x), \mathcal{N}(\epsilon;0,I)}\left[\left\|s_\theta(x+\sigma_i\epsilon, \sigma_i) + \frac{\epsilon}{\sigma_i}\right\|^2\right]\)$

Weighting: \(\lambda(\sigma_i) = \sigma_i^2\) (variance weighting)

Annealed Langevin Dynamics Sampling:

Start from high noise, gradually reduce:

x_0 ~ N(0, Οƒ_1Β² I)
for i = 1 to L:
    Ξ±_i = Ξ΅ Β· Οƒ_iΒ² / Οƒ_LΒ²  (adaptive step size)
    for t = 1 to T:
        z_t ~ N(0, I)
        x_t = x_{t-1} + (Ξ±_i/2)Β·s_ΞΈ(x_{t-1}, Οƒ_i) + √α_iΒ·z_t
    x_0 = x_T  (initialize next level)
return x_0

Intuition:

  1. Start from pure noise

  2. High noise: Smooth out manifold, easy to sample

  3. Progressively denoise

  4. Low noise: Refine details on data manifold

4. Score-Based SDE (Continuous Formulation)ΒΆ

Limitation of NCSN: Discrete noise levels \(\{\sigma_i\}\), still many steps.

Solution (Song et al., 2021): Continuous-time formulation via SDE!

Forward Process (Diffusion): $\(dx = f(x,t)dt + g(t)dw\)$

where:

  • \(f(x,t)\): Drift coefficient

  • \(g(t)\): Diffusion coefficient

  • \(w\): Standard Wiener process

Marginal Distribution:

At time \(t\), \(x(t)\) has distribution \(p_t(x)\).

Score of Marginal: $\(s(x,t) = \nabla_x \log p_t(x)\)$

Reverse-Time SDE (Anderson, 1982):

The reverse process is also an SDE: $\(dx = \left[f(x,t) - g(t)^2\nabla_x \log p_t(x)\right]dt + g(t)d\bar{w}\)$

where \(\bar{w}\) is reverse-time Wiener process.

Key Insight: Given score \(s_\theta(x,t) \approx \nabla_x \log p_t(x)\), we can simulate reverse SDE!

Probability Flow ODE:

Deterministic alternative with same marginals: $\(\frac{dx}{dt} = f(x,t) - \frac{1}{2}g(t)^2\nabla_x \log p_t(x)\)$

Advantages:

  • Exact likelihood via change of variables

  • Faster sampling with adaptive ODE solvers

  • Enables latent space manipulation

Training:

Perturb data with \(p_t(x|x_0)\), then: $\(\mathcal{L}_{SDE} = \mathbb{E}_{t \sim U(0,T), x_0 \sim p_0, x_t \sim p_t(x_t|x_0)}\left[\lambda(t)\|\nabla_{x_t}\log p_t(x_t|x_0) - s_\theta(x_t, t)\|^2\right]\)$

5. Variance Preserving (VP) vs. Variance Exploding (VE)ΒΆ

Variance Preserving (VP-SDE): $\(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw\)$

Properties:

  • Marginal: \(p_t(x|x_0) = \mathcal{N}\left(x; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I\right)\)

  • Variance: \(\mathbb{E}[\|x(t)\|^2] \approx \mathbb{E}[\|x(0)\|^2]\) (preserved)

  • Equivalent to DDPM

Variance Exploding (VE-SDE): $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}}dw\)$

Properties:

  • Marginal: \(p_t(x|x_0) = \mathcal{N}(x; x_0, \sigma^2(t)I)\)

  • Variance: \(\mathbb{E}[\|x(t)\|^2] = \mathbb{E}[\|x(0)\|^2] + \sigma^2(t)\) (exploding)

  • Equivalent to NCSN

Unified Framework:

Both are special cases of general SDE. Can interpolate: $\(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)(1-\gamma)}dw\)$

where \(\gamma \in [0,1]\):

  • \(\gamma = 0\): VP-SDE

  • \(\gamma = 1\): VE-SDE

6. Predictor-Corrector SamplingΒΆ

Motivation: Pure reverse SDE can accumulate errors.

Solution: Alternate between:

  1. Predictor: Update with reverse SDE/ODE

  2. Corrector: Langevin MCMC to improve sample quality

Algorithm:

x_T ~ p_T
for i = T-1 down to 0:
    # Predictor: one reverse SDE step
    x_i = Predictor(x_{i+1}, s_ΞΈ, i)
    
    # Corrector: M Langevin steps
    for j = 1 to M:
        x_i = x_i + Ρ·s_θ(x_i, i) + √(2Ρ)·z_j
return x_0

Predictors:

  • Euler-Maruyama

  • Heun (2nd order)

  • Ancestral sampling

Correctors:

  • Langevin dynamics

  • Annealed Langevin dynamics

  • None (pure predictor)

Trade-off:

  • More corrector steps: Better quality, slower

  • Fewer corrector steps: Faster, lower quality

Typical: 1-5 corrector steps per predictor step.

7. Controllable GenerationΒΆ

Conditional Sampling:

Want to sample from \(p(x|y)\) where \(y\) is condition (class, text, etc.).

Bayes Rule: $\(\nabla_x \log p(x|y) = \nabla_x \log p(x) + \nabla_x \log p(y|x)\)$

Conditional Score: $\(s(x,y,t) = s(x,t) + \nabla_x \log p_t(y|x)\)$

Classifier Guidance:

Train classifier \(p_\phi(y|x,t)\) on noisy data: $\(s(x,y,t) = s_\theta(x,t) + w \cdot \nabla_x \log p_\phi(y|x,t)\)$

where \(w\) is guidance weight.

Classifier-Free Guidance:

Joint training on \((x,y)\) and unconditional \(x\): $\(s(x,y,t) = (1+w)s_\theta(x,t,y) - w \cdot s_\theta(x,t,\emptyset)\)$

Imputation:

For missing data \(x_m\), keep observed \(x_o\) fixed: $\(x_m^{t+1} = x_m^t + \epsilon s_\theta([x_o, x_m^t], t) + \sqrt{2\epsilon}z\)$

Applications: Inpainting, super-resolution, compressed sensing.

Inverse Problems:

General form: \(y = A(x) + \eta\) where \(A\) is forward operator (blur, downsample, etc.).

Posterior Sampling: $\(\nabla_x \log p(x|y) \approx \nabla_x \log p(x) - \frac{1}{2\sigma^2}\nabla_x \|y - A(x)\|^2\)$

Update rule: $\(x^{t+1} = x^t + \epsilon\left[s_\theta(x^t,t) - \frac{1}{\sigma^2}\nabla_x\|y-A(x^t)\|^2\right] + \sqrt{2\epsilon}z\)$

8. Likelihood ComputationΒΆ

Probability Flow ODE: $\(\frac{dx}{dt} = f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\)$

Instantaneous Change of Variables: $\(\frac{d\log p_t(x(t))}{dt} = -\text{div}\left(f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\right)\)$

Log-Likelihood: $\(\log p_0(x(0)) = \log p_T(x(T)) - \int_0^T \text{div}\left(f(x,t) - \frac{1}{2}g(t)^2 s_\theta(x,t)\right)dt\)$

Hutchinson’s Trace Estimator:

For divergence \(\text{div}(f) = \text{tr}(\nabla_x f)\): $\(\mathbb{E}_{v \sim \mathcal{N}(0,I)}\left[v^T\nabla_x f \cdot v\right] = \text{div}(f)\)$

Algorithm:

  1. Sample \(v \sim \mathcal{N}(0,I)\)

  2. Compute \(v^T\nabla_x(f \cdot v)\) via vector-Jacobian product (single backprop!)

  3. Estimate \(\text{div}(f)\)

Bits per Dimension: $\(\text{BPD} = -\frac{\log_2 p(x)}{D}\)$

where \(D\) is data dimensionality.

9. Connections to Other ModelsΒΆ

Energy-Based Models:

If \(p(x) = \frac{1}{Z}\exp(-E(x))\), then: $\(\nabla_x \log p(x) = -\nabla_x E(x)\)$

Score matching ≑ learning energy function gradient.

Denoising Autoencoders:

Denoising score matching objective equivalent to training DAE to denoise: $\(\mathcal{L}_{DAE} = \mathbb{E}_{x,\epsilon}\left[\left\|\frac{x - D_\theta(x+\sigma\epsilon)}{\sigma} + \frac{\epsilon}{\sigma}\right\|^2\right]\)$

where \(D_\theta\) is denoising network.

Diffusion Models:

DDPM noise prediction \(\epsilon_\theta(x_t,t)\) related to score: $\(s_\theta(x_t,t) = -\frac{\epsilon_\theta(x_t,t)}{\sqrt{1-\bar{\alpha}_t}}\)$

Score-based models and diffusion models are equivalent under continuous formulation!

Normalizing Flows:

Probability flow ODE converts score-based model into continuous normalizing flow.

VAE:

Diffusion/score models can be viewed as hierarchical VAE with:

  • Latents \(x_1, \ldots, x_T\) same dimension as \(x_0\)

  • Fixed encoder \(q(x_t|x_0)\)

  • Learned decoder \(p_\theta(x_{t-1}|x_t)\)

10. Advanced ArchitecturesΒΆ

U-Net with Time Embedding:

Standard architecture for score networks:

Time t β†’ Sinusoidal Embedding β†’ MLP β†’ time_emb

Encoder:
  x β†’ Conv+GroupNorm+SiLU+time_emb β†’ Skip1
  β†’ Downsample β†’ Conv+GroupNorm+SiLU+time_emb β†’ Skip2
  β†’ Downsample β†’ Conv+GroupNorm+SiLU+time_emb

Decoder:
  β†’ Upsample β†’ Concat(Skip2) β†’ Conv+GroupNorm+SiLU+time_emb
  β†’ Upsample β†’ Concat(Skip1) β†’ Conv+GroupNorm+SiLU+time_emb
  β†’ Conv β†’ output

Attention Mechanisms:

Add self-attention at low resolutions (e.g., 16Γ—16): $\(\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)$

Residual Blocks:

class ResBlock(nn.Module):
    def __init__(self, channels, time_emb_dim):
        self.conv1 = Conv(channels, channels)
        self.time_proj = Linear(time_emb_dim, channels)
        self.conv2 = Conv(channels, channels)
    
    def forward(self, x, t_emb):
        h = self.conv1(x)
        h = h + self.time_proj(t_emb)[:, :, None, None]
        h = self.conv2(h)
        return x + h  # Residual connection

Fourier Features:

For better time embedding: $\(\gamma(t) = [\cos(2\pi\omega_1 t), \sin(2\pi\omega_1 t), \ldots, \cos(2\pi\omega_d t), \sin(2\pi\omega_d t)]\)$

Adaptive Group Normalization:

Condition normalization on time: $\(\text{AdaGN}(h, t) = s_t \cdot \text{GroupNorm}(h) + b_t\)$

where \(s_t, b_t\) are functions of time embedding.

11. Training ImprovementsΒΆ

Importance Sampling for Time:

Weight timesteps by importance: $\(\mathcal{L} = \mathbb{E}_{t \sim p(t)}\left[\frac{\lambda(t)}{p(t)}\mathcal{L}_t\right]\)$

Typical: \(p(t) \propto \lambda(t)\) (importance sampling)

Exponential Moving Average (EMA):

Maintain EMA of parameters for sampling: $\(\theta_{EMA} \leftarrow \gamma \theta_{EMA} + (1-\gamma)\theta\)$

Typical \(\gamma = 0.9999\).

Variance Reduction:

Use antithetic sampling for noise: $\(\mathcal{L} = \frac{1}{2}\left[\mathcal{L}(\epsilon) + \mathcal{L}(-\epsilon)\right]\)$

Consistency Models: (Song et al., 2023)

Self-consistency property: $\(f_\theta(x_t, t) = f_\theta(x_{t'}, t')\)$

for any \(t, t'\) on same trajectory.

Enables one-step generation!

12. Evaluation MetricsΒΆ

Inception Score (IS): $\(IS = \exp\left(\mathbb{E}_x\left[D_{KL}(p(y|x) \| p(y))\right]\right)\)$

FrΓ©chet Inception Distance (FID): $\(FID = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2})\)$

Negative Log-Likelihood (NLL):

Via probability flow ODE + Hutchinson estimator.

Sample Quality vs. Diversity:

  • Precision: Fraction of generated samples in real manifold

  • Recall: Fraction of real manifold covered by generated samples

13. State-of-the-Art ResultsΒΆ

ImageNet 256Γ—256:

  • NCSN++ (Song et al., 2021): FID = 2.20

  • EDM (Karras et al., 2022): FID = 1.79 (SOTA)

Likelihood:

  • VP-SDE on CIFAR-10: 2.99 bits/dim

  • VE-SDE on CIFAR-10: 2.92 bits/dim

Speed:

  • DDPM: 1000 steps

  • DDIM: 50 steps (20Γ— faster)

  • DPM-Solver: 20 steps (50Γ— faster)

  • Consistency models: 1 step (1000Γ— faster!)

14. Applications Beyond ImagesΒΆ

Audio Generation:

  • WaveGrad: High-quality speech synthesis

  • DiffWave: Vocoder for text-to-speech

3D Shapes:

  • Point cloud generation

  • Mesh generation via SDF

Molecular Design:

  • Equivariant diffusion for 3D molecules

  • Protein structure generation (RFdiffusion)

Video:

  • Video diffusion models

  • Frame interpolation

Recommendation Systems:

  • Collaborative filtering with diffusion

15. Practical ConsiderationsΒΆ

Hyperparameters:

Parameter

Typical Value

Notes

\(\sigma_{\min}\)

0.01

Minimum noise level

\(\sigma_{\max}\)

50-100

Maximum noise level

\(L\) (noise levels)

10-1000

More = better, slower

Langevin steps

1-5 per level

Corrector steps

Step size \(\epsilon\)

\(10^{-5}\) to \(10^{-4}\)

Depends on \(\sigma\)

EMA decay

0.9999

For parameter averaging

Noise Schedule:

Geometric: \(\sigma_i = \sigma_{\max} \cdot (\sigma_{\min}/\sigma_{\max})^{i/(L-1)}\)

Computational Cost:

Training:

  • Similar to DDPM

  • Single network for all noise levels

Sampling:

  • NCSN: \(L \times T\) steps (e.g., 10 Γ— 100 = 1000)

  • Faster with ODE solvers

16. Limitations & Open ProblemsΒΆ

Slow Sampling:

  • Still requires many steps

  • Solutions: Distillation, consistency models

  • Open: Single-step score-based generation?

Likelihood Estimation:

  • Requires ODE + divergence computation

  • Expensive for high dimensions

  • Open: Efficient exact likelihood?

Mode Coverage:

  • Better than GANs, but still challenges

  • Open: Provable mode coverage?

Theory:

  • Convergence guarantees for finite steps?

  • Sample complexity bounds?

  • Optimal noise schedules?

17. Key Papers (Chronological)ΒΆ

  1. HyvΓ€rinen, 2005: β€œEstimation of Non-Normalized Statistical Models by Score Matching” (score matching foundation)

  2. Vincent, 2011: β€œA Connection Between Score Matching and Denoising Autoencoders” (denoising score matching)

  3. Song & Ermon, 2019: β€œGenerative Modeling by Estimating Gradients of the Data Distribution” (NCSN)

  4. Song & Ermon, 2020: β€œImproved Techniques for Training Score-Based Generative Models” (NCSN++)

  5. Song et al., 2021: β€œScore-Based Generative Modeling through Stochastic Differential Equations” (continuous SDE formulation)

  6. Song et al., 2021: β€œMaximum Likelihood Training of Score-Based Diffusion Models” (likelihood computation)

  7. Karras et al., 2022: β€œElucidating the Design Space of Diffusion-Based Generative Models” (EDM, SOTA)

  8. Song et al., 2023: β€œConsistency Models” (one-step generation)

18. Comparison: Score-Based vs. DiffusionΒΆ

Aspect

Score-Based

Diffusion (DDPM)

Formulation

\(s_\theta(x,t) \approx \nabla_x \log p_t(x)\)

\(\epsilon_\theta(x_t,t) \approx \epsilon\)

Training

Score matching

Noise prediction

Sampling

Langevin dynamics

Ancestral sampling

Noise

Continuous levels

Discrete timesteps

Framework

Energy-based

Hierarchical VAE

Likelihood

Via ODE (exact)

VLB (lower bound)

Unification:

Under continuous SDE formulation, they are equivalent: $\(s_\theta(x_t,t) = -\frac{\epsilon_\theta(x_t,t)}{\sqrt{1-\bar{\alpha}_t}}\)$

Both can use:

  • Same architectures (U-Net)

  • Same training objectives

  • Same sampling algorithms

When to Use Score-Based:

  • Need exact likelihood

  • Prefer energy-based perspective

  • Continuous-time formulation

  • Flexible noise schedules

When to Use Diffusion:

  • Prefer hierarchical VAE perspective

  • Simpler discrete formulation

  • Extensive existing codebases

"""
Advanced Score-Based Models - Complete Implementations

This cell provides production-ready implementations of:
1. Noise Conditional Score Network (NCSN)
2. Variance Preserving SDE (VP-SDE)
3. Variance Exploding SDE (VE-SDE)
4. Predictor-Corrector Samplers
5. Sliced Score Matching
6. Likelihood Computation via ODE
7. Conditional Generation
8. Complete Training & Sampling Pipeline
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate
from torch.autograd import grad
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Noise Conditional Score Network (NCSN)
# ============================================================================

class ResidualBlock(nn.Module):
    """Residual block with time/noise conditioning"""
    
    def __init__(self, in_channels, out_channels, noise_emb_dim=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        # Noise level embedding projection
        self.noise_proj = nn.Linear(noise_emb_dim, out_channels)
        
        # Skip connection
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
    
    def forward(self, x, noise_emb):
        h = self.conv1(x)
        h = self.norm1(h)
        
        # Add noise conditioning
        h = h + self.noise_proj(noise_emb)[:, :, None, None]
        h = F.relu(h)
        
        h = self.conv2(h)
        h = self.norm2(h)
        
        return F.relu(h + self.skip(x))


class NoiseConditionalScoreNetwork(nn.Module):
    """
    NCSN: Learn score at multiple noise levels
    
    Theory:
    s_ΞΈ(x, Οƒ) β‰ˆ βˆ‡_x log p_Οƒ(x)
    
    Train with: L = Ξ£_i Ξ»(Οƒ_i) E[||s_ΞΈ(x+Οƒ_iΒ·Ξ΅, Οƒ_i) + Ξ΅/Οƒ_i||Β²]
    """
    
    def __init__(self, channels=[32, 64, 128, 256], noise_emb_dim=32, num_res_blocks=2):
        super().__init__()
        self.noise_emb_dim = noise_emb_dim
        
        # Noise level embedding
        self.noise_embed = nn.Sequential(
            nn.Linear(1, noise_emb_dim),
            nn.ReLU(),
            nn.Linear(noise_emb_dim, noise_emb_dim)
        )
        
        # Input projection
        self.input_proj = nn.Conv2d(1, channels[0], 3, padding=1)
        
        # Encoder
        self.encoder_blocks = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        
        for i in range(len(channels) - 1):
            for _ in range(num_res_blocks):
                self.encoder_blocks.append(
                    ResidualBlock(channels[i], channels[i], noise_emb_dim)
                )
            self.downsamples.append(nn.Conv2d(channels[i], channels[i+1], 3, stride=2, padding=1))
        
        # Middle
        self.middle = nn.ModuleList([
            ResidualBlock(channels[-1], channels[-1], noise_emb_dim),
            ResidualBlock(channels[-1], channels[-1], noise_emb_dim)
        ])
        
        # Decoder
        self.upsamples = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        
        for i in range(len(channels) - 1, 0, -1):
            self.upsamples.append(nn.ConvTranspose2d(channels[i], channels[i-1], 4, stride=2, padding=1))
            for _ in range(num_res_blocks):
                self.decoder_blocks.append(
                    ResidualBlock(channels[i-1], channels[i-1], noise_emb_dim)
                )
        
        # Output
        self.output = nn.Conv2d(channels[0], 1, 3, padding=1)
    
    def forward(self, x, sigma):
        """
        Args:
            x: Input (B, 1, H, W)
            sigma: Noise level (B,) or scalar
            
        Returns:
            score: Predicted score βˆ‡_x log p_Οƒ(x)
        """
        # Embed noise level
        if isinstance(sigma, float):
            sigma = torch.full((x.shape[0],), sigma, device=x.device)
        
        noise_emb = self.noise_embed(sigma.view(-1, 1))
        
        # Input
        h = self.input_proj(x)
        
        # Encoder
        for block in self.encoder_blocks[:2]:
            h = block(h, noise_emb)
        h = self.downsamples[0](h)
        
        for block in self.encoder_blocks[2:4]:
            h = block(h, noise_emb)
        h = self.downsamples[1](h)
        
        for block in self.encoder_blocks[4:]:
            h = block(h, noise_emb)
        h = self.downsamples[2](h)
        
        # Middle
        for block in self.middle:
            h = block(h, noise_emb)
        
        # Decoder
        h = self.upsamples[0](h)
        for block in self.decoder_blocks[:2]:
            h = block(h, noise_emb)
        
        h = self.upsamples[1](h)
        for block in self.decoder_blocks[2:4]:
            h = block(h, noise_emb)
        
        h = self.upsamples[2](h)
        for block in self.decoder_blocks[4:]:
            h = block(h, noise_emb)
        
        # Output
        return self.output(h)


# ============================================================================
# SDE Framework
# ============================================================================

class VariancePreservingSDE:
    """
    VP-SDE: dx = -½β(t)x dt + √β(t) dw
    
    Properties:
    - Variance preserved: E[||x(t)||Β²] β‰ˆ E[||x(0)||Β²]
    - Equivalent to DDPM
    """
    
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.T = T
    
    def beta(self, t):
        """Linear schedule Ξ²(t)"""
        return self.beta_min + t * (self.beta_max - self.beta_min)
    
    def mean_coeff(self, t):
        """βˆšΞ±Μ…_t for marginal p_t(x|x_0)"""
        log_mean_coeff = -0.25 * t**2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
        return torch.exp(log_mean_coeff)
    
    def std(self, t):
        """√(1-Ξ±Μ…_t) for marginal p_t(x|x_0)"""
        return torch.sqrt(1.0 - self.mean_coeff(t)**2)
    
    def marginal_prob(self, x0, t):
        """
        Sample from p_t(x|x_0) = N(βˆšΞ±Μ…_tΒ·x_0, (1-Ξ±Μ…_t)I)
        
        Returns:
            x_t: Noisy sample
            std: Standard deviation
        """
        mean_coeff = self.mean_coeff(t)
        std = self.std(t)
        
        noise = torch.randn_like(x0)
        x_t = mean_coeff[:, None, None, None] * x0 + std[:, None, None, None] * noise
        
        return x_t, std, noise
    
    def sde(self, x, t):
        """Forward SDE coefficients"""
        beta_t = self.beta(t)
        drift = -0.5 * beta_t[:, None, None, None] * x
        diffusion = torch.sqrt(beta_t)
        return drift, diffusion
    
    def reverse_sde(self, x, t, score):
        """Reverse SDE coefficients"""
        drift, diffusion = self.sde(x, t)
        drift = drift - diffusion[:, None, None, None]**2 * score
        return drift, diffusion
    
    def ode(self, x, t, score):
        """Probability flow ODE"""
        drift, diffusion = self.sde(x, t)
        drift = drift - 0.5 * diffusion[:, None, None, None]**2 * score
        return drift


class VarianceExplodingSDE:
    """
    VE-SDE: dx = √(dσ²(t)/dt) dw
    
    Properties:
    - Variance exploding: E[||x(t)||Β²] = E[||x(0)||Β²] + σ²(t)
    - Equivalent to NCSN
    """
    
    def __init__(self, sigma_min=0.01, sigma_max=50.0, T=1.0):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.T = T
    
    def sigma(self, t):
        """Noise schedule Οƒ(t)"""
        return self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    
    def marginal_prob(self, x0, t):
        """
        Sample from p_t(x|x_0) = N(x_0, σ²(t)I)
        """
        std = self.sigma(t)
        noise = torch.randn_like(x0)
        x_t = x0 + std[:, None, None, None] * noise
        return x_t, std, noise
    
    def sde(self, x, t):
        """Forward SDE coefficients"""
        sigma_t = self.sigma(t)
        drift = torch.zeros_like(x)
        diffusion = sigma_t * torch.sqrt(torch.tensor(2 * np.log(self.sigma_max / self.sigma_min)))
        return drift, diffusion
    
    def reverse_sde(self, x, t, score):
        """Reverse SDE coefficients"""
        drift, diffusion = self.sde(x, t)
        drift = -diffusion[:, None, None, None]**2 * score
        return drift, diffusion
    
    def ode(self, x, t, score):
        """Probability flow ODE"""
        drift, diffusion = self.sde(x, t)
        drift = -0.5 * diffusion[:, None, None, None]**2 * score
        return drift


# ============================================================================
# Predictor-Corrector Samplers
# ============================================================================

class EulerMaruyamaPredictor:
    """Euler-Maruyama method for reverse SDE"""
    
    def __init__(self, sde, score_fn):
        self.sde = sde
        self.score_fn = score_fn
    
    def step(self, x, t, dt):
        """Single reverse SDE step"""
        score = self.score_fn(x, t)
        drift, diffusion = self.sde.reverse_sde(x, t, score)
        
        # Euler-Maruyama: x_{t-dt} = x_t - drift·dt + diffusion·√dt·z
        x = x - drift * dt
        x = x + diffusion[:, None, None, None] * torch.sqrt(dt) * torch.randn_like(x)
        
        return x


class LangevinCorrector:
    """
    Langevin MCMC corrector
    
    Theory:
    x' = x + Ρ·s_θ(x,t) + √(2Ρ)·z
    
    Refines samples at fixed time
    """
    
    def __init__(self, sde, score_fn, snr=0.16, n_steps=1):
        """
        Args:
            snr: Signal-to-noise ratio for step size
            n_steps: Number of Langevin steps
        """
        self.sde = sde
        self.score_fn = score_fn
        self.snr = snr
        self.n_steps = n_steps
    
    def step(self, x, t):
        """Multiple Langevin refinement steps"""
        for _ in range(self.n_steps):
            score = self.score_fn(x, t)
            
            # Adaptive step size
            noise_norm = torch.norm(score.reshape(score.shape[0], -1), dim=-1).mean()
            step_size = (self.snr * self.sde.std(t)[0] / noise_norm) ** 2
            
            # Langevin step
            x = x + step_size * score + torch.sqrt(2 * step_size) * torch.randn_like(x)
        
        return x


class PredictorCorrectorSampler:
    """
    Combined predictor-corrector sampler
    
    Algorithm:
    for t = T to 0:
        x = Predictor(x, t)  # Move backward in time
        x = Corrector(x, t)  # Refine at current time
    """
    
    def __init__(self, predictor, corrector):
        self.predictor = predictor
        self.corrector = corrector
    
    @torch.no_grad()
    def sample(self, shape, num_steps=1000, device='cpu'):
        """
        Generate samples
        
        Args:
            shape: Output shape (B, C, H, W)
            num_steps: Number of discretization steps
            device: Device
            
        Returns:
            samples: Generated samples
        """
        # Start from noise
        x = torch.randn(shape, device=device)
        
        # Time discretization
        dt = self.predictor.sde.T / num_steps
        
        for i in range(num_steps):
            t = torch.ones(shape[0], device=device) * (1 - i / num_steps) * self.predictor.sde.T
            
            # Predictor step
            x = self.predictor.step(x, t, dt)
            
            # Corrector step
            x = self.corrector.step(x, t)
        
        return x


# ============================================================================
# Sliced Score Matching
# ============================================================================

class SlicedScoreMatching:
    """
    Sliced Score Matching (Song et al., 2019)
    
    Theory:
    L = E[Β½v^Tβˆ‡_x s_ΞΈ(x)v + v^T s_ΞΈ(x) + Β½||s_ΞΈ(x)||Β²]
    
    Advantage: No Jacobian trace computation!
    """
    
    @staticmethod
    def loss(score_fn, x):
        """
        Compute sliced score matching loss
        
        Args:
            score_fn: Score network (must support gradients)
            x: Data samples
            
        Returns:
            loss: Sliced score matching loss
        """
        x = x.requires_grad_(True)
        
        # Random projection direction
        v = torch.randn_like(x)
        
        # Compute score
        score = score_fn(x)
        
        # v^T Β· s_ΞΈ(x)
        v_score = torch.sum(v * score)
        
        # v^T Β· βˆ‡_x s_ΞΈ(x) Β· v (using double backprop)
        grad_v_score = grad(v_score, x, create_graph=True)[0]
        v_grad_v = torch.sum(v * grad_v_score)
        
        # Loss
        loss = 0.5 * v_grad_v + v_score + 0.5 * torch.sum(score ** 2)
        
        return loss / x.shape[0]


# ============================================================================
# Likelihood Computation via ODE
# ============================================================================

class LikelihoodComputer:
    """
    Compute exact likelihood via probability flow ODE
    
    Theory:
    log p_0(x(0)) = log p_T(x(T)) - ∫_0^T div(f) dt
    
    Uses Hutchinson's trace estimator for divergence
    """
    
    def __init__(self, sde, score_fn):
        self.sde = sde
        self.score_fn = score_fn
    
    def divergence_hutchinson(self, x, t, score, v):
        """
        Hutchinson's trace estimator for divergence
        
        E_v[v^T βˆ‡_x(fΒ·v)] = tr(βˆ‡_x f) = div(f)
        """
        # ODE drift
        ode_drift = self.sde.ode(x, t, score)
        
        # v^T Β· ode_drift
        v_ode = torch.sum(v * ode_drift)
        
        # βˆ‡_x(v^T Β· ode_drift) Β· v
        grad_v_ode = grad(v_ode, x, create_graph=True)[0]
        div_estimate = torch.sum(v * grad_v_ode)
        
        return div_estimate
    
    @torch.no_grad()
    def compute_likelihood(self, x0, num_steps=100):
        """
        Compute log-likelihood via ODE integration
        
        Args:
            x0: Data sample (1, C, H, W)
            num_steps: ODE integration steps
            
        Returns:
            log_likelihood: log p(x0)
        """
        # Prior log probability at T
        x_T = torch.randn_like(x0)
        log_p_T = -0.5 * torch.sum(x_T**2) - 0.5 * np.prod(x0.shape[1:]) * np.log(2 * np.pi)
        
        # Integrate divergence from 0 to T
        dt = self.sde.T / num_steps
        divergence_integral = 0
        
        x = x0.clone().requires_grad_(True)
        
        for i in range(num_steps):
            t = torch.ones(1, device=x.device) * (i / num_steps) * self.sde.T
            
            # Random projection for Hutchinson estimator
            v = torch.randn_like(x)
            
            # Score
            score = self.score_fn(x, t)
            
            # Divergence
            div = self.divergence_hutchinson(x, t, score, v)
            divergence_integral += div * dt
            
            # Update x along ODE
            ode_drift = self.sde.ode(x, t, score)
            x = x + ode_drift * dt
            x = x.detach().requires_grad_(True)
        
        # Log-likelihood
        log_likelihood = log_p_T - divergence_integral
        
        return log_likelihood.item()


# ============================================================================
# Training Loop
# ============================================================================

class ScoreBasedTrainer:
    """Complete training pipeline for score-based models"""
    
    def __init__(self, model, sde, optimizer, device='cpu'):
        self.model = model
        self.sde = sde
        self.optimizer = optimizer
        self.device = device
        self.ema = None  # Exponential moving average
    
    def denoising_score_matching_loss(self, x):
        """
        Denoising score matching loss
        
        L = E_t,x,Ξ΅[Ξ»(t)||s_ΞΈ(x_t,t) - βˆ‡log p_t(x_t|x_0)||Β²]
        """
        batch_size = x.shape[0]
        
        # Random time
        t = torch.rand(batch_size, device=self.device) * self.sde.T
        
        # Perturb data
        x_t, std, noise = self.sde.marginal_prob(x, t)
        
        # Predict score
        score_pred = self.model(x_t, std)
        
        # True score: -noise/std
        score_true = -noise / std[:, None, None, None]
        
        # Loss (weighted by stdΒ²)
        loss = torch.mean(std**2 * torch.sum((score_pred - score_true)**2, dim=(1,2,3)))
        
        return loss
    
    def train_step(self, x):
        """Single training step"""
        self.model.train()
        
        loss = self.denoising_score_matching_loss(x)
        
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        self.optimizer.step()
        
        # Update EMA
        if self.ema is not None:
            self.update_ema()
        
        return loss.item()
    
    def update_ema(self, decay=0.9999):
        """Update exponential moving average of parameters"""
        if self.ema is None:
            self.ema = {k: v.clone().detach() for k, v in self.model.state_dict().items()}
        else:
            for k, v in self.model.state_dict().items():
                self.ema[k] = decay * self.ema[k] + (1 - decay) * v


# ============================================================================
# Demonstration
# ============================================================================

print("Advanced Score-Based Models Implemented:")
print("=" * 70)
print("1. NoiseConditionalScoreNetwork - NCSN with residual blocks")
print("2. VariancePreservingSDE - VP-SDE (equivalent to DDPM)")
print("3. VarianceExplodingSDE - VE-SDE (equivalent to NCSN)")
print("4. EulerMaruyamaPredictor - Reverse SDE sampler")
print("5. LangevinCorrector - MCMC refinement")
print("6. PredictorCorrectorSampler - Combined sampling")
print("7. SlicedScoreMatching - Alternative training objective")
print("8. LikelihoodComputer - Exact likelihood via ODE")
print("9. ScoreBasedTrainer - Complete training pipeline")
print("=" * 70)

# Example: Create models
print("\nExample: Model Instantiation")
print("-" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# NCSN model
ncsn = NoiseConditionalScoreNetwork(channels=[32, 64, 128, 256]).to(device)
print(f"NCSN parameters: {sum(p.numel() for p in ncsn.parameters()):,}")

# SDE
vp_sde = VariancePreservingSDE(beta_min=0.1, beta_max=20.0)
ve_sde = VarianceExplodingSDE(sigma_min=0.01, sigma_max=50.0)

print(f"VP-SDE: Ξ²_min={vp_sde.beta_min}, Ξ²_max={vp_sde.beta_max}")
print(f"VE-SDE: Οƒ_min={ve_sde.sigma_min}, Οƒ_max={ve_sde.sigma_max}")

# Predictor-Corrector sampler
def score_fn(x, t):
    """Wrapper for score function"""
    return ncsn(x, vp_sde.std(t))

predictor = EulerMaruyamaPredictor(vp_sde, score_fn)
corrector = LangevinCorrector(vp_sde, score_fn, snr=0.16, n_steps=1)
pc_sampler = PredictorCorrectorSampler(predictor, corrector)

print(f"Predictor-Corrector sampler created")

print("\n" + "=" * 70)
print("Key Advantages:")
print("=" * 70)
print("1. Score-based: No adversarial training, stable optimization")
print("2. NCSN: Multiple noise levels for robust learning")
print("3. SDE formulation: Unified continuous framework")
print("4. Predictor-Corrector: Better sample quality with refinement")
print("5. Exact likelihood: Via probability flow ODE")
print("6. Flexible: VP-SDE (DDPM) or VE-SDE (NCSN)")
print("=" * 70)

print("\n" + "=" * 70)
print("Comparison: VP-SDE vs VE-SDE")
print("=" * 70)

# Compare noise schedules
t_values = torch.linspace(0, 1, 100)
vp_stds = torch.tensor([vp_sde.std(t).item() for t in t_values])
ve_stds = torch.tensor([ve_sde.sigma(t).item() for t in t_values])

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(t_values, vp_stds, label='VP-SDE (variance preserving)', linewidth=2)
ax.plot(t_values, ve_stds, label='VE-SDE (variance exploding)', linewidth=2)
ax.set_xlabel('Time t', fontsize=12)
ax.set_ylabel('Noise Level Οƒ(t)', fontsize=12)
ax.set_title('Noise Schedules: VP-SDE vs VE-SDE', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nVisualization shows different noise evolution strategies!")
print("VP-SDE: Bounded variance, similar to diffusion models")
print("VE-SDE: Unbounded variance, original NCSN formulation")

print("\n" + "=" * 70)
print("When to Use Each Method:")
print("=" * 70)
print("β€’ VP-SDE: When equivalent to DDPM desired, bounded variance")
print("β€’ VE-SDE: When following NCSN, unbounded noise acceptable")
print("β€’ Predictor-Corrector: When highest quality needed, have compute")
print("β€’ Pure Predictor: When speed critical, quality acceptable")
print("β€’ Sliced Score Matching: When memory limited (no Jacobian trace)")
print("β€’ Likelihood Computation: When exact probabilities needed")
print("=" * 70)

Advanced Score-Based Generative Models: Mathematical Foundations and Modern ArchitecturesΒΆ

1. Introduction to Score-Based ModelsΒΆ

Score-based generative models learn to estimate the score function (gradient of the log-density) instead of directly modeling the probability distribution. This approach offers a powerful alternative to traditional generative models by avoiding intractable normalizations and mode collapse.

Core concept: Model the score function $\(s_\theta(x) = \nabla_x \log p(x) = \frac{\nabla_x p(x)}{p(x)}\)$

Key insight: Score doesn’t require knowing the normalizing constant! $\(\nabla_x \log p(x) = \nabla_x \log \frac{p_{\text{unnorm}}(x)}{Z} = \nabla_x \log p_{\text{unnorm}}(x) - \nabla_x \log Z = \nabla_x \log p_{\text{unnorm}}(x)\)$

Generation via Langevin dynamics: $\(x_{t+1} = x_t + \frac{\epsilon}{2} s_\theta(x_t) + \sqrt{\epsilon} \, z_t, \quad z_t \sim \mathcal{N}(0, I)\)$

Starting from noise \(x_0 \sim \mathcal{N}(0, I)\), this converges to \(p(x)\) as \(\epsilon \to 0\) and \(T \to \infty\).

Advantages:

  • Flexible architectures: Any neural network (no invertibility constraints)

  • Mode coverage: Better than GANs (no mode collapse)

  • Training stability: No adversarial dynamics

  • High quality: State-of-the-art results (diffusion models)

Unified view: Score-based models unify:

  • Denoising diffusion models (DDPM)

  • Noise Conditional Score Networks (NCSN)

  • Stochastic differential equations (SDE) framework

2. Score Function and Score MatchingΒΆ

2.1 Score Function DefinitionΒΆ

For probability density \(p(x)\), the score function is: $\(s(x) = \nabla_x \log p(x) = \frac{1}{p(x)} \nabla_x p(x)\)$

Geometric interpretation: Points in direction of increasing probability density.

Properties:

  1. Independent of normalization: \(s(x)\) same for \(p(x)\) and \(c \cdot p(x)\)

  2. Zero at modes: \(\nabla_x p(x) = 0\) at local maxima

  3. Curl-free: \(\nabla \times s(x) = 0\) (gradient field)

Example (Gaussian): $\(p(x) = \mathcal{N}(\mu, \Sigma) \implies s(x) = -\Sigma^{-1}(x - \mu)\)$

2.2 Explicit Score MatchingΒΆ

Goal: Match model score \(s_\theta(x)\) to data score \(s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x)\).

Naive objective: $\(\mathcal{L}_{\text{naive}}(\theta) = \frac{1}{2} \mathbb{E}_{x \sim p_{\text{data}}}[\|s_\theta(x) - \nabla_x \log p_{\text{data}}(x)\|^2]\)$

Problem: \(\nabla_x \log p_{\text{data}}(x)\) unknown!

Solution (HyvΓ€rinen, 2005): Integration by parts gives equivalent objective $\(\mathcal{L}_{\text{ESM}}(\theta) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\frac{1}{2}\|s_\theta(x)\|^2 + \text{tr}(\nabla_x s_\theta(x))\right] + \text{const}\)$

where \(\text{tr}(\nabla_x s_\theta(x)) = \sum_{i=1}^D \frac{\partial s_\theta^i(x)}{\partial x_i}\) is the divergence.

Gradient: $\(\nabla_\theta \mathcal{L}_{\text{ESM}} = \mathbb{E}_{x \sim p_{\text{data}}}\left[s_\theta(x) \nabla_\theta s_\theta(x)^T + \nabla_\theta \text{tr}(\nabla_x s_\theta(x))\right]\)$

Computational cost: \(O(D^2)\) for Jacobian trace (expensive for images).

2.3 Denoising Score Matching (DSM)ΒΆ

Key idea (Vincent, 2011): Add noise to data, then match score of noisy distribution.

Noise perturbation: \(q(x | x_0) = \mathcal{N}(x | x_0, \sigma^2 I)\)

Noisy distribution: \(q(x) = \int q(x | x_0) p_{\text{data}}(x_0) dx_0\)

True score of noisy distribution: $\(\nabla_x \log q(x | x_0) = -\frac{x - x_0}{\sigma^2}\)$

DSM objective: $\(\mathcal{L}_{\text{DSM}}(\theta, \sigma) = \frac{1}{2} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x \sim q(x|x_0)}\left[\left\|s_\theta(x) + \frac{x - x_0}{\sigma^2}\right\|^2\right]\)$

Advantages:

  • \(O(D)\) complexity (no Jacobian)

  • Fully differentiable

  • Equivalent to explicit score matching under certain conditions

Implementation:

x_0 ~ p_data
noise = Οƒ * Ξ΅, Ξ΅ ~ N(0, I)
x = x_0 + noise
loss = ||s_ΞΈ(x) - (-noise/σ²)||Β²

3. Multi-Scale Score Matching (Noise Conditioning)ΒΆ

3.1 MotivationΒΆ

Problem with single noise level:

  • Low noise (\(\sigma\) small): Score accurate near data, but unstable in low-density regions

  • High noise (\(\sigma\) large): Score stable everywhere, but data structure lost

Solution: Train score network at multiple noise levels \(\{\sigma_1, \ldots, \sigma_L\}\) where \(\sigma_1 > \sigma_2 > \cdots > \sigma_L\).

3.2 Noise Conditional Score Networks (NCSN)ΒΆ

Score network: \(s_\theta(x, \sigma): \mathbb{R}^D \times \mathbb{R}_+ \to \mathbb{R}^D\)

Objective: $\(\mathcal{L}_{\text{NCSN}}(\theta) = \sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x \sim \mathcal{N}(x_0, \sigma_i^2 I)}\left[\left\|s_\theta(x, \sigma_i) + \frac{x - x_0}{\sigma_i^2}\right\|^2\right]\)$

Weighting: \(\lambda(\sigma_i) = \sigma_i^2\) (variance-weighted) or uniform.

Noise schedule: Geometric progression $\(\sigma_i = \sigma_{\max} \cdot \left(\frac{\sigma_{\min}}{\sigma_{\max}}\right)^{(i-1)/(L-1)}, \quad i = 1, \ldots, L\)$

Typical: \(\sigma_{\max} = 50\), \(\sigma_{\min} = 0.01\), \(L = 10\) for CIFAR-10.

3.3 Annealed Langevin DynamicsΒΆ

Sampling: Gradually decrease noise level during Langevin dynamics.

Algorithm:

x_L ~ N(0, Οƒ_1Β² I)  # Initialize from largest noise
for i = 1 to L:
    Ξ±_i = Ξ΅ Β· Οƒ_iΒ² / Οƒ_LΒ²  # Adaptive step size
    for t = 1 to T:
        z_t ~ N(0, I)
        x ← x + (Ξ±_i/2) s_ΞΈ(x, Οƒ_i) + √α_i z_t
    x_{i-1} ← x
return x_0

Intuition:

  • Large \(\sigma\): Explore global structure

  • Small \(\sigma\): Refine local details

Convergence: Proven for smooth densities and sufficient annealing steps.

4. Score-Based SDEsΒΆ

4.1 Forward SDE (Diffusion Process)ΒΆ

Continuous-time view: Gradually perturb data with SDE

\[dx = f(x, t) dt + g(t) dw\]

where:

  • \(f(x, t)\): Drift coefficient

  • \(g(t)\): Diffusion coefficient

  • \(w\): Brownian motion

Variance-Preserving (VP) SDE: $\(dx = -\frac{1}{2}\beta(t) x \, dt + \sqrt{\beta(t)} \, dw\)$

Variance-Exploding (VE) SDE: $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} \, dw\)$

sub-VP SDE: $\(dx = -\frac{1}{2}\beta(t) x \, dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s) ds})} \, dw\)$

4.2 Reverse SDEΒΆ

Anderson (1982) theorem: Reverse-time SDE is

\[dx = \left[f(x, t) - g(t)^2 \nabla_x \log p_t(x)\right] dt + g(t) d\bar{w}\]

where \(\bar{w}\) is reverse-time Brownian motion.

Key insight: If we know score \(\nabla_x \log p_t(x)\) at all times \(t\), we can reverse the diffusion!

Score approximation: Replace \(\nabla_x \log p_t(x)\) with \(s_\theta(x, t)\)

\[dx = \left[f(x, t) - g(t)^2 s_\theta(x, t)\right] dt + g(t) d\bar{w}\]

4.3 Probability Flow ODEΒΆ

Alternative: Reverse process as ODE (deterministic, no noise)

\[\frac{dx}{dt} = f(x, t) - \frac{1}{2} g(t)^2 \nabla_x \log p_t(x)\]

Approximation: $\(\frac{dx}{dt} = f(x, t) - \frac{1}{2} g(t)^2 s_\theta(x, t)\)$

Properties:

  • Same marginals as reverse SDE: \(p_t(x)\) identical

  • Deterministic trajectories (useful for inversion, interpolation)

  • Faster sampling (adaptive ODE solvers)

Example (VP-SDE): $\(\frac{dx}{dt} = -\frac{1}{2}\beta(t)[x + s_\theta(x, t)]\)$

5. Training ObjectivesΒΆ

5.1 Denoising Score Matching with TimeΒΆ

Continuous-time objective: $\(\mathcal{L}_{\text{DSM}}(\theta) = \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x_t \sim p_t(x_t|x_0)}\left[\lambda(t) \left\|s_\theta(x_t, t) - \nabla_{x_t} \log p_t(x_t | x_0)\right\|^2\right]\)$

where \(p_t(x_t | x_0)\) is the transition kernel of the forward SDE.

VP-SDE: \(p_t(x_t | x_0) = \mathcal{N}(x_t | \alpha_t x_0, \beta_t^2 I)\)

  • \(\alpha_t = e^{-\frac{1}{2}\int_0^t \beta(s) ds}\)

  • \(\beta_t^2 = 1 - e^{-\int_0^t \beta(s) ds}\)

True score: \(\nabla_{x_t} \log p_t(x_t | x_0) = -\frac{x_t - \alpha_t x_0}{\beta_t^2}\)

Loss: $\(\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\lambda(t) \left\|s_\theta(x_t, t) + \frac{x_t - \alpha_t x_0}{\beta_t^2}\right\|^2\right]\)$

where \(x_t = \alpha_t x_0 + \beta_t \epsilon\), \(\epsilon \sim \mathcal{N}(0, I)\).

5.2 Noise Prediction ParameterizationΒΆ

Reparameterization: Predict noise instead of score $\(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\beta_t}\)$

Noise prediction objective: $\(\mathcal{L}_{\epsilon}(\theta) = \mathbb{E}_{t, x_0, \epsilon}\left[\lambda(t) \|\epsilon_\theta(x_t, t) - \epsilon\|^2\right]\)$

Weighting:

  • Simple: \(\lambda(t) = 1\)

  • SNR-weighted: \(\lambda(t) = \beta_t^2\)

  • Variance-preserving: \(\lambda(t) = \frac{1}{2\beta_t^2}\)

Equivalence: Score matching ↔ noise prediction (up to weighting).

5.3 Likelihood WeightingΒΆ

Optimal weighting (Song et al., 2021): $\(\lambda(t) = g(t)^2\)$

Corresponds to maximizing ELBO (variational lower bound on likelihood).

Weighted loss: $\(\mathcal{L}_{\text{weighted}}(\theta) = \mathbb{E}_{t \sim p(t)} \left[w(t) \mathbb{E}_{x_0, \epsilon}\left[\|s_\theta(x_t, t) - s_t(x_t; x_0)\|^2\right]\right]\)$

where \(w(t) = g(t)^2 / p(t)\).

6. Sampling AlgorithmsΒΆ

6.1 Predictor-Corrector SamplingΒΆ

Two-step process:

  1. Predictor: Numerical SDE/ODE solver step

  2. Corrector: Langevin MCMC step to refine

Algorithm:

x_T ~ N(0, I)
for i = T-1 to 0:
    # Predictor (e.g., Euler-Maruyama)
    x_i ← x_{i+1} + Ξ”t Β· f(x_{i+1}, t_{i+1}) + βˆšΞ”t Β· g(t_{i+1}) Β· z
    
    # Corrector (Langevin)
    for j = 1 to M:
        x_i ← x_i + Ξ΅ s_ΞΈ(x_i, t_i) + √(2Ξ΅) z
        
return x_0

Advantage: Corrector improves sample quality at each timestep.

6.2 SDE SolversΒΆ

Euler-Maruyama (first-order): $\(x_{t-\Delta t} = x_t + [f(x_t, t) - g(t)^2 s_\theta(x_t, t)] \Delta t + g(t) \sqrt{\Delta t} \, z\)$

Stochastic Runge-Kutta (higher-order):

  • RK2, RK4 adaptations for SDEs

  • Better accuracy with larger timesteps

Adaptive stepping:

  • Error estimation (Richardson extrapolation)

  • Adjust \(\Delta t\) based on local truncation error

6.3 ODE SolversΒΆ

Euler method: $\(x_{t-\Delta t} = x_t + [f(x_t, t) - \frac{1}{2}g(t)^2 s_\theta(x_t, t)] \Delta t\)$

Heun’s method (RK2):

k1 = f(x_t, t) - 0.5 g(t)Β² s_ΞΈ(x_t, t)
x̃ = x_t + Δt k1
k2 = f(x̃, t-Δt) - 0.5 g(t-Δt)² s_θ(x̃, t-Δt)
x_{t-Ξ”t} = x_t + Ξ”t (k1 + k2) / 2

DPM-Solver (Song et al., 2022):

  • Exponential integrator

  • 10-20Γ— faster than Euler

  • ~10-20 steps for high quality

DDIM (Song et al., 2021):

  • Deterministic sampling (ODE)

  • Skip timesteps (acceleration)

  • Invertible (can encode images to latent)

7. Continuous vs. Discrete TimeΒΆ

7.1 DDPM (Discrete-Time Diffusion)ΒΆ

Forward process (Markov chain): $\(q(x_t | x_{t-1}) = \mathcal{N}(x_t | \sqrt{1 - \beta_t} x_{t-1}, \beta_t I)\)$

Cumulative: $\(q(x_t | x_0) = \mathcal{N}(x_t | \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I)\)$

where \(\bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s)\).

Reverse process: $\(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1} | \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$

Mean prediction: $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{1 - \beta_t}}\left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$

Variance: Fixed \(\Sigma_\theta = \beta_t I\) or learned.

7.2 Connection to Score-Based ModelsΒΆ

Score from noise prediction: $\(\nabla_x \log p(x_t) = -\frac{\epsilon_\theta(x_t, t)}{\sqrt{1 - \bar{\alpha}_t}}\)$

Unified framework:

  • DDPM: Discrete-time formulation, variance schedule \(\{\beta_t\}\)

  • Score SDE: Continuous-time formulation, SDE coefficients \(f, g\)

Conversion: DDPM with \(\beta_t = \beta(t) \Delta t\) β†’ VP-SDE as \(\Delta t \to 0\)

8. Architecture DesignΒΆ

8.1 U-Net for Score NetworksΒΆ

Standard choice: U-Net with:

  • Downsampling path (encoder)

  • Upsampling path (decoder)

  • Skip connections (preserve spatial info)

  • Time embedding (condition on \(t\) or \(\sigma\))

Time embedding: Sinusoidal positional encoding $\(\gamma(t) = [\sin(2\pi f_1 t), \cos(2\pi f_1 t), \ldots, \sin(2\pi f_k t), \cos(2\pi f_k t)]\)$

Injected via:

  • FiLM (Feature-wise Linear Modulation): \(\text{FiLM}(h, \gamma) = \gamma_s \odot h + \gamma_b\)

  • Cross-attention

  • Adaptive group normalization

8.2 Attention MechanismsΒΆ

Self-attention layers: Model long-range dependencies

Multi-head self-attention: $\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V\)$

where \(Q = h W_Q\), \(K = h W_K\), \(V = h W_V\).

Placement: Typically at coarser resolutions (16Γ—16, 8Γ—8) due to \(O(n^2)\) cost.

Cross-attention: For conditional generation (text, class labels)

  • Queries from image features

  • Keys, values from conditioning

8.3 Modern ImprovementsΒΆ

EDM (Karras et al., 2022):

  • Preconditioning (input/output scaling)

  • Optimal noise schedule

  • Improved architecture

DiT (Diffusion Transformer):

  • Replace U-Net with Transformer

  • Better scalability (parameter count)

  • State-of-the-art on ImageNet

Latent Diffusion (Stable Diffusion):

  • Apply diffusion in VAE latent space

  • 4-8Γ— faster than pixel-space

  • Maintains quality

9. Conditional GenerationΒΆ

9.1 Class-Conditional GenerationΒΆ

Conditional score: \(s_\theta(x, t, y)\) where \(y\) is class label

Training: Standard DSM with \((x, y)\) pairs

Classifier-Free Guidance: $\(\tilde{s}_\theta(x, t, y) = s_\theta(x, t) + w \cdot [s_\theta(x, t, y) - s_\theta(x, t)]\)$

where:

  • \(s_\theta(x, t)\): Unconditional score (train with \(y = \emptyset\) dropout)

  • \(s_\theta(x, t, y)\): Conditional score

  • \(w\): Guidance weight (e.g., 1.5-7.5)

Effect: Higher \(w\) β†’ stronger conditioning, lower diversity.

9.2 Text-to-Image GenerationΒΆ

Cross-attention conditioning:

  • Text embedding: \(c = \text{CLIP/T5}(\text{prompt})\)

  • Cross-attention: \(\text{Attn}(Q=f(x), K=c, V=c)\)

Classifier-Free Guidance: $\(\tilde{s}_\theta(x, t, c) = s_\theta(x, t) + w \cdot [s_\theta(x, t, c) - s_\theta(x, t)]\)$

Examples:

  • DALL-E 2: CLIP conditioning + diffusion

  • Imagen: T5 text encoder + cascaded diffusion

  • Stable Diffusion: CLIP + latent diffusion

9.3 Image Editing and InpaintingΒΆ

Inpainting: Generate missing region \(x_{\bar{M}}\) given observed \(x_M\)

Method 1 (Repaint): Resample known region at each step

for t = T to 0:
    x_t ← reverse_step(x_t)
    x_t[M] ← forward_step(x_0[M], t)  # Restore known region

Method 2 (Conditioning): Train \(s_\theta(x, t, x_M)\) with masked inputs

Image-to-image: SDEdit (Meng et al., 2021)

  1. Add noise to input: \(x_T = x_{\text{input}} + \sigma_T \epsilon\)

  2. Denoise with score model

  3. Result: Variation of input

10. Likelihood ComputationΒΆ

10.1 Exact Likelihood via ODEΒΆ

Instantaneous change of variables: $\(\log p_0(x_0) = \log p_T(x_T) - \int_0^T \text{div}(f_\theta)(x_t, t) dt\)$

where \(f_\theta(x, t) = f(x, t) - \frac{1}{2}g(t)^2 s_\theta(x, t)\) is the ODE drift.

Divergence: \(\text{div}(f_\theta) = \text{tr}(\nabla_x f_\theta(x, t))\)

Computational cost: \(O(D^2)\) for Jacobian trace.

Hutchinson’s estimator (unbiased): $\(\text{tr}(\nabla_x f) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0, I)}[\epsilon^T (\nabla_x f) \epsilon]\)$

Can be computed via vector-Jacobian product (efficient in autograd).

10.2 Approximate LikelihoodΒΆ

ELBO: Variational lower bound (from DDPM) $\(\log p(x_0) \geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log \frac{p_T(x_T) \prod_{t=1}^T p_\theta(x_{t-1}|x_t)}{q(x_{1:T}|x_0)}\right]\)$

Simplified: Sum of denoising scores across timesteps.

11. Advanced TopicsΒΆ

11.1 Riemannian Score-Based ModelsΒΆ

Manifold data: \(x \in \mathcal{M}\) (e.g., SO(3) rotations, protein structures)

Score on manifold: \(s_\theta: \mathcal{M} \times \mathbb{R}_+ \to T\mathcal{M}\) (tangent bundle)

Riemannian diffusion: $\(dx = s_\theta(x, t) dt + \sqrt{2} \, dW_{\mathcal{M}}\)$

where \(W_{\mathcal{M}}\) is Brownian motion on manifold.

Applications:

  • SE(3) diffusion for protein design

  • SO(3) for 3D rotations

  • Hyperbolic space for hierarchical data

11.2 Consistency ModelsΒΆ

Motivation: Distill score-based models to single-step generators.

Consistency function: \(f: (x_t, t) \mapsto x_0\) satisfies $\(f(x_t, t) = f(x_s, s) \quad \forall s, t\)$

Training:

  • Consistency distillation: Use pretrained score model

  • Consistency training: Train from scratch

Advantage: 1-step generation (1000Γ— faster than DDPM)

Trade-off: Slight quality degradation vs. iterative sampling.

11.3 Flow MatchingΒΆ

Alternative: Directly learn vector field instead of score.

Continuous normalizing flow: $\(\frac{dx}{dt} = v_\theta(x, t)\)$

Training: Regression to optimal transport paths $\(\mathcal{L} = \mathbb{E}_{t, x_0, x_1}\left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\)$

where \(x_t = (1-t)x_0 + t x_1\) (linear interpolation).

Relation to score models: Flow matching β‰ˆ score matching with specific drift.

12. Theoretical AnalysisΒΆ

12.1 Score Matching ConsistencyΒΆ

Theorem (HyvΓ€rinen, 2005): Explicit score matching recovers true distribution.

If \(\theta^* = \arg\min_\theta \mathcal{L}_{\text{ESM}}(\theta)\), then \(s_{\theta^*}(x) = \nabla_x \log p_{\text{data}}(x)\) (under regularity conditions).

Proof sketch:

  • Objective minimized when \(s_\theta(x) = \nabla_x \log p_{\text{data}}(x)\) for all \(x\)

  • Integration gives \(\log p_\theta(x) = \log p_{\text{data}}(x) + c\)

  • Normalization constraint β†’ \(c = 0\)

12.2 Convergence of Langevin DynamicsΒΆ

Theorem: Under log-concavity and smoothness, Langevin dynamics converges exponentially to target distribution.

\[W_2(p_t, p_{\infty}) \leq C e^{-\lambda t}\]

where \(W_2\) is 2-Wasserstein distance.

Practice: Non-log-concave data, finite steps β†’ approximate samples.

12.3 Sample ComplexityΒΆ

Theorem (Song et al., 2021): Score matching sample complexity is \(\tilde{O}(d^2 / \epsilon^2)\) for \(\epsilon\)-accurate score estimation in dimension \(d\).

Comparison:

  • GANs: Potentially exponential in \(d\) (mode collapse)

  • Normalizing flows: \(O(d)\) architectures only

  • Score models: Polynomial, flexible architectures

13. Training TechniquesΒΆ

13.1 Noise Schedule DesignΒΆ

Linear schedule (DDPM): $\(\beta_t = \beta_{\min} + (\beta_{\max} - \beta_{\min}) \frac{t}{T}\)$

Cosine schedule (Improved DDPM): $\(\bar{\alpha}_t = \frac{f(t)}{f(0)}, \quad f(t) = \cos\left(\frac{t/T + s}{1 + s} \cdot \frac{\pi}{2}\right)^2\)$

Learned schedule: Optimize \(\beta_t\) or \(\sigma_t\) via variational bound.

EDM schedule (Karras et al., 2022): $\(\sigma_i = \left(\sigma_{\max}^{1/\rho} + \frac{i-1}{N-1}(\sigma_{\min}^{1/\rho} - \sigma_{\max}^{1/\rho})\right)^\rho\)$

with \(\rho = 7\), \(\sigma_{\max} = 80\), \(\sigma_{\min} = 0.002\).

13.2 EMA (Exponential Moving Average)ΒΆ

Maintain EMA of parameters: $\(\theta_{\text{EMA}} \leftarrow \gamma \theta_{\text{EMA}} + (1 - \gamma) \theta\)$

Typical: \(\gamma = 0.9999\)

Benefit: Smoother model, better generation quality.

13.3 Mixed Precision TrainingΒΆ

FP16 training: Reduce memory, accelerate training.

Loss scaling: Prevent gradient underflow $\(\mathcal{L}_{\text{scaled}} = \text{scale} \cdot \mathcal{L}\)$

Gradient clipping: Stabilize training $\(g \leftarrow \frac{g}{\max(1, \|g\| / \text{clip\_norm})}\)$

14. Applications and ResultsΒΆ

14.1 Image GenerationΒΆ

CIFAR-10:

  • NCSN++: FID 2.2

  • DDPM++: FID 2.78

  • EDM: FID 1.97

ImageNet 256Γ—256:

  • Improved DDPM: FID 10.94

  • CDM (Cascaded): FID 4.88

  • DiT-XL/2: FID 2.27

High-resolution:

  • Latent Diffusion (Stable Diffusion): 512Γ—512, FID 12.63 on COCO

  • DALL-E 2: 1024Γ—1024

  • Imagen: 1024Γ—1024, state-of-art text-to-image

14.2 Audio and SpeechΒΆ

WaveGrad: Raw audio waveform generation

  • 24 kHz audio

  • FID competitive with GANs

DiffWave: Vocoder (mel-spectrogram to waveform)

  • MOS (Mean Opinion Score) 4.4+ (near human quality)

Grad-TTS: Text-to-speech

  • End-to-end diffusion

  • Natural prosody

14.3 Video GenerationΒΆ

Video Diffusion Models (Ho et al., 2022):

  • Factorized space-time U-Net

  • 16 frames @ 64Γ—64, FVD 481

Imagen Video: High-resolution text-to-video

  • Cascaded diffusion (24 β†’ 128 β†’ 1024)

  • Temporal attention

14.4 3D and MoleculesΒΆ

Point-E (OpenAI): Text-to-3D point clouds

  • Diffusion on point clouds

  • 1-2 minutes per shape

DreamFusion: Text-to-3D via distillation

  • Score Distillation Sampling (SDS)

  • Neural radiance fields (NeRF)

Molecule generation:

  • EDM for molecular graphs

  • SE(3)-equivariant diffusion

15. ComparisonsΒΆ

15.1 Score-Based vs. GANsΒΆ

Aspect

Score-Based

GANs

Training

Stable

Unstable (mode collapse)

Sampling

Slow (iterative)

Fast (single pass)

Diversity

High

Can be limited

Likelihood

Tractable (ODE)

Intractable

Architecture

Flexible

Generator + Discriminator

Quality

State-of-art

High

Recommendation: Score models for quality/diversity, GANs for speed.

15.2 Score-Based vs. VAEsΒΆ

Aspect

Score-Based

VAEs

Latent space

No explicit latent

Structured latent

Likelihood

Exact (ODE)

ELBO (approximate)

Sampling

Slow

Fast

Interpolation

ODE trajectories

Latent interpolation

Quality

Higher

Moderate (blurry)

Recommendation: Score models for quality, VAEs for latent manipulation.

15.3 Continuous vs. Discrete TimeΒΆ

Aspect

SDE (Continuous)

DDPM (Discrete)

Formulation

Stochastic differential equation

Markov chain

Flexibility

General SDEs (VP, VE, sub-VP)

Fixed variance schedule

Likelihood

ODE change of variables

ELBO

Solvers

Adaptive SDE/ODE solvers

Ancestral sampling

Theory

Anderson theorem

Variational inference

Unified: Both are equivalent in the limit.

16. Limitations and Future DirectionsΒΆ

16.1 Current LimitationsΒΆ

Sampling speed:

  • 50-1000 steps for high quality (vs. 1 for GANs)

  • ~50 seconds for 256Γ—256 image (vs. 0.1s GAN)

Computational cost:

  • Training: 100-1000 GPU-days for large models

  • Inference: Multiple forward passes

Determinism:

  • Stochastic generation (randomness in sampling)

  • Reproducibility requires seed control

16.2 Acceleration TechniquesΒΆ

Faster samplers:

  • DPM-Solver: 10-20 steps (20-50Γ— speedup)

  • DDIM: Deterministic, skip steps

  • Consistency models: 1-4 steps

Distillation:

  • Progressive distillation (Salimans & Ho, 2022)

  • 4 steps β†’ 2 steps β†’ 1 step

Latent diffusion:

  • Diffuse in compressed latent space

  • 4-8Γ— faster than pixel space

16.3 Future DirectionsΒΆ

Unified frameworks:

  • Connect score models, flows, and ODEs

  • Optimal transport theory

Continuous-time models:

  • Neural ODEs/SDEs

  • Infinitely deep networks

Applications:

  • Scientific computing (PDEs, molecular dynamics)

  • Inverse problems (super-resolution, inpainting, CT reconstruction)

  • Controllable generation (fine-grained control)

Theoretical understanding:

  • Sample complexity bounds

  • Approximation theory

  • Convergence guarantees

17. SummaryΒΆ

Key Concepts:

  1. Score function: \(s(x) = \nabla_x \log p(x)\) independent of normalization

  2. Score matching: Train via denoising (avoid intractable score)

  3. Multi-scale: Noise conditioning for stable training across densities

  4. SDE framework: Continuous-time diffusion (forward) and reverse processes

  5. Sampling: Langevin dynamics, predictor-corrector, ODE solvers

Training Recipe:

  1. Choose SDE (VP, VE, sub-VP)

  2. Design U-Net with time embedding

  3. Train with denoising score matching loss

  4. Use EMA, mixed precision, gradient clipping

Sampling Recipe:

  1. Initialize from noise \(x_T \sim \mathcal{N}(0, I)\)

  2. Run reverse SDE/ODE with score network

  3. Use adaptive solvers (DPM-Solver, Heun’s method)

  4. Apply classifier-free guidance for conditioning

Advantages:

  • State-of-the-art generation quality

  • Training stability (no adversarial dynamics)

  • Mode coverage (no collapse)

  • Exact likelihood computation (ODE)

Disadvantages:

  • Slow sampling (iterative process)

  • Computational cost (training and inference)

  • Memory requirements (U-Net parameters)

Best for:

  • High-quality image/audio/video generation

  • Text-to-image synthesis

  • Likelihood-based modeling

  • Controllable generation

"""
Advanced Score-Based Models - Production Implementation
Comprehensive PyTorch implementations with SDE solvers and modern architectures
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List, Callable
from dataclasses import dataclass

# ===========================
# 1. Time Embeddings
# ===========================

class SinusoidalPosEmb(nn.Module):
    """Sinusoidal positional embeddings for time conditioning"""
    
    def __init__(self, dim: int, max_period: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_period = max_period
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            t: (batch_size,) tensor of timesteps
        Returns:
            (batch_size, dim) embeddings
        """
        half_dim = self.dim // 2
        embeddings = math.log(self.max_period) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        return embeddings

# ===========================
# 2. U-Net Building Blocks
# ===========================

class ResidualBlock(nn.Module):
    """Residual block with time conditioning via FiLM"""
    
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, 
                 dropout: float = 0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        # Time embedding projection (FiLM parameters)
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, 2 * out_channels)  # scale and bias
        )
        
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        self.dropout = nn.Dropout(dropout)
        
        # Shortcut
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W)
            time_emb: (B, time_emb_dim)
        """
        h = self.conv1(F.silu(self.norm1(x)))
        
        # FiLM conditioning: scale and bias from time embedding
        time_out = self.time_mlp(time_emb)
        scale, bias = time_out.chunk(2, dim=1)
        h = h * (1 + scale[:, :, None, None]) + bias[:, :, None, None]
        
        h = self.dropout(h)
        h = self.conv2(F.silu(self.norm2(h)))
        
        return h + self.shortcut(x)


class AttentionBlock(nn.Module):
    """Multi-head self-attention block"""
    
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        assert channels % num_heads == 0, "channels must be divisible by num_heads"
        
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj_out = nn.Conv2d(channels, channels, 1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W)
        """
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)
        
        # Reshape for multi-head attention
        q, k, v = qkv.chunk(3, dim=1)
        q = q.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
        k = k.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
        v = v.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
        
        # Attention: (B, num_heads, HW, HW)
        scale = (C // self.num_heads) ** -0.5
        attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) * scale, dim=-1)
        
        # Apply attention to values
        h = torch.matmul(attn, v)  # (B, num_heads, HW, C//num_heads)
        h = h.transpose(2, 3).reshape(B, C, H, W)
        
        return x + self.proj_out(h)


class Downsample(nn.Module):
    """Downsampling with conv stride 2"""
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class Upsample(nn.Module):
    """Upsampling with nearest neighbor + conv"""
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)


# ===========================
# 3. Score Network (U-Net)
# ===========================

class ScoreNet(nn.Module):
    """
    U-Net architecture for score function estimation
    Outputs score s_ΞΈ(x, t) = βˆ‡_x log p_t(x)
    """
    
    def __init__(self, 
                 in_channels: int = 3,
                 model_channels: int = 128,
                 out_channels: int = 3,
                 num_res_blocks: int = 2,
                 attention_resolutions: List[int] = [16, 8],
                 dropout: float = 0.1,
                 channel_mult: List[int] = [1, 2, 2, 2],
                 num_heads: int = 4):
        super().__init__()
        
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        
        # Time embedding
        time_emb_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            SinusoidalPosEmb(model_channels),
            nn.Linear(model_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # Initial convolution
        self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1)
        
        # Downsampling path
        self.down_blocks = nn.ModuleList()
        self.down_samples = nn.ModuleList()
        
        channels = [model_channels]
        now_channels = model_channels
        
        for level, mult in enumerate(channel_mult):
            out_ch = model_channels * mult
            for _ in range(num_res_blocks):
                layers = [ResidualBlock(now_channels, out_ch, time_emb_dim, dropout)]
                now_channels = out_ch
                
                # Add attention at specified resolutions
                if 2 ** level in attention_resolutions:
                    layers.append(AttentionBlock(now_channels, num_heads))
                
                self.down_blocks.append(nn.ModuleList(layers))
                channels.append(now_channels)
            
            # Downsample (except last level)
            if level != len(channel_mult) - 1:
                self.down_samples.append(Downsample(now_channels))
                channels.append(now_channels)
        
        # Middle blocks
        self.middle_block = nn.ModuleList([
            ResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
            AttentionBlock(now_channels, num_heads),
            ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
        ])
        
        # Upsampling path
        self.up_blocks = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        
        for level, mult in enumerate(reversed(channel_mult)):
            for i in range(num_res_blocks + 1):
                # Skip connection from downsampling
                skip_ch = channels.pop()
                out_ch = model_channels * mult
                
                layers = [ResidualBlock(now_channels + skip_ch, out_ch, time_emb_dim, dropout)]
                now_channels = out_ch
                
                # Add attention at specified resolutions
                if 2 ** (len(channel_mult) - 1 - level) in attention_resolutions:
                    layers.append(AttentionBlock(now_channels, num_heads))
                
                # Upsample (except first iteration and last level)
                if i == num_res_blocks and level != len(channel_mult) - 1:
                    layers.append(Upsample(now_channels))
                
                self.up_blocks.append(nn.ModuleList(layers))
        
        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(8, now_channels),
            nn.SiLU(),
            nn.Conv2d(now_channels, out_channels, 3, padding=1)
        )
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W) noisy input
            t: (B,) timesteps
        Returns:
            score: (B, C, H, W) estimated score
        """
        # Time embedding
        t_emb = self.time_embed(t)
        
        # Initial conv
        h = self.input_conv(x)
        
        # Downsampling with skip connections
        skip_connections = [h]
        
        for i, block in enumerate(self.down_blocks):
            for layer in block:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, t_emb)
                else:
                    h = layer(h)
            skip_connections.append(h)
            
            if i < len(self.down_samples):
                h = self.down_samples[i](h)
                skip_connections.append(h)
        
        # Middle
        for layer in self.middle_block:
            if isinstance(layer, ResidualBlock):
                h = layer(h, t_emb)
            else:
                h = layer(h)
        
        # Upsampling with skip connections
        for block in self.up_blocks:
            skip = skip_connections.pop()
            h = torch.cat([h, skip], dim=1)
            
            for layer in block:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, t_emb)
                else:
                    h = layer(h)
        
        # Output
        return self.output_conv(h)


# ===========================
# 4. Noise Schedules (VP, VE, sub-VP)
# ===========================

@dataclass
class NoiseScheduleConfig:
    """Configuration for noise schedule"""
    schedule_type: str  # 'VP', 'VE', 'sub-VP'
    beta_min: float = 0.1
    beta_max: float = 20.0
    sigma_min: float = 0.01
    sigma_max: float = 50.0
    T: float = 1.0  # Total time


class NoiseSchedule:
    """Noise schedule for SDE"""
    
    def __init__(self, config: NoiseScheduleConfig):
        self.config = config
    
    def beta(self, t: torch.Tensor) -> torch.Tensor:
        """Variance schedule Ξ²(t)"""
        if self.config.schedule_type in ['VP', 'sub-VP']:
            # Linear schedule
            return self.config.beta_min + (self.config.beta_max - self.config.beta_min) * t
        else:
            return torch.zeros_like(t)
    
    def alpha_t(self, t: torch.Tensor) -> torch.Tensor:
        """Cumulative product α_t = exp(-1/2 ∫_0^t β(s) ds)"""
        if self.config.schedule_type == 'VP':
            integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
            return torch.exp(-0.5 * integral)
        elif self.config.schedule_type == 'sub-VP':
            integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
            return torch.exp(-0.5 * integral)
        else:  # VE
            return torch.ones_like(t)
    
    def sigma_t(self, t: torch.Tensor) -> torch.Tensor:
        """Noise level Οƒ(t)"""
        if self.config.schedule_type == 'VP':
            alpha = self.alpha_t(t)
            return torch.sqrt(1 - alpha ** 2)
        elif self.config.schedule_type == 'sub-VP':
            alpha = self.alpha_t(t)
            integral = self.config.beta_min * t + 0.5 * (self.config.beta_max - self.config.beta_min) * t ** 2
            return torch.sqrt(1 - torch.exp(-integral))
        else:  # VE
            # Geometric interpolation
            log_sigma = torch.log(torch.tensor(self.config.sigma_min)) + \
                        t * (torch.log(torch.tensor(self.config.sigma_max)) - torch.log(torch.tensor(self.config.sigma_min)))
            return torch.exp(log_sigma)
    
    def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Marginal distribution p_t(x | x_0) = N(ΞΌ_t, Οƒ_tΒ²I)
        Returns: (mean, std)
        """
        alpha = self.alpha_t(t).view(-1, 1, 1, 1)
        sigma = self.sigma_t(t).view(-1, 1, 1, 1)
        mean = alpha * x_0
        return mean, sigma
    
    def perturbation_kernel(self, x_0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Sample from p_t(x | x_0)
        Returns: (x_t, noise)
        """
        mean, std = self.marginal_prob(x_0, t)
        noise = torch.randn_like(x_0)
        x_t = mean + std * noise
        return x_t, noise


# ===========================
# 5. Denoising Score Matching Trainer
# ===========================

class DenoisingScoreMatching:
    """Denoising Score Matching training"""
    
    def __init__(self, 
                 score_net: ScoreNet,
                 noise_schedule: NoiseSchedule,
                 device: str = 'cuda'):
        self.score_net = score_net.to(device)
        self.noise_schedule = noise_schedule
        self.device = device
    
    def loss(self, x_0: torch.Tensor, t: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
        """
        Denoising score matching loss
        L = E_t E_{x_0} E_{x_t|x_0} [Ξ»(t) ||s_ΞΈ(x_t, t) - βˆ‡ log p_t(x_t | x_0)||Β²]
        """
        batch_size = x_0.shape[0]
        
        # Sample random timesteps
        if t is None:
            t = torch.rand(batch_size, device=self.device) * self.noise_schedule.config.T
        
        # Perturb data: x_t = Ξ±_t x_0 + Οƒ_t Ξ΅
        x_t, noise = self.noise_schedule.perturbation_kernel(x_0, t)
        
        # True score: βˆ‡ log p_t(x_t | x_0) = -(x_t - Ξ±_t x_0) / Οƒ_tΒ²
        alpha = self.noise_schedule.alpha_t(t).view(-1, 1, 1, 1)
        sigma = self.noise_schedule.sigma_t(t).view(-1, 1, 1, 1)
        true_score = -(x_t - alpha * x_0) / (sigma ** 2)
        
        # Predicted score
        pred_score = self.score_net(x_t, t)
        
        # Loss: ||s_ΞΈ - s_true||Β²
        # Equivalent to noise prediction with Ξ»(t) = σ²
        loss = torch.mean((pred_score - true_score) ** 2)
        
        metrics = {
            'loss': loss.item(),
            'mean_t': t.mean().item(),
            'mean_sigma': sigma.mean().item()
        }
        
        return loss, metrics
    
    def train_step(self, x_0: torch.Tensor, optimizer: torch.optim.Optimizer) -> dict:
        """Single training step"""
        optimizer.zero_grad()
        loss, metrics = self.loss(x_0)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.score_net.parameters(), max_norm=1.0)
        
        optimizer.step()
        return metrics


# ===========================
# 6. SDE Samplers
# ===========================

class EulerMaruyamaSampler:
    """Euler-Maruyama SDE solver for reverse-time sampling"""
    
    def __init__(self, 
                 score_net: ScoreNet,
                 noise_schedule: NoiseSchedule,
                 num_steps: int = 1000,
                 device: str = 'cuda'):
        self.score_net = score_net
        self.noise_schedule = noise_schedule
        self.num_steps = num_steps
        self.device = device
    
    @torch.no_grad()
    def sample(self, batch_size: int, img_shape: Tuple[int, int, int]) -> torch.Tensor:
        """
        Sample from p_0 using reverse SDE
        dx = [f(x,t) - g(t)Β² βˆ‡ log p_t(x)] dt + g(t) dwΜ„
        """
        T = self.noise_schedule.config.T
        dt = -T / self.num_steps
        
        # Initialize from noise
        x = torch.randn(batch_size, *img_shape, device=self.device)
        x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
        
        timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
        
        for i in range(self.num_steps):
            t = timesteps[i]
            t_batch = torch.full((batch_size,), t, device=self.device)
            
            # Compute score
            score = self.score_net(x, t_batch)
            
            # SDE coefficients
            beta_t = self.noise_schedule.beta(t_batch)
            
            if self.noise_schedule.config.schedule_type == 'VP':
                # dx = -1/2 β(t) [x + s_θ(x,t)] dt + √β(t) dw
                drift = -0.5 * beta_t.view(-1, 1, 1, 1) * (x + score)
                diffusion = torch.sqrt(beta_t).view(-1, 1, 1, 1)
            elif self.noise_schedule.config.schedule_type == 'VE':
                # dx = Οƒ(t) βˆ‡Οƒ(t) s_ΞΈ(x,t) dt + √(d[σ²(t)]/dt) dw
                sigma = self.noise_schedule.sigma_t(t_batch).view(-1, 1, 1, 1)
                # d[σ²]/dt for VE schedule
                dsigma2_dt = 2 * sigma * (torch.log(torch.tensor(self.noise_schedule.config.sigma_max)) - 
                                          torch.log(torch.tensor(self.noise_schedule.config.sigma_min)))
                drift = sigma ** 2 * score
                diffusion = torch.sqrt(torch.abs(dsigma2_dt))
            else:  # sub-VP
                drift = -0.5 * beta_t.view(-1, 1, 1, 1) * (x + score)
                alpha = self.noise_schedule.alpha_t(t_batch).view(-1, 1, 1, 1)
                diffusion = torch.sqrt(beta_t * (1 - alpha ** 2)).view(-1, 1, 1, 1)
            
            # Euler-Maruyama step
            z = torch.randn_like(x) if i < self.num_steps - 1 else torch.zeros_like(x)
            x = x + drift * dt + diffusion * torch.sqrt(torch.abs(dt)) * z
        
        return x


class ODESampler:
    """ODE solver for deterministic sampling (probability flow ODE)"""
    
    def __init__(self,
                 score_net: ScoreNet,
                 noise_schedule: NoiseSchedule,
                 num_steps: int = 100,
                 device: str = 'cuda'):
        self.score_net = score_net
        self.noise_schedule = noise_schedule
        self.num_steps = num_steps
        self.device = device
    
    @torch.no_grad()
    def sample(self, batch_size: int, img_shape: Tuple[int, int, int], 
               method: str = 'heun') -> torch.Tensor:
        """
        Sample using probability flow ODE
        dx/dt = f(x,t) - 1/2 g(t)Β² βˆ‡ log p_t(x)
        """
        T = self.noise_schedule.config.T
        dt = -T / self.num_steps
        
        # Initialize
        x = torch.randn(batch_size, *img_shape, device=self.device)
        x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
        
        timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
        
        for i in range(self.num_steps):
            t = timesteps[i]
            t_batch = torch.full((batch_size,), t, device=self.device)
            
            if method == 'euler':
                # Euler method
                drift = self._ode_drift(x, t_batch)
                x = x + drift * dt
            
            elif method == 'heun':
                # Heun's method (RK2)
                drift1 = self._ode_drift(x, t_batch)
                x_tilde = x + drift1 * dt
                
                t_next = timesteps[i + 1]
                t_next_batch = torch.full((batch_size,), t_next, device=self.device)
                drift2 = self._ode_drift(x_tilde, t_next_batch)
                
                x = x + 0.5 * (drift1 + drift2) * dt
        
        return x
    
    def _ode_drift(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Compute ODE drift: f(x,t) - 1/2 g(t)Β² s_ΞΈ(x,t)"""
        score = self.score_net(x, t)
        beta_t = self.noise_schedule.beta(t).view(-1, 1, 1, 1)
        
        if self.noise_schedule.config.schedule_type == 'VP':
            # f = -1/2 Ξ²(t) x, gΒ² = Ξ²(t)
            drift = -0.5 * beta_t * (x + score)
        elif self.noise_schedule.config.schedule_type == 'VE':
            sigma = self.noise_schedule.sigma_t(t).view(-1, 1, 1, 1)
            drift = 0.5 * sigma ** 2 * score
        else:  # sub-VP
            drift = -0.5 * beta_t * (x + score)
        
        return drift


# ===========================
# 7. Predictor-Corrector Sampler
# ===========================

class PredictorCorrectorSampler:
    """Predictor-Corrector sampling (combines SDE step + Langevin refinement)"""
    
    def __init__(self,
                 score_net: ScoreNet,
                 noise_schedule: NoiseSchedule,
                 num_steps: int = 1000,
                 num_corrector_steps: int = 1,
                 snr: float = 0.16,
                 device: str = 'cuda'):
        self.score_net = score_net
        self.noise_schedule = noise_schedule
        self.num_steps = num_steps
        self.num_corrector_steps = num_corrector_steps
        self.snr = snr  # Signal-to-noise ratio for Langevin
        self.device = device
    
    @torch.no_grad()
    def sample(self, batch_size: int, img_shape: Tuple[int, int, int]) -> torch.Tensor:
        """Predictor-Corrector sampling"""
        T = self.noise_schedule.config.T
        dt = -T / self.num_steps
        
        x = torch.randn(batch_size, *img_shape, device=self.device)
        x = x * self.noise_schedule.sigma_t(torch.tensor([T], device=self.device))
        
        timesteps = torch.linspace(T, 0, self.num_steps + 1, device=self.device)
        
        for i in range(self.num_steps):
            t = timesteps[i]
            t_batch = torch.full((batch_size,), t, device=self.device)
            
            # Predictor: Euler-Maruyama step
            score = self.score_net(x, t_batch)
            beta_t = self.noise_schedule.beta(t_batch).view(-1, 1, 1, 1)
            
            drift = -0.5 * beta_t * (x + score)
            diffusion = torch.sqrt(beta_t)
            
            z = torch.randn_like(x) if i < self.num_steps - 1 else torch.zeros_like(x)
            x = x + drift * dt + diffusion * torch.sqrt(torch.abs(dt)) * z
            
            # Corrector: Langevin MCMC steps
            if i < self.num_steps - 1:
                for _ in range(self.num_corrector_steps):
                    score = self.score_net(x, t_batch)
                    
                    # Step size based on SNR
                    grad_norm = torch.norm(score.reshape(batch_size, -1), dim=-1).mean()
                    noise_norm = math.sqrt(np.prod(img_shape))
                    step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
                    
                    z = torch.randn_like(x)
                    x = x + step_size * score + torch.sqrt(2 * step_size) * z
        
        return x


# ===========================
# 8. Classifier-Free Guidance
# ===========================

class ConditionalScoreNet(nn.Module):
    """Score network with classifier-free guidance"""
    
    def __init__(self, base_net: ScoreNet, num_classes: int, dropout_prob: float = 0.1):
        super().__init__()
        self.base_net = base_net
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob
        
        # Class embedding
        self.class_embed = nn.Embedding(num_classes + 1, base_net.model_channels * 4)  # +1 for unconditional
    
    def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W)
            t: (B,)
            y: (B,) class labels or None for unconditional
        """
        if y is None:
            y = torch.full((x.shape[0],), self.num_classes, device=x.device, dtype=torch.long)
        
        # During training, randomly drop labels for unconditional training
        if self.training:
            mask = torch.rand(x.shape[0], device=x.device) < self.dropout_prob
            y = torch.where(mask, self.num_classes, y)
        
        # Add class embedding to time embedding (simple approach)
        # In practice, inject via FiLM or cross-attention
        # For simplicity, we'll pass through base network
        # (Real implementation would modify ResidualBlock to accept class_emb)
        
        return self.base_net(x, t)
    
    @torch.no_grad()
    def forward_with_guidance(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, 
                             guidance_scale: float = 1.0) -> torch.Tensor:
        """
        Classifier-free guidance:
        s̃_θ(x,t,y) = s_θ(x,t) + w·[s_θ(x,t,y) - s_θ(x,t)]
                     = (1-w)Β·s_ΞΈ(x,t) + wΒ·s_ΞΈ(x,t,y)
        """
        # Conditional score
        cond_score = self.forward(x, t, y)
        
        if guidance_scale == 1.0:
            return cond_score
        
        # Unconditional score
        uncond_score = self.forward(x, t, None)
        
        # Guided score
        return uncond_score + guidance_scale * (cond_score - uncond_score)


# ===========================
# 9. Demo Functions
# ===========================

def demo_time_embedding():
    """Demonstrate sinusoidal time embeddings"""
    print("=" * 50)
    print("Demo: Sinusoidal Time Embedding")
    print("=" * 50)
    
    emb = SinusoidalPosEmb(dim=128)
    t = torch.linspace(0, 1, 10)
    embeddings = emb(t)
    
    print(f"Input timesteps: {t.shape} -> {t[:5].tolist()[:5]}...")
    print(f"Embeddings shape: {embeddings.shape}")
    print(f"First embedding (t=0): {embeddings[0, :8].tolist()}")
    print(f"Last embedding (t=1): {embeddings[-1, :8].tolist()}")
    print()


def demo_score_network():
    """Demonstrate score network forward pass"""
    print("=" * 50)
    print("Demo: Score Network (U-Net)")
    print("=" * 50)
    
    model = ScoreNet(
        in_channels=3,
        model_channels=64,
        out_channels=3,
        num_res_blocks=2,
        attention_resolutions=[16],
        channel_mult=[1, 2, 2]
    )
    
    x = torch.randn(2, 3, 32, 32)
    t = torch.rand(2)
    
    score = model(x, t)
    
    num_params = sum(p.numel() for p in model.parameters())
    
    print(f"Input: x={x.shape}, t={t.shape}")
    print(f"Output score: {score.shape}")
    print(f"Score range: [{score.min():.3f}, {score.max():.3f}]")
    print(f"Total parameters: {num_params:,}")
    print(f"Model size: ~{num_params * 4 / 1024**2:.1f} MB (FP32)")
    print()


def demo_noise_schedules():
    """Demonstrate different noise schedules"""
    print("=" * 50)
    print("Demo: Noise Schedules (VP, VE, sub-VP)")
    print("=" * 50)
    
    t = torch.linspace(0, 1, 11)
    
    for schedule_type in ['VP', 'VE', 'sub-VP']:
        config = NoiseScheduleConfig(schedule_type=schedule_type)
        schedule = NoiseSchedule(config)
        
        print(f"\n{schedule_type} Schedule:")
        print(f"{'t':>6} {'Ξ²(t)':>8} {'Ξ±_t':>8} {'Οƒ_t':>8}")
        print("-" * 32)
        for ti in [0.0, 0.25, 0.5, 0.75, 1.0]:
            ti_tensor = torch.tensor([ti])
            beta = schedule.beta(ti_tensor).item()
            alpha = schedule.alpha_t(ti_tensor).item()
            sigma = schedule.sigma_t(ti_tensor).item()
            print(f"{ti:>6.2f} {beta:>8.3f} {alpha:>8.3f} {sigma:>8.3f}")
    print()


def demo_denoising_score_matching():
    """Demonstrate DSM training step"""
    print("=" * 50)
    print("Demo: Denoising Score Matching")
    print("=" * 50)
    
    device = 'cpu'
    model = ScoreNet(in_channels=3, model_channels=32, out_channels=3, 
                     num_res_blocks=1, channel_mult=[1, 2])
    
    config = NoiseScheduleConfig(schedule_type='VP')
    schedule = NoiseSchedule(config)
    
    dsm = DenoisingScoreMatching(model, schedule, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Dummy training
    x_0 = torch.randn(4, 3, 32, 32)
    
    print("Training for 5 steps...")
    for step in range(5):
        metrics = dsm.train_step(x_0, optimizer)
        print(f"Step {step+1}: loss={metrics['loss']:.4f}, mean_t={metrics['mean_t']:.3f}, "
              f"mean_sigma={metrics['mean_sigma']:.3f}")
    print()


def demo_sampling():
    """Demonstrate sampling with different methods"""
    print("=" * 50)
    print("Demo: Sampling Methods")
    print("=" * 50)
    
    device = 'cpu'
    model = ScoreNet(in_channels=3, model_channels=32, out_channels=3, 
                     num_res_blocks=1, channel_mult=[1, 2])
    model.eval()
    
    config = NoiseScheduleConfig(schedule_type='VP')
    schedule = NoiseSchedule(config)
    
    # Euler-Maruyama (SDE)
    print("\n1. Euler-Maruyama SDE Sampler:")
    sampler_sde = EulerMaruyamaSampler(model, schedule, num_steps=10, device=device)
    samples_sde = sampler_sde.sample(batch_size=2, img_shape=(3, 32, 32))
    print(f"   Generated samples: {samples_sde.shape}")
    print(f"   Sample range: [{samples_sde.min():.2f}, {samples_sde.max():.2f}]")
    
    # ODE (deterministic)
    print("\n2. ODE Sampler (Heun):")
    sampler_ode = ODESampler(model, schedule, num_steps=10, device=device)
    samples_ode = sampler_ode.sample(batch_size=2, img_shape=(3, 32, 32), method='heun')
    print(f"   Generated samples: {samples_ode.shape}")
    print(f"   Sample range: [{samples_ode.min():.2f}, {samples_ode.max():.2f}]")
    
    # Predictor-Corrector
    print("\n3. Predictor-Corrector Sampler:")
    sampler_pc = PredictorCorrectorSampler(model, schedule, num_steps=10, 
                                          num_corrector_steps=1, device=device)
    samples_pc = sampler_pc.sample(batch_size=2, img_shape=(3, 32, 32))
    print(f"   Generated samples: {samples_pc.shape}")
    print(f"   Sample range: [{samples_pc.min():.2f}, {samples_pc.max():.2f}]")
    print()


def demo_classifier_free_guidance():
    """Demonstrate classifier-free guidance"""
    print("=" * 50)
    print("Demo: Classifier-Free Guidance")
    print("=" * 50)
    
    base_net = ScoreNet(in_channels=3, model_channels=32, out_channels=3,
                       num_res_blocks=1, channel_mult=[1, 2])
    
    model = ConditionalScoreNet(base_net, num_classes=10, dropout_prob=0.1)
    model.eval()
    
    x = torch.randn(2, 3, 32, 32)
    t = torch.rand(2)
    y = torch.tensor([3, 7])
    
    # No guidance
    score_1 = model.forward_with_guidance(x, t, y, guidance_scale=1.0)
    print(f"Score (w=1.0): {score_1.shape}, range=[{score_1.min():.3f}, {score_1.max():.3f}]")
    
    # With guidance
    score_2 = model.forward_with_guidance(x, t, y, guidance_scale=2.0)
    print(f"Score (w=2.0): {score_2.shape}, range=[{score_2.min():.3f}, {score_2.max():.3f}]")
    
    # Strong guidance
    score_5 = model.forward_with_guidance(x, t, y, guidance_scale=5.0)
    print(f"Score (w=5.0): {score_5.shape}, range=[{score_5.min():.3f}, {score_5.max():.3f}]")
    
    print("\nInterpretation:")
    print("  Higher guidance scale β†’ stronger conditioning β†’ less diversity")
    print("  Typical values: 1.0 (no guidance) to 7.5 (strong guidance)")
    print()


def print_performance_comparison():
    """Comprehensive performance comparison and decision guide"""
    print("=" * 80)
    print("PERFORMANCE COMPARISON: Score-Based Generative Models")
    print("=" * 80)
    
    # 1. Image Generation Quality
    print("\n1. Image Generation Quality (FID ↓, IS ↑)")
    print("-" * 80)
    data = [
        ("Model", "CIFAR-10 FID", "ImageNet 256 FID", "Notes"),
        ("-" * 30, "-" * 12, "-" * 15, "-" * 30),
        ("NCSN (Song 2019)", "25.3", "N/A", "Early score-based model"),
        ("NCSN++ (Song 2020)", "2.2", "N/A", "Improved architecture"),
        ("DDPM (Ho 2020)", "3.17", "N/A", "Discrete-time diffusion"),
        ("Improved DDPM", "2.9", "10.94", "Cosine schedule + hybrid loss"),
        ("Score SDE (VP)", "2.20", "9.89", "Continuous-time VP-SDE"),
        ("Score SDE (VE)", "2.38", "11.3", "Variance-exploding SDE"),
        ("Score SDE (sub-VP)", "2.61", "9.56", "Sub-variance-preserving"),
        ("EDM (Karras 2022)", "1.97", "N/A", "Optimal preconditioning"),
        ("DiT-XL/2 (Peebles 2023)", "N/A", "2.27", "Diffusion Transformer"),
        ("", "", "", ""),
        ("COMPARISON:", "", "", ""),
        ("StyleGAN2", "2.92", "2.71", "Best GAN (fast sampling)"),
        ("BigGAN-deep", "6.95", "6.95", "Class-conditional GAN"),
        ("VAE (Ξ²=1)", "~80", "N/A", "Blurry reconstructions"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<12} {row[2]:<15} {row[3]:<30}")
    
    # 2. Sampling Speed
    print("\n2. Sampling Speed Comparison")
    print("-" * 80)
    data = [
        ("Method", "Steps", "Time (256Γ—256)", "Quality", "Type"),
        ("-" * 25, "-" * 6, "-" * 15, "-" * 10, "-" * 15),
        ("DDPM (ancestral)", "1000", "~50 sec", "Excellent", "Stochastic"),
        ("DDIM (deterministic)", "50", "~5 sec", "Very Good", "Deterministic"),
        ("DPM-Solver++", "10-20", "~1-2 sec", "Excellent", "Deterministic"),
        ("Consistency Model", "1-4", "~0.1-0.5 sec", "Good", "Deterministic"),
        ("EDM (Heun solver)", "35-79", "~3-8 sec", "SOTA", "Hybrid"),
        ("Latent Diffusion", "50", "~2-3 sec", "Excellent", "Latent space"),
        ("", "", "", "", ""),
        ("GAN (StyleGAN2)", "1", "~0.1 sec", "Excellent", "Single-step"),
        ("VAE", "1", "~0.05 sec", "Moderate", "Single-step"),
        ("Normalizing Flow", "1", "~0.2 sec", "Good", "Invertible"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<6} {row[2]:<15} {row[3]:<10} {row[4]:<15}")
    
    # 3. Training Stability
    print("\n3. Training Stability and Convergence")
    print("-" * 80)
    data = [
        ("Model", "Stability", "Mode Coverage", "Hyperparameter Sensitivity"),
        ("-" * 20, "-" * 12, "-" * 15, "-" * 30),
        ("Score-based/Diffusion", "Very Stable", "Excellent", "Low (robust)"),
        ("GAN", "Unstable", "Poor-Moderate", "Very High (careful tuning)"),
        ("VAE", "Stable", "Good", "Moderate (Ξ²-VAE)"),
        ("Normalizing Flow", "Stable", "Good", "Moderate (architecture)"),
        ("Energy-based", "Moderate", "Good", "High (MCMC sensitive)"),
    ]
    for row in data:
        print(f"{row[0]:<20} {row[1]:<12} {row[2]:<15} {row[3]:<30}")
    
    # 4. Likelihood Evaluation
    print("\n4. Likelihood Evaluation (bits/dim on CIFAR-10, ↓ better)")
    print("-" * 80)
    data = [
        ("Model", "Likelihood", "Method", "Notes"),
        ("-" * 25, "-" * 12, "-" * 20, "-" * 30),
        ("Score SDE (VP)", "2.99", "ODE (exact)", "Continuous normalizing flow"),
        ("DDPM (improved)", "2.94", "ELBO (lower bound)", "Discrete-time variational"),
        ("Glow (Flow)", "3.35", "Exact", "Invertible architecture"),
        ("VAE (PixelCNN++)", "2.92", "ELBO", "Hybrid VAE + autoregressive"),
        ("PixelCNN++", "2.92", "Exact", "Autoregressive"),
        ("", "", "", ""),
        ("Score models", "Comparable", "ODE tractable", "Flexible architecture"),
        ("Note:", "", "", "Likelihood β‰  sample quality"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<12} {row[2]:<20} {row[3]:<30}")
    
    # 5. SDE Type Comparison
    print("\n5. SDE Type Comparison")
    print("-" * 80)
    data = [
        ("SDE Type", "Forward SDE", "Best Use Case", "FID (ImageNet 256)"),
        ("-" * 15, "-" * 35, "-" * 25, "-" * 20),
        ("VP", "dx = -½β(t)x dt + √β(t) dw", "General purpose", "9.89"),
        ("VE", "dx = √(d[σ²]/dt) dw", "High-resolution images", "11.3"),
        ("sub-VP", "dx = -½β(t)x dt + √(β(1-α²)) dw", "Better likelihood", "9.56"),
        ("DDPM-equiv", "Discrete Markov chain", "Simple implementation", "10.94"),
    ]
    for row in data:
        print(f"{row[0]:<15} {row[1]:<35} {row[2]:<25} {row[3]:<20}")
    
    # 6. Training Hyperparameters
    print("\n6. Recommended Training Hyperparameters")
    print("-" * 80)
    data = [
        ("Parameter", "CIFAR-10", "ImageNet 256Γ—256", "Notes"),
        ("-" * 25, "-" * 15, "-" * 20, "-" * 30),
        ("Model channels", "128", "256", "Base width"),
        ("Channel multipliers", "[1,2,2,2]", "[1,1,2,2,4,4]", "Depth scaling"),
        ("Num res blocks", "2-4", "2-3", "Per resolution"),
        ("Attention res", "[16]", "[32,16,8]", "Apply self-attention"),
        ("Dropout", "0.1", "0.1", "Regularization"),
        ("", "", "", ""),
        ("Batch size", "128", "256-2048", "Larger better"),
        ("Learning rate", "2e-4", "1e-4", "Adam optimizer"),
        ("EMA decay", "0.9999", "0.9999", "For sampling"),
        ("Gradient clip", "1.0", "1.0", "Stability"),
        ("", "", "", ""),
        ("Training steps", "800K", "1M-3M", "Until convergence"),
        ("Noise schedule", "Linear", "Cosine/EDM", "Ξ²(t) or Οƒ(t)"),
        ("T (time horizon)", "1.0", "1.0", "Total diffusion time"),
        ("Οƒ_min / Οƒ_max (VE)", "0.01 / 50", "0.002 / 80", "Noise range"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<15} {row[2]:<20} {row[3]:<30}")
    
    # 7. Sampling Configuration
    print("\n7. Sampling Configuration Trade-offs")
    print("-" * 80)
    data = [
        ("Method", "Steps", "Quality", "Speed", "Deterministic", "Use Case"),
        ("-" * 20, "-" * 6, "-" * 10, "-" * 8, "-" * 12, "-" * 25),
        ("Ancestral (DDPM)", "1000", "Excellent", "Slow", "No", "Best quality"),
        ("DDIM", "50-100", "Very Good", "Medium", "Yes", "Fast + invertible"),
        ("DPM-Solver++", "10-20", "Excellent", "Fast", "Yes", "Production (recommended)"),
        ("Euler-Maruyama", "100-500", "Good", "Medium", "No", "General SDE"),
        ("Heun (RK2)", "35-79", "SOTA", "Medium", "Yes", "EDM (best quality)"),
        ("Predictor-Corrector", "100-500", "Excellent", "Slow", "No", "High quality + diversity"),
        ("ODE (prob flow)", "20-100", "Very Good", "Fast", "Yes", "Likelihood / inversion"),
        ("Consistency", "1-4", "Good", "Very Fast", "Yes", "Real-time applications"),
    ]
    for row in data:
        print(f"{row[0]:<20} {row[1]:<6} {row[2]:<10} {row[3]:<8} {row[4]:<12} {row[5]:<25}")
    
    # 8. Guidance Trade-offs
    print("\n8. Classifier-Free Guidance Trade-offs")
    print("-" * 80)
    data = [
        ("Guidance Scale (w)", "Quality", "Diversity", "Condition Strength", "Use Case"),
        ("-" * 18, "-" * 12, "-" * 12, "-" * 18, "-" * 30),
        ("1.0 (no guidance)", "Good", "High", "Weak", "Unconditional / diverse"),
        ("1.5", "Good", "High", "Moderate", "Slight conditioning"),
        ("3.0", "Very Good", "Moderate", "Strong", "Balanced (text-to-image)"),
        ("5.0", "Excellent", "Low", "Very Strong", "Precise control"),
        ("7.5", "SOTA", "Very Low", "Extreme", "DALL-E 2 / Stable Diffusion"),
        ("10.0+", "Saturated", "Minimal", "Maximum", "Overfitting / artifacts"),
    ]
    for row in data:
        print(f"{row[0]:<18} {row[1]:<12} {row[2]:<12} {row[3]:<18} {row[4]:<30}")
    print("Note: Classifier-free guidance requires dropout_prob=0.1 during training")
    
    # 9. Application-Specific Results
    print("\n9. Application-Specific Results")
    print("-" * 80)
    data = [
        ("Application", "Model", "Result", "Notes"),
        ("-" * 25, "-" * 25, "-" * 30, "-" * 30),
        ("Text-to-Image", "DALL-E 2", "Human-quality 1024Γ—1024", "CLIP + diffusion"),
        ("", "Imagen", "SOTA photorealism", "T5 + cascaded diffusion"),
        ("", "Stable Diffusion", "512Γ—512, open-source", "Latent diffusion"),
        ("", "Midjourney v5", "Artistic generation", "Commercial"),
        ("", "", "", ""),
        ("Image Editing", "SDEdit", "Stroke-to-image", "Stochastic editing"),
        ("", "Repaint", "Inpainting", "Resample known region"),
        ("", "DiffEdit", "Text-guided editing", "Mask + diffusion"),
        ("", "", "", ""),
        ("Audio", "WaveGrad", "24kHz waveform generation", "Raw audio diffusion"),
        ("", "DiffWave", "MOS 4.4+ vocoder", "Mel-to-waveform"),
        ("", "Grad-TTS", "Natural TTS", "End-to-end diffusion"),
        ("", "", "", ""),
        ("Video", "Video Diffusion", "16 frames @ 64Γ—64", "Factorized space-time"),
        ("", "Imagen Video", "1024p text-to-video", "Cascaded diffusion"),
        ("", "", "", ""),
        ("3D", "Point-E", "Text-to-3D point clouds", "Diffusion on point clouds"),
        ("", "DreamFusion", "Text-to-NeRF", "Score distillation"),
        ("", "", "", ""),
        ("Science", "Molecule generation", "Valid molecules", "Graph diffusion"),
        ("", "Protein design", "SE(3)-equivariant", "Manifold diffusion"),
        ("", "Inverse problems", "CT reconstruction", "Posterior sampling"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<25} {row[2]:<30} {row[3]:<30}")
    
    # 10. Decision Guide
    print("\n10. DECISION GUIDE: When to Use Score-Based Models")
    print("=" * 80)
    
    print("\nβœ“ USE Score-Based Models When:")
    advantages = [
        "β€’ Need state-of-the-art generation quality (images, audio, video)",
        "β€’ Require diverse samples (avoid mode collapse)",
        "β€’ Want training stability (no adversarial dynamics)",
        "β€’ Need exact likelihood (via ODE)",
        "β€’ Flexible architecture constraints (any network)",
        "β€’ Conditional generation (text-to-image, class-conditional)",
        "β€’ Image editing and inpainting applications",
        "β€’ Scientific computing (inverse problems, molecular generation)",
    ]
    for adv in advantages:
        print(adv)
    
    print("\nβœ— AVOID Score-Based Models When:")
    limitations = [
        "β€’ Real-time generation required (GANs 100Γ— faster)",
        "β€’ Limited computational budget (training costly)",
        "β€’ Single-step sampling mandatory (VAEs, GANs, Flows)",
        "β€’ Exact control over latent space needed (VAEs better)",
        "β€’ Very low-resolution images (overkill, simple models sufficient)",
    ]
    for lim in limitations:
        print(lim)
    
    print("\n→ RECOMMENDED ALTERNATIVE:")
    alternatives = [
        "β€’ Fast sampling β†’ Latent Diffusion (4-8Γ— faster) or DPM-Solver++ (20 steps)",
        "β€’ Real-time β†’ Consistency Models (1-4 steps) or distilled models",
        "β€’ Latent manipulation β†’ VAE or GAN inversion",
        "β€’ Extremely fast β†’ StyleGAN2 or other GANs",
    ]
    for alt in alternatives:
        print(alt)
    
    # 11. Variant Selection Guide
    print("\n11. VARIANT SELECTION GUIDE")
    print("=" * 80)
    data = [
        ("Variant", "When to Use", "Pros", "Cons"),
        ("-" * 20, "-" * 30, "-" * 30, "-" * 30),
        ("DDPM", "Simple implementation", "Well-documented, stable", "Slower sampling (1000 steps)"),
        ("Score SDE (VP)", "General purpose", "Flexible, continuous-time", "Complex formulation"),
        ("Score SDE (VE)", "High-res images", "Better for large images", "Slightly lower FID"),
        ("DDIM", "Fast deterministic", "50 steps, invertible", "Slight quality loss"),
        ("DPM-Solver++", "Production deployment", "10-20 steps, SOTA", "Newer (less tested)"),
        ("EDM", "Best quality", "SOTA FID, optimal", "Complex preconditioning"),
        ("Latent Diffusion", "Fast high-res", "4-8Γ— faster, 512Γ—512+", "Requires VAE training"),
        ("Consistency", "Real-time", "1-4 steps, very fast", "Quality trade-off"),
    ]
    for row in data:
        print(f"{row[0]:<20} {row[1]:<30} {row[2]:<30} {row[3]:<30}")
    
    # 12. Troubleshooting
    print("\n12. TROUBLESHOOTING COMMON ISSUES")
    print("=" * 80)
    data = [
        ("Problem", "Possible Cause", "Solution"),
        ("-" * 30, "-" * 35, "-" * 40),
        ("Poor sample quality", "Insufficient training", "Train longer (1M+ steps)"),
        ("", "Bad noise schedule", "Try cosine or EDM schedule"),
        ("", "Too few sampling steps", "Increase to 50-100 steps"),
        ("", "", ""),
        ("Slow convergence", "Learning rate too low", "Increase LR to 1e-4 or 2e-4"),
        ("", "Small batch size", "Increase to 128+"),
        ("", "", ""),
        ("Training instability", "Exploding gradients", "Gradient clipping (max_norm=1.0)"),
        ("", "Bad initialization", "Use default PyTorch init"),
        ("", "", ""),
        ("Blurry samples", "Score network too small", "Increase model_channels"),
        ("", "Not enough denoising", "More sampling steps"),
        ("", "", ""),
        ("OOM (out of memory)", "Batch size too large", "Reduce batch size or resolution"),
        ("", "Model too large", "Reduce model_channels"),
        ("", "", "Use mixed precision (FP16)"),
        ("", "", ""),
        ("Slow sampling", "Too many steps", "Use DPM-Solver++ (10-20 steps)"),
        ("", "High resolution", "Use latent diffusion"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<35} {row[2]:<40}")
    
    print("\n" + "=" * 80)
    print("Summary: Score-based models offer SOTA quality with stable training,")
    print("but require iterative sampling. Use DPM-Solver++ or latent diffusion")
    print("for practical deployment. Classifier-free guidance boosts quality.")
    print("=" * 80)
    print()


# ===========================
# Run All Demos
# ===========================

if __name__ == "__main__":
    print("\n" + "=" * 80)
    print("SCORE-BASED GENERATIVE MODELS - COMPREHENSIVE IMPLEMENTATION")
    print("=" * 80 + "\n")
    
    demo_time_embedding()
    demo_score_network()
    demo_noise_schedules()
    demo_denoising_score_matching()
    demo_sampling()
    demo_classifier_free_guidance()
    print_performance_comparison()
    
    print("\n" + "=" * 80)
    print("All demos completed successfully!")
    print("=" * 80)