import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Energy-Based ModelsΒΆ
Energy Function:ΒΆ
where \(Z_\theta = \int \exp(-E_\theta(x)) dx\) is partition function.
Maximum Likelihood:ΒΆ
π Reference Materials:
generative_models.pdf - Generative Models
rbm_cd.pdf - Rbm Cd
class EnergyNet(nn.Module):
"""Energy-based model."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
"""Compute energy E(x)."""
return self.net(x).squeeze()
print("EnergyNet defined")
2. Contrastive DivergenceΒΆ
Algorithm:ΒΆ
Sample positive examples from data
Generate negative examples via MCMC
Update: decrease energy of positives, increase energy of negatives
def sample_langevin(energy_net, x_init, n_steps=60, step_size=10.0):
"""SGLD sampling from energy model."""
x = x_init.clone().requires_grad_(True)
for _ in range(n_steps):
# Compute energy gradient
energy = energy_net(x).sum()
grad = torch.autograd.grad(energy, x, create_graph=False)[0]
# Langevin update
x = x - 0.5 * step_size * grad + np.sqrt(step_size) * torch.randn_like(x)
x = torch.clamp(x, 0, 1)
x = x.detach().requires_grad_(True)
return x.detach()
print("Langevin sampling defined")
Training LoopΒΆ
Energy-based models (EBMs) are trained by contrastive divergence: push down the energy of real data samples and push up the energy of generated (negative) samples. The loss is \(\mathcal{L} = \mathbb{E}_{x \sim p_{\text{data}}}[E_\theta(x)] - \mathbb{E}_{x^- \sim p_\theta}[E_\theta(x^-)]\), where negative samples \(x^-\) are obtained via MCMC (typically Langevin dynamics) from the current model. Maintaining a replay buffer of previously generated negatives and initializing MCMC chains from buffer samples (persistent contrastive divergence) dramatically improves training stability and sample quality.
def train_ebm(model, train_loader, n_epochs=5):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
losses = []
for epoch in range(n_epochs):
epoch_loss = 0
for x_pos, _ in train_loader:
x_pos = x_pos.to(device)
# Positive samples (data)
energy_pos = model(x_pos)
# Negative samples (generated)
x_neg = torch.rand_like(x_pos)
x_neg = sample_langevin(model, x_neg, n_steps=60)
energy_neg = model(x_neg)
# Contrastive divergence loss
loss = energy_pos.mean() - energy_neg.mean()
# Regularization
reg = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
loss = loss + 0.001 * reg
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_loader)
losses.append(avg_loss)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
return losses
print("Training function defined")
Train ModelΒΆ
During training, each step involves: (1) sampling a batch of real images, (2) running MCMC for a few steps to produce negative samples (or drawing from the replay buffer), (3) computing the contrastive divergence loss, and (4) taking a gradient step. The balance between positive and negative sample energies must be carefully maintained β if negative samples are too poor, the model learns trivial energy landscapes; if the MCMC runs too long, training becomes prohibitively slow. Spectral normalization and gradient penalties help keep the energy function well-behaved and prevent the loss from diverging.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
# Model
ebm = EnergyNet().to(device)
# Train
losses = train_ebm(ebm, train_loader, n_epochs=3)
Generate SamplesΒΆ
Generating samples from a trained EBM requires running MCMC (Langevin dynamics) for many steps starting from random noise, following the negative gradient of the energy function toward low-energy (high-probability) regions. Unlike GANs and VAEs which generate in a single forward pass, EBM sampling is iterative and can be computationally expensive, typically requiring hundreds to thousands of Langevin steps for high-quality results. The iterative nature is also an advantage: samples can be refined by running more steps, providing a quality-computation trade-off at inference time.
# Generate samples
ebm.eval()
n_samples = 16
x_init = torch.rand(n_samples, 1, 28, 28).to(device)
samples = sample_langevin(ebm, x_init, n_steps=200, step_size=10.0)
# Visualize
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
ax = axes[i // 4, i % 4]
ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('EBM Generated Samples', fontsize=12)
plt.tight_layout()
plt.show()
Energy Landscape VisualizationΒΆ
Visualizing the learned energy landscape provides unique insight into what the model has learned about the data distribution. For 2D data, we can plot the energy function as a surface or contour map, revealing the basins (low energy) where the model places probability mass and the barriers (high energy) between modes. For image data, comparing the energy assigned to real images, generated samples, and random noise validates that the model correctly assigns low energy to plausible data and high energy to implausible inputs.
# Get real samples
real_samples, _ = next(iter(train_loader))
real_samples = real_samples[:16].to(device)
# Compute energies
with torch.no_grad():
energy_real = ebm(real_samples)
energy_generated = ebm(samples)
# Plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(energy_real.cpu().numpy(), bins=20, alpha=0.7, label='Real Data', edgecolor='black')
ax.hist(energy_generated.cpu().numpy(), bins=20, alpha=0.7, label='Generated', edgecolor='black')
ax.set_xlabel('Energy', fontsize=11)
ax.set_ylabel('Count', fontsize=11)
ax.set_title('Energy Distribution', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Real energy: {energy_real.mean().item():.3f}")
print(f"Generated energy: {energy_generated.mean().item():.3f}")
SummaryΒΆ
Energy-Based Models:ΒΆ
Key Concepts:
Energy function \(E_\theta(x)\)
Probability via Boltzmann: \(p(x) \propto \exp(-E(x))\)
Training via contrastive divergence
Sampling via Langevin dynamics
Training:ΒΆ
Decrease energy of data
Increase energy of generated samples
MCMC for negative samples
Advantages:ΒΆ
Flexible energy functions
No explicit density modeling
Natural for structured prediction
Composable models
Challenges:ΒΆ
Expensive sampling
Mode collapse
Partition function intractable
Applications:ΒΆ
Image generation
Denoising
Anomaly detection
Compositional generation
Advanced Energy-Based Models TheoryΒΆ
1. Mathematical FoundationsΒΆ
Energy Function and Probability DistributionΒΆ
Definition: An energy-based model defines a probability distribution via an energy function E_ΞΈ(x):
where Z_ΞΈ is the partition function:
Key Properties:
Lower energy β higher probability
Z_ΞΈ ensures normalization but is typically intractable
E_ΞΈ(x) can be any differentiable function (neural network)
Maximum Likelihood GradientΒΆ
Theorem (MLE Gradient): The gradient of log-likelihood is:
Interpretation:
First term: βPositive phaseβ - decrease energy of data
Second term: βNegative phaseβ - increase energy of model samples
Difference drives learning
Challenge: Computing expectation requires sampling from p_ΞΈ, which requires knowing Z_ΞΈ.
2. Contrastive Divergence (CD)ΒΆ
CD-k Algorithm (Hinton, 2002)ΒΆ
Key Idea: Approximate negative phase using k steps of MCMC starting from data.
Algorithm:
1. Initialize xβ° from data distribution
2. Run k Gibbs sampling steps: xβ° β xΒΉ β ... β xα΅
3. Gradient approximation:
β_ΞΈ L β -β_ΞΈ E_ΞΈ(xβ°) + β_ΞΈ E_ΞΈ(xα΅)
4. Update: ΞΈ β ΞΈ - Ξ±β_ΞΈ L
Theoretical Justification:
CD-k minimizes difference between data distribution and k-step reconstruction
As k β β, CD-k β exact MLE gradient
In practice, k=1 or k=5 works well
Bias: CD-k is biased but has lower variance than exact gradient.
Persistent Contrastive Divergence (PCD)ΒΆ
Key Improvement: Maintain persistent βfantasy particlesβ across training iterations.
Algorithm:
1. Initialize persistent chains {xΜα΅’} randomly
2. For each mini-batch:
a. Sample xβΊ from data
b. Run k MCMC steps from {xΜα΅’}: xΜα΅’ β xΜ'α΅’
c. Update: β_ΞΈ L β -β_ΞΈ E_ΞΈ(xβΊ) + β_ΞΈ E_ΞΈ(xΜ')
d. Keep updated xΜ' for next iteration
Advantages:
Better mixing: chains donβt restart from data
More accurate negative phase approximation
Faster convergence in practice
3. Score Matching for EBMsΒΆ
Connection to Score-Based ModelsΒΆ
Theorem: For EBM p(x) β exp(-E(x)), the score is:
(partition function Z is constant w.r.t. x)
Denoising Score Matching LossΒΆ
Objective: Train energy via denoising instead of contrastive divergence:
where x_t = x_0 + Ξ΅.
Advantages:
No MCMC sampling required
Avoids partition function
More stable training
4. Restricted Boltzmann Machines (RBMs)ΒΆ
ArchitectureΒΆ
Bipartite Graph: Visible units v β ββΏα΅ and hidden units h β ββΏΚ°.
Energy Function:
where W is weight matrix, b and c are biases.
Joint Distribution:
Conditional IndependenceΒΆ
Key Property: Given h, visible units are independent:
Similarly for p(h | v).
Gibbs SamplingΒΆ
Block Gibbs:
1. Sample h ~ p(h | v)
2. Sample v ~ p(v | h)
3. Repeat
Efficiency: Parallel sampling within each layer due to conditional independence.
Training RBMsΒΆ
CD-1 for RBMs:
1. vβ° β data sample
2. hβ° ~ p(h | vβ°)
3. vΒΉ ~ p(v | hβ°)
4. hΒΉ ~ p(h | vΒΉ)
5. ΞW β <vh>_data - <vh>_recon = vβ°hβ°α΅ - vΒΉhΒΉα΅
Practical Tips:
Learning rate: 0.01 - 0.1
Weight decay for regularization
Momentum (0.9) for faster convergence
5. Deep Belief Networks (DBNs)ΒΆ
ArchitectureΒΆ
Stack of RBMs: Layer-wise unsupervised pretraining.
Greedy Layer-wise Training:
1. Train RBMβ on data
2. Use hβ as input to train RBMβ
3. Repeat for L layers
4. Fine-tune entire network
Theoretical Justification: Each layer increases lower bound on data likelihood.
Fine-tuningΒΆ
Wake-Sleep Algorithm:
Wake phase: Update recognition weights (bottom-up)
Sleep phase: Update generative weights (top-down)
Modern Approach: Fine-tune with backpropagation after pretraining.
6. Modern Energy-Based ModelsΒΆ
Joint Energy-Based Models (JEMs)ΒΆ
Key Idea: Single model for both generation and classification.
Energy Function:
Training:
Classification: Standard cross-entropy on p(y|x)
Generation: SGLD to sample from p(x)
Advantages:
Unified model
Better calibration
Out-of-distribution detection
Conditional EBMsΒΆ
For Discriminative Tasks:
Applications:
Structured prediction (segmentation, parsing)
Image-to-image translation
Conditional generation
7. Noise Contrastive Estimation (NCE)ΒΆ
Key IdeaΒΆ
Objective: Distinguish data from noise distribution p_n.
NCE Loss:
where h_ΞΈ(x) = p_ΞΈ(x)/(p_ΞΈ(x) + kΒ·p_n(x)) and k is noise samples per data sample.
Advantages:
Approximate partition function as learnable parameter
Easier than MCMC sampling
Scales to high dimensions
Connection to GANsΒΆ
NCE vs GANs:
NCE: Fixed noise distribution
GANs: Learned noise (generator)
Both avoid explicit density modeling
8. MCMC Sampling MethodsΒΆ
Langevin DynamicsΒΆ
Update Rule:
where z_t ~ N(0, I).
Convergence: As Ξ΅ β 0 and T β β, samples converge to p_ΞΈ.
Hamiltonian Monte Carlo (HMC)ΒΆ
Augmented State: (x, v) where v is momentum.
Hamiltonian:
Leapfrog Integration:
1. v β v - (Ξ΅/2)β_x E(x)
2. x β x + Ξ΅ v
3. v β v - (Ξ΅/2)β_x E(x)
Advantages:
Better mixing than Langevin
Fewer rejections
Explores energy landscape efficiently
Replica Exchange Monte CarloΒΆ
Parallel Tempering: Run chains at different temperatures Tβ < Tβ < β¦ < Tβ.
Exchange Step: Swap states between adjacent temperatures with probability:
Benefit: High-temperature chains escape local modes, low-temperature chains sample target.
9. Training ImprovementsΒΆ
Spectral NormalizationΒΆ
Objective: Constrain Lipschitz constant of energy function.
Method: Normalize weights by largest singular value:
where Ο(W) is estimated via power iteration.
Benefit: Stabilizes training, prevents energy collapse.
Energy RegularizationΒΆ
Objectives:
Pull-away term: Encourage diversity in generated samples
Entropy maximization: Spread probability mass
Squared energy: Prevent unbounded energies
Combined Loss:
10. Evaluation MetricsΒΆ
Inception Score (IS)ΒΆ
Definition:
Interpretation:
High if samples are diverse (high H[p(y)])
High if samples are discriminative (low H[p(y|x)])
FrΓ©chet Inception Distance (FID)ΒΆ
Definition: Distance between data and generated distributions in Inception feature space:
Lower is better: FID < 10 is excellent.
Log-Likelihood via Annealed Importance Sampling (AIS)ΒΆ
Method: Bridge between prior pβ (tractable) and target p_ΞΈ (intractable).
Estimator:
where weights w_k come from intermediate distributions.
Use: Estimate log p_ΞΈ(x) = -E_ΞΈ(x) - log Z_ΞΈ.
11. State-of-the-Art ModelsΒΆ
EBM for Image GenerationΒΆ
Best Results (as of 2023):
JEM (Joint Energy-based Model): CIFAR-10 FID ~38
Conjugate Energy-Based Models: FID ~15-20
Improved SGLD: Competitive with GANs on small datasets
Limitation: Still lag behind diffusion models and GANs on large-scale generation.
EBM for Compositional ReasoningΒΆ
Key Advantage: Energy composition E_total = Eβ + Eβ + β¦ + E_n.
Applications:
Multi-attribute generation (combine βredβ + βcarβ + βconvertibleβ)
Concept combination without retraining
Out-of-distribution generalization
Example (Du et al., 2020): Compose classifiers as energy functions for zero-shot tasks.
12. Connections to Other ModelsΒΆ
Energy vs Score-Based ModelsΒΆ
Aspect |
Energy-Based |
Score-Based |
|---|---|---|
Core function |
E(x) |
s(x) = βlog p(x) |
Relation |
s(x) = -βE(x) |
E(x) = -log p(x) + C |
Partition |
Explicit challenge |
Avoided |
Training |
CD, NCE |
Score matching |
Sampling |
MCMC |
Langevin, SDE |
Energy vs GANsΒΆ
Similarities:
Implicit generation (no explicit p(x))
Adversarial training dynamics
Differences:
EBM: Single network + MCMC
GAN: Two networks (G, D), no MCMC
EBM: Energy landscape interpretation
GAN: Min-max game interpretation
Energy vs Normalizing FlowsΒΆ
Flows: Explicit density via invertible transforms.
Exact likelihood
Single forward/backward pass
Architectural constraints
EBMs: Flexible energy functions.
Intractable likelihood
MCMC sampling required
Any architecture
13. ApplicationsΒΆ
Image SynthesisΒΆ
Method: Train EBM on images, sample via SGLD.
Enhancements:
Multi-scale generation
Hierarchical sampling
Classifier guidance
DenoisingΒΆ
Approach: Energy as denoising autoencoder objective.
Advantage: Direct connection to score matching.
Anomaly DetectionΒΆ
Principle: Anomalies have high energy.
Method:
1. Train EBM on normal data
2. Threshold energy: E(x) > Ο β anomaly
Benefit: No need for anomaly examples.
3D Shape GenerationΒΆ
Energy over Point Clouds:
Sampling: MCMC over point positions.
Application: Generate novel 3D objects, complete partial scans.
Molecule DesignΒΆ
Energy as Chemical Property Predictor:
Optimization: MCMC or gradient-based search in molecular space.
14. Practical ConsiderationsΒΆ
Hyperparameter TuningΒΆ
Parameter |
Typical Range |
Effect |
|---|---|---|
Learning rate |
1e-5 to 1e-3 |
Higher β faster but unstable |
SGLD steps |
20-200 |
More β better samples, slower |
SGLD step size |
0.1-10.0 |
Larger β faster mixing |
CD-k |
k=1, 5, 10 |
Higher k β better gradient |
PCD chains |
100-1000 |
More β lower variance |
Computational CostΒΆ
Training:
Forward pass: O(N) (N = network parameters)
MCMC sampling: O(KΒ·N) (K = SGLD steps)
Typical: 5-10x slower than GANs
Sampling:
Single sample: O(KΒ·N), K ~ 100-1000
Diffusion models: Often faster (20-50 steps)
Debugging TipsΒΆ
Check:
Energy distribution: Should separate data vs random
SGLD trajectory: Should converge to low energy
Gradient norms: Clip if exploding
Sample diversity: Avoid mode collapse
Common Issues:
Energy collapse: All samples have same energy β use regularization
Poor mixing: SGLD stuck β tune step size, use HMC
Slow convergence: β PCD, spectral normalization
15. LimitationsΒΆ
Computational CostΒΆ
Sampling: MCMC is slow, especially in high dimensions.
100-1000 gradient evaluations per sample
Diffusion models: 10-50 steps
GANs: 1 forward pass
Mode CoverageΒΆ
Challenge: MCMC may not explore all modes in limited steps.
Solutions:
Parallel tempering
Multiple chains
Hybrid MCMC methods
Training InstabilityΒΆ
Issues:
Energy can diverge
MCMC chains may not converge
Gradient variance high with CD
Mitigations:
Spectral normalization
Gradient clipping
PCD instead of CD
Theoretical GapsΒΆ
Open Questions:
When does CD-k converge?
How many MCMC steps sufficient?
Optimal architecture for energy function?
16. Recent Advances (2020-2024)ΒΆ
Diffusion Recovery Likelihood (DRL)ΒΆ
Idea: Combine diffusion and EBM for hybrid model.
Method: Use diffusion to initialize, EBM to refine.
Continuous-Time EBMsΒΆ
Formulation: Energy evolves continuously via SDE.
Advantage: Unifies discrete (CD) and continuous (Langevin) perspectives.
Energy-Based Priors for Inverse ProblemsΒΆ
Application: Solve inverse problems with EBM prior:
Use Cases: MRI reconstruction, deblurring, super-resolution.
Self-Supervised EBMsΒΆ
Training: Use contrastive learning objectives (SimCLR, MoCo) as energy.
Benefit: Leverage large unlabeled datasets.
17. Key PapersΒΆ
FoundationsΒΆ
Hinton (2002): βTraining Products of Experts by Minimizing Contrastive Divergenceβ
Tieleman (2008): βTraining Restricted Boltzmann Machines using Approximations to the Likelihood Gradientβ
Modern EBMsΒΆ
Du & Mordatch (2019): βImplicit Generation and Modeling with Energy Based Modelsβ
Grathwohl et al. (2020): βYour Classifier is Secretly an Energy Based Model (JEM)β
Nijkamp et al. (2020): βLearning Energy-Based Models by Diffusion Recovery Likelihoodβ
Theoretical AnalysisΒΆ
Bengio et al. (2013): βEstimating or Propagating Gradients Through Stochastic Neuronsβ
Song & Kingma (2021): βHow to Train Your Energy-Based Modelsβ
ApplicationsΒΆ
Du et al. (2020): βEnergy-Based Models for Atomic-Resolution Protein Conformationsβ
Xie et al. (2021): βLearning Energy-Based Models in High-Dimensional Spaces with Multi-Scale Denoisingβ
18. Comparison: EBMs vs Other Generative ModelsΒΆ
Model |
Training |
Sampling |
Likelihood |
Flexibility |
|---|---|---|---|---|
EBM |
CD, NCE, Score |
MCMC (slow) |
Intractable |
High |
GAN |
Min-max |
Fast (1 pass) |
No |
High |
VAE |
ELBO |
Fast (1 pass) |
Approx. |
Medium |
Flow |
Exact MLE |
Fast (1 pass) |
Exact |
Low |
Diffusion |
Score match |
Medium (20-50) |
Via ODE |
High |
When to use EBMs:
Compositional generation (energy addition)
Flexible energy functions needed
Theoretical interpretability important
Small-scale tasks (computational cost acceptable)
When to avoid:
Need fast sampling (use GANs or flows)
Large-scale generation (use diffusion models)
Exact likelihood required (use flows or VAEs)
# Advanced Energy-Based Models Implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, List
class SpectralNormConv2d(nn.Module):
"""Conv2d with spectral normalization for Lipschitz constraint."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, n_iter=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.n_iter = n_iter
# Initialize spectral norm parameters
self.register_buffer('u', torch.randn(1, out_channels))
self.register_buffer('v', torch.randn(1, in_channels * kernel_size * kernel_size))
def forward(self, x):
"""Apply spectrally normalized convolution."""
# Get weight matrix
weight = self.conv.weight.view(self.conv.out_channels, -1)
if self.training:
# Power iteration to estimate largest singular value
u, v = self.u, self.v
for _ in range(self.n_iter):
v = F.normalize(u @ weight, dim=1, eps=1e-12)
u = F.normalize(v @ weight.t(), dim=1, eps=1e-12)
# Update buffers
self.u.copy_(u.detach())
self.v.copy_(v.detach())
# Compute spectral norm
sigma = (u @ weight @ v.t()).item()
else:
sigma = (self.u @ weight @ self.v.t()).item()
# Normalize weight
weight_sn = weight / (sigma + 1e-12)
weight_sn = weight_sn.view_as(self.conv.weight)
# Apply conv with normalized weight
return F.conv2d(x, weight_sn, self.conv.bias,
self.conv.stride, self.conv.padding)
class EnergyFunction(nn.Module):
"""
Modern energy function with spectral normalization.
Maps input x to scalar energy E(x).
"""
def __init__(self, input_channels=1, base_channels=64, use_spectral_norm=True):
super().__init__()
Conv = SpectralNormConv2d if use_spectral_norm else nn.Conv2d
self.encoder = nn.Sequential(
Conv(input_channels, base_channels, 4, 2, 1),
nn.LeakyReLU(0.2),
Conv(base_channels, base_channels * 2, 4, 2, 1),
nn.LeakyReLU(0.2),
Conv(base_channels * 2, base_channels * 4, 4, 2, 1),
nn.LeakyReLU(0.2),
Conv(base_channels * 4, base_channels * 8, 4, 2, 1),
nn.LeakyReLU(0.2),
)
# Output single energy value
self.energy_head = nn.Sequential(
nn.Flatten(),
nn.Linear(base_channels * 8 * 2 * 2, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
"""Compute energy E(x)."""
features = self.encoder(x)
energy = self.energy_head(features).squeeze(-1)
return energy
class ContrastiveDivergenceTrainer:
"""
Train EBM using Contrastive Divergence (CD-k).
Approximates MLE gradient using k MCMC steps.
"""
def __init__(self, energy_fn, optimizer, k_steps=1, sgld_lr=10.0,
sgld_noise=True, energy_reg=0.001):
self.energy_fn = energy_fn
self.optimizer = optimizer
self.k_steps = k_steps
self.sgld_lr = sgld_lr
self.sgld_noise = sgld_noise
self.energy_reg = energy_reg
def sgld_step(self, x, add_noise=True):
"""Single SGLD step: x - Ξ΅/2Β·βE(x) + βΡ·z."""
x = x.clone().requires_grad_(True)
# Compute energy gradient
energy = self.energy_fn(x).sum()
grad = torch.autograd.grad(energy, x, create_graph=False)[0]
# SGLD update
x_new = x - 0.5 * self.sgld_lr * grad
if add_noise and self.sgld_noise:
noise = torch.randn_like(x) * np.sqrt(self.sgld_lr)
x_new = x_new + noise
return torch.clamp(x_new.detach(), 0, 1)
def sample_negative(self, x_init):
"""Generate negative samples via k SGLD steps."""
x_neg = x_init.clone()
for _ in range(self.k_steps):
x_neg = self.sgld_step(x_neg)
return x_neg
def train_step(self, x_pos):
"""Single CD training step."""
# Positive samples (data)
energy_pos = self.energy_fn(x_pos)
# Negative samples (SGLD from uniform)
x_init = torch.rand_like(x_pos)
x_neg = self.sample_negative(x_init)
energy_neg = self.energy_fn(x_neg)
# CD loss: increase energy gap
cd_loss = energy_pos.mean() - energy_neg.mean()
# Energy regularization (prevent unbounded energies)
reg_loss = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
loss = cd_loss + self.energy_reg * reg_loss
# Backprop
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.energy_fn.parameters(), 1.0)
self.optimizer.step()
return {
'loss': loss.item(),
'energy_pos': energy_pos.mean().item(),
'energy_neg': energy_neg.mean().item()
}
class PersistentContrastiveDivergence:
"""
Persistent CD: Maintain fantasy particles across iterations.
Better negative samples with lower variance.
"""
def __init__(self, energy_fn, optimizer, n_persistent=100, k_steps=1,
sgld_lr=10.0, energy_reg=0.001):
self.energy_fn = energy_fn
self.optimizer = optimizer
self.k_steps = k_steps
self.sgld_lr = sgld_lr
self.energy_reg = energy_reg
# Initialize persistent chains
self.persistent_chains = None
self.n_persistent = n_persistent
def initialize_chains(self, shape):
"""Initialize persistent chains randomly."""
self.persistent_chains = torch.rand(self.n_persistent, *shape)
def sgld_step(self, x):
"""Single SGLD step."""
x = x.clone().requires_grad_(True)
energy = self.energy_fn(x).sum()
grad = torch.autograd.grad(energy, x, create_graph=False)[0]
x_new = x - 0.5 * self.sgld_lr * grad
x_new = x_new + torch.randn_like(x) * np.sqrt(self.sgld_lr)
return torch.clamp(x_new.detach(), 0, 1)
def train_step(self, x_pos):
"""PCD training step."""
batch_size = x_pos.size(0)
# Initialize chains if needed
if self.persistent_chains is None:
self.initialize_chains(x_pos.shape[1:])
# Move chains to same device
self.persistent_chains = self.persistent_chains.to(x_pos.device)
# Sample from persistent chains
indices = torch.randperm(self.n_persistent)[:batch_size]
x_neg = self.persistent_chains[indices].clone()
# Run k SGLD steps
for _ in range(self.k_steps):
x_neg = self.sgld_step(x_neg)
# Update persistent chains
self.persistent_chains[indices] = x_neg.detach()
# Compute energies
energy_pos = self.energy_fn(x_pos)
energy_neg = self.energy_fn(x_neg)
# PCD loss
cd_loss = energy_pos.mean() - energy_neg.mean()
reg_loss = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
loss = cd_loss + self.energy_reg * reg_loss
# Backprop
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.energy_fn.parameters(), 1.0)
self.optimizer.step()
return {
'loss': loss.item(),
'energy_pos': energy_pos.mean().item(),
'energy_neg': energy_neg.mean().item()
}
class ScoreMatchingEBM:
"""
Train EBM via denoising score matching.
Avoids MCMC sampling during training.
"""
def __init__(self, energy_fn, optimizer, sigma=0.1):
self.energy_fn = energy_fn
self.optimizer = optimizer
self.sigma = sigma
def train_step(self, x_clean):
"""Denoising score matching step."""
# Add noise
noise = torch.randn_like(x_clean) * self.sigma
x_noisy = x_clean + noise
# Compute score (gradient of energy)
x_noisy_grad = x_noisy.clone().requires_grad_(True)
energy = self.energy_fn(x_noisy_grad).sum()
score_pred = -torch.autograd.grad(energy, x_noisy_grad, create_graph=True)[0]
# Target score: -noise / sigma^2
score_target = -noise / (self.sigma ** 2)
# Score matching loss
loss = 0.5 * ((score_pred - score_target) ** 2).sum(dim=[1,2,3]).mean()
# Backprop
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {'loss': loss.item()}
class RestrictedBoltzmannMachine(nn.Module):
"""
RBM: p(v,h) = exp(-E(v,h))/Z
E(v,h) = -v^T W h - b^T v - c^T h
"""
def __init__(self, n_visible, n_hidden):
super().__init__()
# Parameters
self.W = nn.Parameter(torch.randn(n_visible, n_hidden) * 0.01)
self.b = nn.Parameter(torch.zeros(n_visible))
self.c = nn.Parameter(torch.zeros(n_hidden))
def energy(self, v, h):
"""Compute E(v,h)."""
interaction = -(v @ self.W @ h.t()).diagonal()
visible_bias = -(v @ self.b)
hidden_bias = -(h @ self.c)
return interaction + visible_bias + hidden_bias
def sample_h_given_v(self, v):
"""p(h=1|v) = Ο(c + W^T v)."""
activation = v @ self.W + self.c
prob_h = torch.sigmoid(activation)
h = torch.bernoulli(prob_h)
return h, prob_h
def sample_v_given_h(self, h):
"""p(v=1|h) = Ο(b + W h)."""
activation = h @ self.W.t() + self.b
prob_v = torch.sigmoid(activation)
v = torch.bernoulli(prob_v)
return v, prob_v
def gibbs_step(self, v):
"""Single Gibbs sampling step."""
h, prob_h = self.sample_h_given_v(v)
v_recon, prob_v = self.sample_v_given_h(h)
return v_recon, prob_v, h, prob_h
def cd_update(self, v_data, k=1, lr=0.01):
"""CD-k update rule."""
batch_size = v_data.size(0)
# Positive phase
h0, prob_h0 = self.sample_h_given_v(v_data)
# Negative phase: k Gibbs steps
v_k = v_data.clone()
for _ in range(k):
v_k, prob_v_k, h_k, prob_h_k = self.gibbs_step(v_k)
# Compute gradients
# ΞW β <vh>_data - <vh>_recon
positive_grad = (v_data.t() @ prob_h0) / batch_size
negative_grad = (prob_v_k.t() @ prob_h_k) / batch_size
grad_W = positive_grad - negative_grad
grad_b = (v_data - prob_v_k).mean(0)
grad_c = (prob_h0 - prob_h_k).mean(0)
# Update parameters
with torch.no_grad():
self.W += lr * grad_W
self.b += lr * grad_b
self.c += lr * grad_c
# Reconstruction error
recon_error = ((v_data - prob_v_k) ** 2).mean().item()
return {'recon_error': recon_error}
class LangevinSampler:
"""
SGLD sampler: x_{t+1} = x_t - Ξ΅/2Β·βE(x_t) + βΡ·z_t.
"""
def __init__(self, energy_fn, step_size=1.0, n_steps=100,
noise=True, clip_range=(0, 1)):
self.energy_fn = energy_fn
self.step_size = step_size
self.n_steps = n_steps
self.noise = noise
self.clip_range = clip_range
def sample(self, x_init):
"""Generate samples via Langevin dynamics."""
x = x_init.clone()
trajectory = [x.clone()]
energies = []
for step in range(self.n_steps):
x = x.requires_grad_(True)
# Compute energy and gradient
energy = self.energy_fn(x).sum()
grad = torch.autograd.grad(energy, x)[0]
# SGLD update
x = x - 0.5 * self.step_size * grad
if self.noise:
x = x + torch.randn_like(x) * np.sqrt(self.step_size)
# Clip to valid range
x = torch.clamp(x.detach(), *self.clip_range)
trajectory.append(x.clone())
energies.append(energy.item())
return x, trajectory, energies
class HamiltonianMonteCarlo:
"""
HMC sampler with leapfrog integration.
Better mixing than Langevin dynamics.
"""
def __init__(self, energy_fn, step_size=0.1, n_leapfrog=10):
self.energy_fn = energy_fn
self.step_size = step_size
self.n_leapfrog = n_leapfrog
def hamiltonian(self, x, v):
"""H(x,v) = E(x) + 0.5Β·v^T v."""
potential = self.energy_fn(x).sum()
kinetic = 0.5 * (v ** 2).sum()
return potential + kinetic
def leapfrog_step(self, x, v):
"""Leapfrog integration for Hamiltonian dynamics."""
# Half step for momentum
x_grad = x.clone().requires_grad_(True)
energy = self.energy_fn(x_grad).sum()
grad = torch.autograd.grad(energy, x_grad)[0]
v = v - 0.5 * self.step_size * grad
# Full step for position
x = x + self.step_size * v
# Half step for momentum
x_grad = x.clone().requires_grad_(True)
energy = self.energy_fn(x_grad).sum()
grad = torch.autograd.grad(energy, x_grad)[0]
v = v - 0.5 * self.step_size * grad
return x.detach(), v.detach()
def sample_step(self, x):
"""Single HMC step with Metropolis-Hastings."""
# Sample momentum
v = torch.randn_like(x)
# Current Hamiltonian
H_current = self.hamiltonian(x, v)
# Leapfrog integration
x_new, v_new = x.clone(), v.clone()
for _ in range(self.n_leapfrog):
x_new, v_new = self.leapfrog_step(x_new, v_new)
# Proposed Hamiltonian
H_proposed = self.hamiltonian(x_new, v_new)
# Metropolis-Hastings acceptance
accept_prob = torch.exp(H_current - H_proposed).item()
if np.random.rand() < min(1.0, accept_prob):
return x_new, True
else:
return x, False
# ============================================================================
# Demonstrations
# ============================================================================
print("=" * 60)
print("Energy-Based Models - Advanced Implementations")
print("=" * 60)
# 1. Energy function with spectral normalization
print("\n1. Energy Function:")
energy_fn = EnergyFunction(input_channels=1, base_channels=32, use_spectral_norm=True)
x_test = torch.randn(4, 1, 32, 32)
energy_test = energy_fn(x_test)
print(f" Input shape: {x_test.shape}")
print(f" Output (energies): {energy_test.shape}")
print(f" Energy values: {energy_test.detach().numpy()}")
# 2. CD vs PCD comparison
print("\n2. Contrastive Divergence Methods:")
print(" CD-k: Restart MCMC from data each iteration")
print(" PCD: Maintain persistent fantasy particles")
print(f" CD variance: Higher (fresh chains)")
print(f" PCD variance: Lower (evolved chains)")
print(f" Typical k: 1 (CD-1) or 5 (CD-5)")
# 3. RBM architecture
print("\n3. Restricted Boltzmann Machine:")
rbm = RestrictedBoltzmannMachine(n_visible=784, n_hidden=128)
print(f" Visible units: 784")
print(f" Hidden units: 128")
print(f" Parameters: {784*128 + 784 + 128:,}")
print(f" Energy: E(v,h) = -v^T W h - b^T v - c^T h")
# 4. Sampling comparison
print("\n4. MCMC Sampling Methods:")
print(" Langevin: x β x - Ξ΅/2Β·βE + βΡ·z")
print(" HMC: Augment with momentum, leapfrog integration")
print(" Advantages:")
print(" - Langevin: Simple, gradient-based")
print(" - HMC: Better mixing, fewer rejections")
# 5. Training methods comparison
print("\n5. Training Method Comparison:")
print(" βββββββββββββββββββ¬βββββββββββ¬ββββββββββββ¬ββββββββββββββ")
print(" β Method β Sampling β Stability β Speed β")
print(" βββββββββββββββββββΌβββββββββββΌββββββββββββΌββββββββββββββ€")
print(" β CD-k β Yes β Medium β Slow β")
print(" β PCD β Yes β Higher β Slow β")
print(" β Score Matching β No β High β Fast β")
print(" β NCE β Partial β High β Medium β")
print(" βββββββββββββββββββ΄βββββββββββ΄ββββββββββββ΄ββββββββββββββ")
# 6. When to use guide
print("\n6. When to Use Each Method:")
print(" Use CD-1:")
print(" β Small models, limited compute")
print(" β Quick prototyping")
print(" β RBMs and simple architectures")
print("\n Use PCD:")
print(" β Better samples needed")
print(" β Larger models")
print(" β Can maintain persistent chains")
print("\n Use Score Matching:")
print(" β Avoid MCMC during training")
print(" β Faster training")
print(" β Partition function intractable")
print("\n Use NCE:")
print(" β High-dimensional data")
print(" β Approximate partition function acceptable")
print(" β Connection to contrastive learning")
print("\n" + "=" * 60)
Advanced Energy-Based Models: Mathematical Foundations and Modern ArchitecturesΒΆ
1. Introduction to Energy-Based ModelsΒΆ
Energy-based models (EBMs) provide a unified framework for learning probability distributions by defining an energy function \(E_\theta(x)\) that assigns low energy to likely data points and high energy to unlikely ones. Unlike explicit generative models (GANs, VAEs, normalizing flows), EBMs model the data distribution implicitly through the Boltzmann distribution:
where \(Z_\theta = \int e^{-E_\theta(x)} dx\) is the partition function (intractable in general).
Key properties:
Flexibility: Energy function can be any neural network without architectural constraints
Expressiveness: Can model complex multimodal distributions
Compositionality: Energies can be added/composed for structured modeling
Implicit normalization: No need for explicit normalizing flows or sampling networks
Advantages over other generative models:
No generator network: Unlike GANs, no need for adversarial training
No encoder bottleneck: Unlike VAEs, no variational approximation
No invertibility constraints: Unlike normalizing flows, arbitrary architectures allowed
Challenges:
Intractable partition function: Computing \(Z_\theta\) requires integration over all possible \(x\)
Expensive sampling: MCMC methods slow for high-dimensional data
Training instability: Contrastive divergence and score matching can be unstable
2. Mathematical FrameworkΒΆ
2.1 Energy Function and Probability DistributionΒΆ
The energy function \(E_\theta: \mathcal{X} \to \mathbb{R}\) maps inputs to scalar energies. The probability distribution is:
Properties:
Lower energy β higher probability: \(E_\theta(x_1) < E_\theta(x_2) \implies p_\theta(x_1) > p_\theta(x_2)\)
Energy is defined up to a constant: \(E_\theta(x) + c\) gives same distribution
Partition function ensures normalization: \(\int p_\theta(x) dx = 1\)
Log-likelihood:
Gradient with respect to parameters:
The second term (gradient of log partition function) requires sampling from the model distribution, which is the main computational challenge.
2.2 Maximum Likelihood TrainingΒΆ
Given dataset \(\{x_i\}_{i=1}^N\), maximize log-likelihood:
Gradient:
This is a positive-negative gradient:
Positive phase: Pull down energy on data samples
Negative phase: Push up energy on model samples
Challenge: Computing \(\mathbb{E}_{x \sim p_\theta}[\nabla_\theta E_\theta(x)]\) requires sampling from \(p_\theta\), which is intractable.
Solutions:
Contrastive Divergence (CD): Approximate with short MCMC chains
Score Matching: Avoid partition function entirely
Noise Contrastive Estimation (NCE): Compare data vs noise distribution
3. Contrastive Divergence (CD)ΒΆ
3.1 AlgorithmΒΆ
Contrastive Divergence (Hinton, 2002) approximates the negative phase gradient by running short MCMC chains starting from data:
CD-k algorithm:
Initialize \(x^{(0)} \sim p_{\text{data}}\) (data sample)
Run \(k\) steps of Gibbs/Langevin sampling: \(x^{(k)} \sim p_\theta^{(k)}\)
Approximate gradient:
Intuition:
\(x^{(0)}\) is data (low energy desired)
\(x^{(k)}\) is βfantasyβ sample after \(k\) MCMC steps (high energy desired)
Gradient pushes down data energy, pushes up fantasy energy
Common choices:
CD-1: Single Gibbs/Langevin step (fast but biased)
CD-10: 10 steps (slower but less biased)
Persistent CD (PCD): Maintain persistent MCMC chains across batches
3.2 Langevin Dynamics SamplingΒΆ
For continuous data, use Langevin dynamics to sample from \(p_\theta(x)\):
where:
\(\epsilon\) is the step size
\(\nabla_x E_\theta(x_t)\) is the energy gradient (drives to low energy regions)
\(\sqrt{\epsilon} \, z_t\) is noise (ensures proper exploration)
Convergence: As \(\epsilon \to 0\) and \(T \to \infty\), \(x_T \sim p_\theta(x)\).
Practical: Use finite \(\epsilon\) and \(T\) (e.g., \(\epsilon=0.01\), \(T=100\) steps).
3.3 Persistent Contrastive Divergence (PCD)ΒΆ
Maintain persistent MCMC chains \(\{x_i^{\text{chain}}\}\) across batches:
Algorithm:
Initialize chains randomly
Each iteration:
Sample data batch \(\{x_i^{\text{data}}\}\)
Update chains: \(x_i^{\text{chain}} \leftarrow\) Langevin step from \(x_i^{\text{chain}}\)
Gradient: \(\nabla_\theta \mathcal{L} \approx -\nabla_\theta E_\theta(x^{\text{data}}) + \nabla_\theta E_\theta(x^{\text{chain}})\)
Advantage: Chains better approximate \(p_\theta\) over time (less bias than CD-k).
Disadvantage: Chains can become βstuckβ in local modes.
4. Score MatchingΒΆ
4.1 MotivationΒΆ
Score matching (HyvΓ€rinen, 2005) avoids the partition function by matching the score (gradient of log-density):
Note: \(\nabla_x \log Z_\theta = 0\) because \(Z_\theta\) doesnβt depend on \(x\).
Key insight: Score doesnβt require computing \(Z_\theta\)!
4.2 Explicit Score MatchingΒΆ
Match model score to data score:
Problem: \(\nabla_x \log p_{\text{data}}(x)\) is unknown.
Solution: Integration by parts gives equivalent objective (HyvΓ€rinen, 2005):
where \(\text{tr}(\nabla_x^2 E_\theta(x)) = \sum_{i=1}^D \frac{\partial^2 E_\theta(x)}{\partial x_i^2}\) is the trace of the Hessian.
Gradient:
Computational cost: Requires computing Hessian trace (expensive for high-dimensional data).
4.3 Denoising Score Matching (DSM)ΒΆ
Vincent (2011) proposed a computationally efficient alternative:
Setup: Perturb data with noise \(q(x|x_0) = \mathcal{N}(x | x_0, \sigma^2 I)\).
Objective:
Interpretation:
True score under noise perturbation: \(\nabla_x \log q(x|x_0) = -(x - x_0)/\sigma^2\)
Match model score \(-\nabla_x E_\theta(x)\) to this
Advantage: No Hessian computation! Only first-order gradient \(\nabla_x E_\theta(x)\).
Equivalence: DSM is equivalent to explicit score matching under certain conditions.
4.4 Sliced Score Matching (SSM)ΒΆ
Song et al. (2019) proposed sliced score matching to further reduce computational cost:
where \(v \sim p_v\) is a random direction (e.g., \(p_v = \mathcal{N}(0, I)\)).
Advantage:
Hessian-vector product \(\nabla_x^2 E_\theta(x) v\) can be computed efficiently via automatic differentiation
Cost: \(O(D)\) instead of \(O(D^2)\) for full Hessian
Implementation: Use forward-mode AD or double backward pass.
5. Noise Contrastive Estimation (NCE)ΒΆ
5.1 PrincipleΒΆ
Noise Contrastive Estimation (Gutmann & HyvΓ€rinen, 2010) treats density estimation as binary classification:
Setup:
Data distribution: \(p_{\text{data}}(x)\)
Noise distribution: \(p_n(x)\) (e.g., uniform or Gaussian)
Sample ratio: \(\nu\) noise samples per data sample
Binary classification:
Label \(y=1\) for data: \(x \sim p_{\text{data}}\)
Label \(y=0\) for noise: \(x \sim p_n\)
Posterior:
Key trick: Treat \(\log Z_\theta\) as a learnable parameter \(c\):
where \(\sigma\) is the sigmoid function.
5.2 NCE LossΒΆ
Binary cross-entropy:
Expanding:
Optimization: Jointly optimize \(\theta\) and \(c\) via gradient descent.
Advantage: Avoids computing partition function or sampling from model.
Disadvantage: Requires good noise distribution \(p_n\) (poor choice degrades performance).
6. Modern EBM ArchitecturesΒΆ
6.1 Joint Energy-Based Models (JEMs)ΒΆ
Grathwohl et al. (2020) proposed using a single network for both classification and generation:
Energy function: $\(E_\theta(x, y) = -f_\theta(x)[y]\)$
where \(f_\theta(x) \in \mathbb{R}^C\) is a classifier logit vector.
Distributions:
Conditional: \(p_\theta(y|x) = \frac{\exp(f_\theta(x)[y])}{\sum_{y'} \exp(f_\theta(x)[y'])}\) (standard classifier)
Marginal: \(p_\theta(x) = \frac{\sum_y \exp(f_\theta(x)[y])}{Z_\theta} = \frac{\exp(\text{LogSumExp}(f_\theta(x)))}{Z_\theta}\)
Training:
Classification: Standard cross-entropy on \((x, y)\) pairs
Generation: Contrastive divergence on marginal \(p_\theta(x)\)
Benefits:
Single network for both tasks
Improved robustness (adversarial examples have high energy)
Out-of-distribution detection (OOD samples have high energy)
6.2 IGEBM (Implicit Generation and Likelihood Estimation)ΒΆ
Du & Mordatch (2019) trained EBMs for high-resolution image generation:
Architecture: ResNet-based energy function \(E_\theta(x)\)
Training:
Sample data: \(x^+ \sim p_{\text{data}}\)
Sample negatives via Langevin: \(x^- \sim p_\theta\) (100-200 steps)
Update: \(\mathcal{L} = E_\theta(x^+) - E_\theta(x^-)\) (hinge loss variant)
Improvements:
Spectral normalization: Stabilize training (constrain Lipschitz constant)
MCMC initialization: Initialize chains from replay buffer (mix of old samples + noise)
Multiscale training: Train on progressively higher resolutions
Results: Generated high-quality 128Γ128 and 256Γ256 images (comparable to GANs).
6.3 EBM for Compositional GenerationΒΆ
EBMs naturally support compositional reasoning via energy addition:
Concept: Combine multiple energy functions $\(p(x | c_1, c_2) \propto \exp(-E_1(x, c_1) - E_2(x, c_2))\)$
Applications:
Multi-attribute generation: \(E(x, \text{color}, \text{shape})\)
Logical composition: AND (\(E_1 + E_2\)), OR (\(-\log(e^{-E_1} + e^{-E_2})\)), NOT (\(-E\))
Language-guided generation: \(E_{\text{img}}(x) + \lambda E_{\text{CLIP}}(x, \text{caption})\)
Example (Du et al., 2020):
Train separate classifiers \(f_1(x), f_2(x), \ldots\) for different attributes
Compose: \(p(x | y_1=1, y_2=1) \propto \exp(f_1(x)[1] + f_2(x)[1])\)
Sample via Langevin dynamics with combined energy
7. Score-Based Generative Models (Diffusion Models Connection)ΒΆ
7.1 Score-Based ModelsΒΆ
Song & Ermon (2019, 2020) proposed training models to estimate the score function directly:
Score network: \(s_\theta(x, t): \mathbb{R}^D \times \mathbb{R}_+ \to \mathbb{R}^D\)
Objective: Denoising score matching across noise levels $\(\mathcal{L}(\theta) = \mathbb{E}_{t \sim \mathcal{U}(0,T)} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x_t \sim q(x_t|x_0)}\left[\lambda(t) \left\|s_\theta(x_t, t) - \nabla_{x_t} \log q(x_t | x_0)\right\|^2\right]\)$
where:
\(q(x_t | x_0) = \mathcal{N}(x_t | \alpha_t x_0, \sigma_t^2 I)\) is noise perturbation
\(\nabla_{x_t} \log q(x_t | x_0) = -(x_t - \alpha_t x_0)/\sigma_t^2\) is the true score
\(\lambda(t)\) is a weighting function
Sampling: Reverse-time SDE (Langevin dynamics) $\(x_{t-\Delta t} = x_t + s_\theta(x_t, t) \Delta t + \sqrt{2\Delta t} \, z_t, \quad z_t \sim \mathcal{N}(0, I)\)$
7.2 Connection to Diffusion ModelsΒΆ
Equivalence: Denoising diffusion probabilistic models (DDPMs) are discrete-time score-based models.
DDPM: \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1} | \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)
Reparameterization: Instead of predicting \(\mu_\theta\), predict noise \(\epsilon_\theta(x_t, t)\): $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t))\)$
Score equivalence: $\(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\sigma_t}\)$
Unified perspective (Song et al., 2021):
Both are trained via denoising score matching
Sampling: Ancestral sampling (DDPM) vs. SDE/ODE solvers (score-based)
State-of-the-art: Combine best of both (e.g., DDIM, DPM-Solver)
7.3 Noise ConditioningΒΆ
Challenge: Single noise level insufficient for high-quality generation.
Solution: Noise conditioning (multiple noise levels)
Perturb data at multiple scales: \(\sigma_1 > \sigma_2 > \cdots > \sigma_T\)
Train score network: \(s_\theta(x, \sigma_t)\) for all \(t\)
Sample: Gradually denoise from \(\sigma_1\) to \(\sigma_T\)
Annealed Langevin Dynamics:
x_T ~ N(0, Ο_1^2 I)
for t = 1 to T:
for k = 1 to K:
x_t β x_t + Ξ΅_t s_ΞΈ(x_t, Ο_t) + sqrt(2Ξ΅_t) z_k
x_{t-1} β x_t
return x_0
Improvement: Coarse-to-fine generation (large noise captures global structure, small noise refines details).
8. Training Techniques and TricksΒΆ
8.1 Stability TechniquesΒΆ
Spectral Normalization:
Divide each layerβs weights by their largest singular value
Constrains Lipschitz constant: \(\|E_\theta(x_1) - E_\theta(x_2)\| \leq L \|x_1 - x_2\|\)
Stabilizes training, prevents energy function from becoming too steep
Gradient Clipping:
Clip Langevin gradient: \(\nabla_x E_\theta(x) \leftarrow \text{clip}(\nabla_x E_\theta(x), -c, c)\)
Prevents exploding gradients during sampling
Batch Normalization (use carefully):
Can help, but breaks translation equivariance
Alternative: Group normalization, Layer normalization
8.2 MCMC InitializationΒΆ
Replay Buffer:
Maintain buffer of past negative samples
Initialize new chains from buffer (95%) + noise (5%)
Accelerates convergence (chains start closer to \(p_\theta\))
Multi-scale Sampling:
Sample at low resolution first (faster MCMC)
Upsample to high resolution and refine
Reduces computational cost
8.3 Loss FunctionsΒΆ
Hinge Loss (Xie et al., 2016): $\(\mathcal{L}_{\text{hinge}} = \mathbb{E}_{x^+ \sim p_{\text{data}}}[\max(0, E_\theta(x^+) - m^+)] + \mathbb{E}_{x^- \sim p_\theta}[\max(0, m^- - E_\theta(x^-))]\)$
Regularization:
\(L_2\) penalty on energy: \(\mathbb{E}[E_\theta(x)^2]\) (prevents energy drift)
Gradient penalty: \(\mathbb{E}[\|\nabla_x E_\theta(x)\|^2]\) (smoothness)
9. ApplicationsΒΆ
9.1 Image GenerationΒΆ
IGEBM: High-resolution unconditional generation (256Γ256 images)
JEM: Joint classification and generation (CIFAR-10, ImageNet-32)
Compositional generation: Combine attributes, logical operations
9.2 Out-of-Distribution DetectionΒΆ
Motivation: EBMs assign high energy to OOD samples.
Method:
Train EBM on in-distribution data
Compute energy \(E_\theta(x)\) for test sample
Threshold: OOD if \(E_\theta(x) > \tau\)
Performance: Strong OOD detection on CIFAR-10 vs SVHN, CIFAR-100, etc.
9.3 Adversarial RobustnessΒΆ
Observation: Adversarial examples have higher energy than clean examples.
Robust training:
Generate adversarial examples via PGD on \(E_\theta\)
Train to assign high energy to adversarial examples
Results: JEM achieves better robust accuracy than standard classifiers.
9.4 Controllable GenerationΒΆ
Conditional sampling: \(p(x | y) \propto \exp(-E_\theta(x, y))\)
Attribute editing:
Modify energy function: \(E'(x) = E(x) + \lambda E_{\text{attr}}(x)\)
Sample via Langevin to find low-energy \(x\) satisfying attribute
Inpainting:
Observed pixels: \(x_O\)
Energy: \(E(x) + \frac{\lambda}{2}\|x_O - x_O^{\text{target}}\|^2\)
Sample to fill in missing pixels \(x_{\bar{O}}\)
9.5 Inverse ProblemsΒΆ
General formulation: $\(p(x | y) \propto p(y | x) p(x) \propto \exp(-\|A(x) - y\|^2 / 2\sigma^2) \exp(-E_\theta(x))\)$
Sample via Langevin: $\(x_{t+1} = x_t - \frac{\epsilon}{2}[\nabla_x \|A(x_t) - y\|^2 / \sigma^2 + \nabla_x E_\theta(x_t)] + \sqrt{\epsilon} z_t\)$
Examples:
Super-resolution: \(A\) is downsampling operator
Denoising: \(A\) is identity, \(y = x + \text{noise}\)
Compressed sensing: \(A\) is measurement matrix
10. Theoretical PropertiesΒΆ
10.1 ExpressivenessΒΆ
Universal approximation:
Sufficient capacity \(E_\theta\) can approximate any distribution
Proof: Neural networks are universal function approximators
Multimodality:
EBMs naturally handle multimodal distributions
Multiple local minima in energy landscape β multiple modes
10.2 Convergence of MCMCΒΆ
Langevin dynamics convergence:
Under smoothness and strong convexity, Langevin converges to \(p_\theta\) exponentially fast
In practice: Non-convex energies, finite steps β approximate samples
Mixing time:
Time for MCMC to converge to stationary distribution
Depends on energy landscape (flatter β faster mixing)
10.3 Score Matching ConsistencyΒΆ
Theorem (HyvΓ€rinen, 2005): Score matching is consistent: \(\theta^* = \arg\min \mathcal{L}_{\text{ESM}}(\theta)\) satisfies \(p_{\theta^*} = p_{\text{data}}\) (up to parameter identifiability).
Proof sketch:
ESM objective is minimized when \(\nabla_x \log p_\theta(x) = \nabla_x \log p_{\text{data}}(x)\) for all \(x\)
Integration gives \(\log p_\theta(x) - \log p_{\text{data}}(x) = c\)
Normalization implies \(c = 0\), so \(p_\theta = p_{\text{data}}\)
11. Comparison with Other Generative ModelsΒΆ
Aspect |
EBM |
GAN |
VAE |
Flow |
Diffusion |
|---|---|---|---|---|---|
Training |
CD/Score matching |
Adversarial |
ELBO |
Exact likelihood |
Denoising |
Sampling |
MCMC (slow) |
Single pass (fast) |
Single pass (fast) |
Single pass (fast) |
Iterative (slow) |
Likelihood |
Intractable |
Intractable |
Approximate |
Exact |
Tractable |
Architecture |
Flexible |
Flexible |
Encoder-decoder |
Invertible |
Flexible |
Stability |
Moderate |
Low |
High |
High |
High |
Quality |
High (with tricks) |
High |
Moderate |
Moderate |
Highest |
Compositionality |
β (energy addition) |
β |
β |
β |
β (score addition) |
OOD detection |
β |
β |
β |
β |
β |
When to use EBMs:
Compositional reasoning required
Flexibility in architecture
OOD detection or robustness important
Can afford slow sampling
When to avoid:
Real-time generation needed
Limited computational budget
Standard generative tasks (diffusion models better)
12. Recent Advances (2020-2024)ΒΆ
12.1 Cooperative TrainingΒΆ
Cooperative Training (Xie et al., 2020):
Train EBM and generator jointly
Generator initializes MCMC chains (faster convergence)
EBM guides generator training
Algorithm:
Generator: \(x^- \sim q_\phi(x)\)
Short MCMC: \(x^- \leftarrow\) Langevin refinement
Update generator: \(\phi \leftarrow \phi - \alpha \nabla_\phi \mathbb{E}_{x \sim q_\phi}[E_\theta(x)]\)
Update EBM: \(\theta \leftarrow \theta - \beta(E_\theta(x^+) - E_\theta(x^-))\)
12.2 Flow Contrastive Estimation (FCEM)ΒΆ
Idea: Use normalizing flows as proposal for EBM.
Joint distribution: $\(p_{\theta,\phi}(x) = \frac{1}{Z_\theta} e^{-E_\theta(x)} q_\phi(x)\)$
where \(q_\phi\) is a flow with tractable density.
Advantage:
Flow provides good initialization for MCMC
Reduces MCMC steps needed
Combines flexibility (EBM) with tractability (flow)
12.3 Implicit Contrastive LearningΒΆ
Contrastive learning connection:
InfoNCE loss in contrastive learning is a form of NCE
Representation learning via EBM: \(E_\theta(x, x^+) = -\langle f_\theta(x), f_\theta(x^+) \rangle\)
Applications:
Self-supervised learning (SimCLR, MoCo)
Metric learning
Anomaly detection
12.4 Discrete EBMsΒΆ
Structured prediction:
Variables: \(x \in \mathcal{X}\) (discrete, e.g., graphs, sequences)
Energy: \(E_\theta(x)\) (e.g., GNN for graphs)
Inference: \(\arg\min_x E_\theta(x)\) (combinatorial optimization)
Applications:
Molecule generation (graph EBM)
Protein design
Circuit design
13. Implementation ConsiderationsΒΆ
13.1 HyperparametersΒΆ
Langevin sampling:
Step size: \(\epsilon \in [0.001, 0.1]\) (larger for faster sampling, smaller for stability)
Steps: \(T \in [20, 200]\) (more steps β better samples but slower)
Noise temperature: \(\tau \in [0.5, 2.0]\) (controls exploration)
Training:
Learning rate: \(10^{-4}\) to \(10^{-3}\) (Adam optimizer)
Batch size: 128-256 (larger helps negative sampling)
Noise levels: 10-1000 (for score matching)
Replay buffer:
Size: 10,000 samples
Refresh rate: 5% new samples each batch
13.2 Architectural ChoicesΒΆ
Energy network:
ResNet (standard for images)
U-Net (score-based models)
Transformers (sequences, long-range dependencies)
Output:
Scalar energy (no activation)
NO softmax/sigmoid (energy can be any real value)
Normalization:
Spectral normalization (recommended)
Group normalization (better than BatchNorm for EBMs)
13.3 Debugging TipsΒΆ
Check energy landscape:
Data samples should have lower energy than noise
Plot energy histogram: data vs. MCMC samples
Verify gradients:
\(\nabla_x E_\theta(x)\) should point away from data manifold
Visualize energy gradients as vector field
Monitor MCMC quality:
Visual inspection of samples
Acceptance rate (for Metropolis-Hastings)
Effective sample size (ESS)
14. Benchmarks and ResultsΒΆ
14.1 Image Generation (CIFAR-10)ΒΆ
Model |
FID β |
IS β |
Sampling Speed |
|---|---|---|---|
IGEBM |
38.2 |
6.02 |
100 steps (~10s) |
JEM |
40.5 |
6.8 |
100 steps (~10s) |
DDPM |
3.17 |
9.46 |
1000 steps (~50s) |
GAN (StyleGAN2) |
2.92 |
9.18 |
1 step (~0.1s) |
Note: EBMs lag behind GANs/diffusion on standard metrics, but offer unique advantages (compositionality, OOD detection).
14.2 OOD Detection (CIFAR-10 vs SVHN)ΒΆ
Model |
AUROC β |
FPR@95TPR β |
|---|---|---|
Softmax Baseline |
0.890 |
0.421 |
ODIN |
0.921 |
0.336 |
Mahalanobis |
0.937 |
0.298 |
JEM (Energy) |
0.964 |
0.182 |
Observation: Energy-based OOD detection significantly outperforms classifier-based methods.
14.3 Adversarial Robustness (CIFAR-10)ΒΆ
Model |
Clean Acc |
PGD-20 Acc |
AutoAttack Acc |
|---|---|---|---|
Standard ResNet |
95.2% |
0.0% |
0.0% |
Adversarial Training |
84.7% |
53.1% |
48.2% |
JEM + Adversarial |
82.3% |
56.4% |
51.7% |
Benefit: EBM training improves robustness by ~3-5% over standard adversarial training.
15. Limitations and Future DirectionsΒΆ
15.1 Current LimitationsΒΆ
Computational cost:
MCMC sampling expensive (100-1000Γ slower than GANs)
Requires many energy evaluations per sample
Scalability:
Challenging for high-resolution images (512Γ512+)
MCMC mixing poor in high dimensions
Training instability:
Contrastive divergence can diverge
Requires careful tuning of hyperparameters
15.2 Future DirectionsΒΆ
Faster sampling:
Learned MCMC samplers (neural samplers)
Hybrid models (flow + EBM, generator + EBM)
Amortized inference
Theoretical understanding:
Convergence guarantees for CD
Sample complexity analysis
Connection to diffusion models
Applications:
Scientific domains (protein folding, molecule design)
Structured prediction (graphs, programs)
Multimodal reasoning (vision + language)
Hardware acceleration:
Specialized chips for MCMC
Parallel tempering on GPUs
16. SummaryΒΆ
Key Takeaways:
Flexibility: EBMs can use any architecture, no invertibility or generator constraints
Compositionality: Energy functions naturally compose for structured reasoning
Robustness: Strong OOD detection and adversarial robustness
Training: Contrastive divergence (CD), score matching, NCE
Sampling: Langevin dynamics (slow but flexible)
Modern variants: JEM (joint classification/generation), IGEBM (high-res images), score-based (diffusion connection)
Trade-offs: High quality and flexibility vs. computational cost
Recommended reading:
Hinton (2002): βTraining Products of Experts by Minimizing Contrastive Divergenceβ
HyvΓ€rinen (2005): βEstimation of Non-Normalized Statistical Models by Score Matchingβ
Du & Mordatch (2019): βImplicit Generation and Likelihood Estimation (IGEBM)β
Grathwohl et al. (2020): βYour Classifier is Secretly an Energy Based Model (JEM)β
Song & Ermon (2019): βGenerative Modeling by Estimating Gradients of the Data Distributionβ
Song et al. (2021): βScore-Based Generative Modeling through SDEsβ
When to use EBMs:
β Compositional reasoning required
β OOD detection or robustness critical
β Flexible architecture needed
β Slow sampling acceptable
When to avoid:
β Real-time generation required
β Limited computational budget
β Standard image generation (use diffusion/GANs)
# ============================================================
# ADVANCED ENERGY-BASED MODELS: PRODUCTION IMPLEMENTATIONS
# Complete PyTorch implementations with modern training techniques
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
import math
# ============================================================
# 1. ENERGY NETWORK ARCHITECTURES
# ============================================================
class SpectralNorm:
"""Spectral normalization wrapper for constraining Lipschitz constant."""
def __init__(self, module: nn.Module, name: str = 'weight', power_iterations: int = 1):
self.module = module
self.name = name
self.power_iterations = power_iterations
# Register u and v vectors
weight = getattr(module, name)
with torch.no_grad():
u = torch.randn(weight.size(0), device=weight.device)
u = u / u.norm()
v = torch.randn(weight.size(1), device=weight.device)
v = v / v.norm()
self.register_buffer('u', u)
self.register_buffer('v', v)
def register_buffer(self, name: str, tensor: torch.Tensor):
"""Register buffer in module."""
self.module.register_buffer(f'_sn_{name}', tensor)
def _compute_sigma(self, weight: torch.Tensor) -> torch.Tensor:
"""Compute largest singular value via power iteration."""
u = getattr(self.module, f'_sn_u')
v = getattr(self.module, f'_sn_v')
# Reshape weight to 2D
weight_mat = weight.view(weight.size(0), -1)
# Power iterations
for _ in range(self.power_iterations):
v = F.normalize(torch.mv(weight_mat.t(), u), dim=0)
u = F.normalize(torch.mv(weight_mat, v), dim=0)
# Compute sigma = u^T W v
sigma = torch.dot(u, torch.mv(weight_mat, v))
# Update u, v
setattr(self.module, f'_sn_u', u.detach())
setattr(self.module, f'_sn_v', v.detach())
return sigma
def __call__(self, module: nn.Module, inputs):
"""Apply spectral normalization before forward pass."""
weight = getattr(module, self.name)
sigma = self._compute_sigma(weight)
# Normalize weight
weight_sn = weight / sigma
setattr(module, self.name, weight_sn)
return None
def spectral_norm(module: nn.Module, name: str = 'weight', power_iterations: int = 1) -> nn.Module:
"""Apply spectral normalization to a module."""
SpectralNorm(module, name, power_iterations)
return module
class ResidualBlock(nn.Module):
"""Residual block with spectral normalization for energy networks."""
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, use_spectral_norm: bool = True):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
# Apply spectral normalization if requested
if use_spectral_norm:
self.conv1 = spectral_norm(self.conv1)
self.conv2 = spectral_norm(self.conv2)
# Normalization and activation
self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
self.swish = nn.SiLU()
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
if use_spectral_norm:
self.shortcut = spectral_norm(self.shortcut)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.swish(self.gn1(self.conv1(x)))
out = self.gn2(self.conv2(out))
out += self.shortcut(x)
out = self.swish(out)
return out
class EnergyNetwork(nn.Module):
"""ResNet-based energy network with spectral normalization.
Maps input x to scalar energy E_ΞΈ(x).
"""
def __init__(self, in_channels: int = 3, base_channels: int = 64, num_blocks: List[int] = [2, 2, 2, 2],
use_spectral_norm: bool = True):
super().__init__()
self.in_channels = in_channels
# Initial convolution
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1, bias=False)
if use_spectral_norm:
self.conv1 = spectral_norm(self.conv1)
self.gn1 = nn.GroupNorm(min(32, base_channels), base_channels)
self.swish = nn.SiLU()
# Residual blocks
channels = base_channels
self.layer1 = self._make_layer(channels, channels, num_blocks[0], stride=1, use_spectral_norm=use_spectral_norm)
self.layer2 = self._make_layer(channels, channels*2, num_blocks[1], stride=2, use_spectral_norm=use_spectral_norm)
self.layer3 = self._make_layer(channels*2, channels*4, num_blocks[2], stride=2, use_spectral_norm=use_spectral_norm)
self.layer4 = self._make_layer(channels*4, channels*8, num_blocks[3], stride=2, use_spectral_norm=use_spectral_norm)
# Global pooling and energy output
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(channels*8, 1)
if use_spectral_norm:
self.fc = spectral_norm(self.fc)
def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int,
use_spectral_norm: bool) -> nn.Sequential:
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride, use_spectral_norm))
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels, 1, use_spectral_norm))
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor [batch_size, in_channels, H, W]
Returns:
energy: Scalar energy values [batch_size, 1]
"""
out = self.swish(self.gn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.global_pool(out)
out = out.view(out.size(0), -1)
energy = self.fc(out)
return energy
class JointEnergyModel(nn.Module):
"""Joint Energy-Based Model for classification and generation.
Uses classifier logits as energy function:
E(x, y) = -f_ΞΈ(x)[y]
p(y|x) = softmax(f_ΞΈ(x))
p(x) β exp(LogSumExp(f_ΞΈ(x)))
"""
def __init__(self, in_channels: int = 3, num_classes: int = 10, base_channels: int = 64,
num_blocks: List[int] = [2, 2, 2, 2], use_spectral_norm: bool = True):
super().__init__()
self.num_classes = num_classes
# Feature extractor (ResNet)
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1, bias=False)
if use_spectral_norm:
self.conv1 = spectral_norm(self.conv1)
self.gn1 = nn.GroupNorm(min(32, base_channels), base_channels)
self.swish = nn.SiLU()
# Residual blocks
channels = base_channels
self.layer1 = self._make_layer(channels, channels, num_blocks[0], stride=1, use_spectral_norm=use_spectral_norm)
self.layer2 = self._make_layer(channels, channels*2, num_blocks[1], stride=2, use_spectral_norm=use_spectral_norm)
self.layer3 = self._make_layer(channels*2, channels*4, num_blocks[2], stride=2, use_spectral_norm=use_spectral_norm)
self.layer4 = self._make_layer(channels*4, channels*8, num_blocks[3], stride=2, use_spectral_norm=use_spectral_norm)
# Global pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(channels*8, num_classes)
if use_spectral_norm:
self.fc = spectral_norm(self.fc)
def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int,
use_spectral_norm: bool) -> nn.Sequential:
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride, use_spectral_norm))
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels, 1, use_spectral_norm))
return nn.Sequential(*layers)
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features from input."""
out = self.swish(self.gn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.global_pool(out)
out = out.view(out.size(0), -1)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor [batch_size, in_channels, H, W]
Returns:
logits: Class logits [batch_size, num_classes]
"""
features = self.extract_features(x)
logits = self.fc(features)
return logits
def energy(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute energy E(x) or E(x, y).
Args:
x: Input tensor [batch_size, in_channels, H, W]
y: Optional class labels [batch_size]
Returns:
energy: Energy values [batch_size, 1]
"""
logits = self.forward(x)
if y is None:
# Marginal energy: E(x) = -LogSumExp(f_ΞΈ(x))
energy = -torch.logsumexp(logits, dim=1, keepdim=True)
else:
# Conditional energy: E(x, y) = -f_ΞΈ(x)[y]
energy = -logits.gather(1, y.unsqueeze(1))
return energy
# ============================================================
# 2. LANGEVIN DYNAMICS SAMPLING
# ============================================================
class LangevinSampler:
"""Langevin dynamics sampler for EBMs.
x_{t+1} = x_t - (Ξ΅/2) β_x E_ΞΈ(x_t) + βΞ΅ z_t
"""
def __init__(self, energy_model: nn.Module, step_size: float = 0.01, num_steps: int = 100,
noise_scale: float = 1.0, clip_grad: Optional[float] = 0.1):
self.energy_model = energy_model
self.step_size = step_size
self.num_steps = num_steps
self.noise_scale = noise_scale
self.clip_grad = clip_grad
def sample(self, x_init: torch.Tensor, y: Optional[torch.Tensor] = None,
verbose: bool = False) -> torch.Tensor:
"""Sample from p_ΞΈ(x) or p_ΞΈ(x|y) using Langevin dynamics.
Args:
x_init: Initial samples [batch_size, channels, H, W]
y: Optional conditioning labels [batch_size]
verbose: Print energy during sampling
Returns:
x: Final samples [batch_size, channels, H, W]
"""
x = x_init.clone().detach().requires_grad_(True)
for step in range(self.num_steps):
# Compute energy gradient
if isinstance(self.energy_model, JointEnergyModel):
energy = self.energy_model.energy(x, y)
else:
energy = self.energy_model(x)
grad = torch.autograd.grad(energy.sum(), x, create_graph=False)[0]
# Clip gradient (for stability)
if self.clip_grad is not None:
grad = torch.clamp(grad, -self.clip_grad, self.clip_grad)
# Langevin update
noise = torch.randn_like(x) * self.noise_scale
x = x - (self.step_size / 2) * grad + np.sqrt(self.step_size) * noise
# Clamp to valid range (e.g., [0, 1] for images)
x = torch.clamp(x, 0, 1)
x = x.detach().requires_grad_(True)
if verbose and step % 20 == 0:
print(f"Step {step}: Energy = {energy.mean().item():.3f}")
return x.detach()
class ReplayBuffer:
"""Replay buffer for storing and sampling negative examples."""
def __init__(self, buffer_size: int = 10000, image_shape: Tuple[int, int, int] = (3, 32, 32)):
self.buffer_size = buffer_size
self.image_shape = image_shape
self.buffer = []
def push(self, samples: torch.Tensor):
"""Add samples to buffer."""
samples = samples.detach().cpu()
for sample in samples:
if len(self.buffer) < self.buffer_size:
self.buffer.append(sample)
else:
# Replace random sample
idx = np.random.randint(0, self.buffer_size)
self.buffer[idx] = sample
def sample(self, batch_size: int, device: torch.device, reinit_prob: float = 0.05) -> torch.Tensor:
"""Sample from buffer with probability of reinitialization.
Args:
batch_size: Number of samples
device: Device to place samples on
reinit_prob: Probability of sampling from noise instead of buffer
Returns:
samples: Sampled images [batch_size, channels, H, W]
"""
samples = []
for _ in range(batch_size):
if len(self.buffer) == 0 or np.random.rand() < reinit_prob:
# Sample from noise
sample = torch.rand(self.image_shape)
else:
# Sample from buffer
idx = np.random.randint(0, len(self.buffer))
sample = self.buffer[idx]
samples.append(sample)
samples = torch.stack(samples).to(device)
return samples
# ============================================================
# 3. CONTRASTIVE DIVERGENCE TRAINING
# ============================================================
class ContrastiveDivergenceTrainer:
"""Trainer for EBMs using Contrastive Divergence.
Loss: L = E_ΞΈ(x^+) - E_ΞΈ(x^-)
where x^+ ~ p_data, x^- ~ p_ΞΈ (via MCMC)
"""
def __init__(self, energy_model: nn.Module, langevin_sampler: LangevinSampler,
replay_buffer: ReplayBuffer, optimizer: optim.Optimizer,
energy_reg: float = 0.0, grad_reg: float = 0.0):
self.energy_model = energy_model
self.langevin_sampler = langevin_sampler
self.replay_buffer = replay_buffer
self.optimizer = optimizer
self.energy_reg = energy_reg
self.grad_reg = grad_reg
def train_step(self, x_pos: torch.Tensor) -> Dict[str, float]:
"""Single training step.
Args:
x_pos: Positive samples (data) [batch_size, channels, H, W]
Returns:
metrics: Dictionary of training metrics
"""
batch_size = x_pos.size(0)
device = x_pos.device
# Sample negative examples from replay buffer
x_init = self.replay_buffer.sample(batch_size, device)
# Run Langevin dynamics
with torch.no_grad():
x_neg = self.langevin_sampler.sample(x_init)
# Update replay buffer
self.replay_buffer.push(x_neg)
# Compute energies
energy_pos = self.energy_model(x_pos)
energy_neg = self.energy_model(x_neg)
# Contrastive divergence loss
cd_loss = energy_pos.mean() - energy_neg.mean()
# Regularization
loss = cd_loss
# Energy regularization (prevent energy drift)
if self.energy_reg > 0:
energy_reg_loss = self.energy_reg * (energy_pos ** 2).mean()
loss += energy_reg_loss
# Gradient regularization (smoothness)
if self.grad_reg > 0:
x_pos.requires_grad_(True)
energy_pos_grad = self.energy_model(x_pos)
grad = torch.autograd.grad(energy_pos_grad.sum(), x_pos, create_graph=True)[0]
grad_reg_loss = self.grad_reg * (grad ** 2).sum(dim=(1, 2, 3)).mean()
loss += grad_reg_loss
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Metrics
metrics = {
'cd_loss': cd_loss.item(),
'energy_pos': energy_pos.mean().item(),
'energy_neg': energy_neg.mean().item(),
'loss': loss.item()
}
return metrics
# ============================================================
# 4. SCORE MATCHING TRAINING
# ============================================================
class DenoisingScoreMatching:
"""Denoising score matching for EBMs.
Loss: E_{x_0, x}[||β_x E_ΞΈ(x) + (x - x_0)/ΟΒ²||Β²]
where x = x_0 + Ο Ξ΅, Ξ΅ ~ N(0, I)
"""
def __init__(self, energy_model: nn.Module, optimizer: optim.Optimizer, noise_std: float = 0.1):
self.energy_model = energy_model
self.optimizer = optimizer
self.noise_std = noise_std
def train_step(self, x_clean: torch.Tensor) -> Dict[str, float]:
"""Single training step.
Args:
x_clean: Clean data samples [batch_size, channels, H, W]
Returns:
metrics: Dictionary of training metrics
"""
# Add noise
noise = torch.randn_like(x_clean) * self.noise_std
x_noisy = x_clean + noise
# Compute energy gradient (score)
x_noisy.requires_grad_(True)
energy = self.energy_model(x_noisy)
score = torch.autograd.grad(energy.sum(), x_noisy, create_graph=True)[0]
# True score: β_x log q(x|x_0) = -(x - x_0)/ΟΒ²
true_score = -noise / (self.noise_std ** 2)
# Score matching loss
loss = 0.5 * ((score - true_score) ** 2).sum(dim=(1, 2, 3)).mean()
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Metrics
metrics = {
'sm_loss': loss.item(),
'score_norm': score.norm(dim=(1, 2, 3)).mean().item()
}
return metrics
class MultiScaleScoreMatching:
"""Multi-scale denoising score matching with noise conditioning.
Trains score network s_ΞΈ(x, Ο) for multiple noise levels.
"""
def __init__(self, energy_model: nn.Module, optimizer: optim.Optimizer,
noise_levels: List[float] = [1.0, 0.5, 0.25, 0.1, 0.05]):
self.energy_model = energy_model
self.optimizer = optimizer
self.noise_levels = noise_levels
def train_step(self, x_clean: torch.Tensor) -> Dict[str, float]:
"""Single training step.
Args:
x_clean: Clean data samples [batch_size, channels, H, W]
Returns:
metrics: Dictionary of training metrics
"""
batch_size = x_clean.size(0)
# Sample noise level for each sample
noise_level_idx = torch.randint(0, len(self.noise_levels), (batch_size,))
noise_stds = torch.tensor([self.noise_levels[i] for i in noise_level_idx],
device=x_clean.device, dtype=x_clean.dtype)
# Add noise
noise = torch.randn_like(x_clean)
x_noisy = x_clean + noise * noise_stds.view(-1, 1, 1, 1)
# Compute energy gradient (score)
x_noisy.requires_grad_(True)
energy = self.energy_model(x_noisy)
score = torch.autograd.grad(energy.sum(), x_noisy, create_graph=True)[0]
# True score: -(x - x_0)/ΟΒ²
true_score = -noise / noise_stds.view(-1, 1, 1, 1)
# Weighted score matching loss (weight by ΟΒ²)
weights = noise_stds ** 2
loss = 0.5 * (weights.view(-1, 1, 1, 1) * (score - true_score) ** 2).sum(dim=(1, 2, 3)).mean()
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Metrics
metrics = {
'msm_loss': loss.item(),
'score_norm': score.norm(dim=(1, 2, 3)).mean().item()
}
return metrics
# ============================================================
# 5. JOINT ENERGY-BASED MODEL (JEM) TRAINING
# ============================================================
class JEMTrainer:
"""Trainer for Joint Energy-Based Models (classification + generation)."""
def __init__(self, jem_model: JointEnergyModel, langevin_sampler: LangevinSampler,
replay_buffer: ReplayBuffer, optimizer: optim.Optimizer,
alpha: float = 1.0, beta: float = 1.0):
"""
Args:
jem_model: Joint energy-based model
langevin_sampler: Sampler for generating negative examples
replay_buffer: Buffer for storing negative examples
optimizer: Optimizer
alpha: Weight for classification loss
beta: Weight for generation loss
"""
self.jem_model = jem_model
self.langevin_sampler = langevin_sampler
self.replay_buffer = replay_buffer
self.optimizer = optimizer
self.alpha = alpha
self.beta = beta
def train_step(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, float]:
"""Single training step.
Args:
x: Input images [batch_size, channels, H, W]
y: Class labels [batch_size]
Returns:
metrics: Dictionary of training metrics
"""
batch_size = x.size(0)
device = x.device
# ===== Classification Loss =====
logits = self.jem_model(x)
class_loss = F.cross_entropy(logits, y)
# ===== Generation Loss (Contrastive Divergence) =====
# Sample negative examples from replay buffer
x_init = self.replay_buffer.sample(batch_size, device)
# Run Langevin dynamics (unconditional)
with torch.no_grad():
x_neg = self.langevin_sampler.sample(x_init)
# Update replay buffer
self.replay_buffer.push(x_neg)
# Compute energies (marginal)
energy_pos = self.jem_model.energy(x)
energy_neg = self.jem_model.energy(x_neg)
# Contrastive divergence loss
gen_loss = energy_pos.mean() - energy_neg.mean()
# ===== Total Loss =====
loss = self.alpha * class_loss + self.beta * gen_loss
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Metrics
acc = (logits.argmax(dim=1) == y).float().mean()
metrics = {
'loss': loss.item(),
'class_loss': class_loss.item(),
'gen_loss': gen_loss.item(),
'accuracy': acc.item(),
'energy_pos': energy_pos.mean().item(),
'energy_neg': energy_neg.mean().item()
}
return metrics
# ============================================================
# 6. DEMONSTRATIONS
# ============================================================
def demo_energy_network():
"""Demonstrate energy network architecture."""
print("=" * 60)
print("DEMO: Energy Network Architecture")
print("=" * 60)
# Create energy network
energy_net = EnergyNetwork(in_channels=3, base_channels=64, num_blocks=[2, 2, 2, 2])
# Dummy input
x = torch.randn(8, 3, 32, 32)
# Forward pass
energy = energy_net(x)
print(f"Input shape: {x.shape}")
print(f"Energy shape: {energy.shape}")
print(f"Energy values: {energy.squeeze()}")
print(f"Total parameters: {sum(p.numel() for p in energy_net.parameters()):,}")
print()
def demo_langevin_sampling():
"""Demonstrate Langevin dynamics sampling."""
print("=" * 60)
print("DEMO: Langevin Dynamics Sampling")
print("=" * 60)
# Create energy network
energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
energy_net.eval()
# Create sampler
sampler = LangevinSampler(energy_net, step_size=0.01, num_steps=50, clip_grad=0.1)
# Initialize from noise
x_init = torch.rand(4, 3, 32, 32)
print("Running Langevin sampling (50 steps)...")
x_samples = sampler.sample(x_init, verbose=True)
print(f"\nInitial samples range: [{x_init.min():.3f}, {x_init.max():.3f}]")
print(f"Final samples range: [{x_samples.min():.3f}, {x_samples.max():.3f}]")
# Compute initial vs final energies
with torch.no_grad():
energy_init = energy_net(x_init)
energy_final = energy_net(x_samples)
print(f"\nInitial energy: {energy_init.mean():.3f} Β± {energy_init.std():.3f}")
print(f"Final energy: {energy_final.mean():.3f} Β± {energy_final.std():.3f}")
print(f"Energy reduction: {(energy_init - energy_final).mean():.3f}")
print()
def demo_contrastive_divergence():
"""Demonstrate contrastive divergence training."""
print("=" * 60)
print("DEMO: Contrastive Divergence Training")
print("=" * 60)
# Create energy network
energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
# Create sampler and buffer
sampler = LangevinSampler(energy_net, step_size=0.01, num_steps=20)
buffer = ReplayBuffer(buffer_size=1000, image_shape=(3, 32, 32))
# Create trainer
optimizer = optim.Adam(energy_net.parameters(), lr=1e-4)
trainer = ContrastiveDivergenceTrainer(energy_net, sampler, buffer, optimizer,
energy_reg=0.01, grad_reg=0.01)
# Dummy training data
x_data = torch.rand(16, 3, 32, 32)
# Training step
metrics = trainer.train_step(x_data)
print("Training metrics:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
print()
def demo_score_matching():
"""Demonstrate denoising score matching."""
print("=" * 60)
print("DEMO: Denoising Score Matching")
print("=" * 60)
# Create energy network
energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
# Create trainer
optimizer = optim.Adam(energy_net.parameters(), lr=1e-4)
trainer = DenoisingScoreMatching(energy_net, optimizer, noise_std=0.1)
# Dummy training data
x_data = torch.rand(16, 3, 32, 32)
# Training step
metrics = trainer.train_step(x_data)
print("Score matching metrics:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
print()
def demo_joint_energy_model():
"""Demonstrate Joint Energy-Based Model."""
print("=" * 60)
print("DEMO: Joint Energy-Based Model (JEM)")
print("=" * 60)
# Create JEM
jem = JointEnergyModel(in_channels=3, num_classes=10, base_channels=32, num_blocks=[1, 1, 1, 1])
# Dummy input
x = torch.rand(8, 3, 32, 32)
y = torch.randint(0, 10, (8,))
# Forward pass (classification)
logits = jem(x)
probs = F.softmax(logits, dim=1)
print(f"Input shape: {x.shape}")
print(f"Logits shape: {logits.shape}")
print(f"Class probabilities (first sample): {probs[0]}")
# Energy computation
energy_marginal = jem.energy(x)
energy_conditional = jem.energy(x, y)
print(f"\nMarginal energy E(x): {energy_marginal.squeeze()}")
print(f"Conditional energy E(x,y): {energy_conditional.squeeze()}")
print(f"\nTotal parameters: {sum(p.numel() for p in jem.parameters()):,}")
print()
def demo_jem_training():
"""Demonstrate JEM training."""
print("=" * 60)
print("DEMO: JEM Training Step")
print("=" * 60)
# Create JEM
jem = JointEnergyModel(in_channels=3, num_classes=10, base_channels=32, num_blocks=[1, 1, 1, 1])
# Create sampler and buffer
sampler = LangevinSampler(jem, step_size=0.01, num_steps=20)
buffer = ReplayBuffer(buffer_size=1000, image_shape=(3, 32, 32))
# Create trainer
optimizer = optim.Adam(jem.parameters(), lr=1e-4)
trainer = JEMTrainer(jem, sampler, buffer, optimizer, alpha=1.0, beta=1.0)
# Dummy training data
x = torch.rand(16, 3, 32, 32)
y = torch.randint(0, 10, (16,))
# Training step
metrics = trainer.train_step(x, y)
print("JEM training metrics:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
print()
def print_performance_comparison():
"""Print comprehensive performance comparison and decision guide."""
print("=" * 80)
print("ENERGY-BASED MODELS: COMPREHENSIVE PERFORMANCE ANALYSIS")
print("=" * 80)
print("\n" + "=" * 80)
print("1. IMAGE GENERATION BENCHMARKS (CIFAR-10)")
print("=" * 80)
generation_data = [
["Model", "FID β", "IS β", "Sampling Time", "Training Stable?"],
["-" * 20, "-" * 10, "-" * 10, "-" * 15, "-" * 15],
["IGEBM", "38.2", "6.02", "~10s (100 steps)", "Moderate"],
["JEM", "40.5", "6.8", "~10s (100 steps)", "Moderate"],
["Score-based", "3.2", "9.5", "~50s (1000 steps)", "High"],
["DDPM", "3.17", "9.46", "~50s (1000 steps)", "High"],
["StyleGAN2", "2.92", "9.18", "~0.1s (1 step)", "Low"],
["BigGAN", "6.95", "9.22", "~0.1s (1 step)", "Low"]
]
for row in generation_data:
print(f"{row[0]:<20} {row[1]:<10} {row[2]:<10} {row[3]:<15} {row[4]:<15}")
print("\nπ Key Observations:")
print(" β’ EBMs (IGEBM, JEM) lag GANs/diffusion on standard metrics")
print(" β’ Sampling 100-1000Γ slower than GANs due to MCMC")
print(" β’ More stable training than GANs, less stable than diffusion")
print(" β’ Score-based models bridge gap (diffusion connection)")
print("\n" + "=" * 80)
print("2. OUT-OF-DISTRIBUTION (OOD) DETECTION (CIFAR-10 vs SVHN)")
print("=" * 80)
ood_data = [
["Method", "AUROC β", "FPR@95TPR β", "Approach"],
["-" * 25, "-" * 10, "-" * 15, "-" * 30],
["Softmax Baseline", "0.890", "0.421", "Max softmax probability"],
["ODIN", "0.921", "0.336", "Temperature + input perturbation"],
["Mahalanobis", "0.937", "0.298", "Feature space distance"],
["JEM (Energy)", "0.964", "0.182", "Energy threshold"],
["EBM + Ensembles", "0.978", "0.145", "Energy + multiple models"]
]
for row in ood_data:
print(f"{row[0]:<25} {row[1]:<10} {row[2]:<15} {row[3]:<30}")
print("\nπ― Key Observations:")
print(" β’ Energy-based OOD detection significantly outperforms classifier-based methods")
print(" β’ JEM energy threshold: +7.4 AUROC vs softmax baseline")
print(" β’ Lower FPR@95TPR: More reliable rejection of OOD samples")
print(" β’ Ensembles further improve OOD detection")
print("\n" + "=" * 80)
print("3. ADVERSARIAL ROBUSTNESS (CIFAR-10, PGD-20 Attack)")
print("=" * 80)
robust_data = [
["Model", "Clean Acc", "PGD-20 Acc", "AutoAttack", "Training Cost"],
["-" * 25, "-" * 12, "-" * 13, "-" * 13, "-" * 15],
["Standard ResNet", "95.2%", "0.0%", "0.0%", "1Γ baseline"],
["Adversarial Training", "84.7%", "53.1%", "48.2%", "3Γ baseline"],
["JEM + Adversarial", "82.3%", "56.4%", "51.7%", "5Γ baseline"],
["TRADES", "84.9%", "54.4%", "49.8%", "3Γ baseline"]
]
for row in robust_data:
print(f"{row[0]:<25} {row[1]:<12} {row[2]:<13} {row[3]:<13} {row[4]:<15}")
print("\nπ‘οΈ Key Observations:")
print(" β’ JEM improves robust accuracy by ~3-4% over standard adversarial training")
print(" β’ Energy landscape provides additional robustness signal")
print(" β’ Higher computational cost (5Γ vs 3Γ for standard adversarial training)")
print(" β’ Trade-off: Clean accuracy drops ~3% vs standard training")
print("\n" + "=" * 80)
print("4. COMPUTATIONAL COMPLEXITY ANALYSIS")
print("=" * 80)
complexity_data = [
["Operation", "Time Complexity", "Space", "Bottleneck"],
["-" * 30, "-" * 20, "-" * 15, "-" * 30],
["Energy forward", "O(D)", "O(D)", "Network evaluation"],
["Score computation", "O(D)", "O(D)", "Backprop through network"],
["Langevin sampling (T steps)", "O(TΒ·D)", "O(D)", "MCMC iterations"],
["CD training (batch B)", "O(BΒ·TΒ·D)", "O(BΒ·D)", "Negative sampling"],
["Score matching", "O(BΒ·D)", "O(D)", "Score gradient computation"]
]
for row in complexity_data:
print(f"{row[0]:<30} {row[1]:<20} {row[2]:<15} {row[3]:<30}")
print("\nβ±οΈ Typical Values:")
print(" β’ D = 3Γ32Γ32 = 3,072 (CIFAR-10)")
print(" β’ T = 20-200 Langevin steps")
print(" β’ B = 128-256 batch size")
print(" β’ CD: ~10s per batch (100 steps), Score matching: ~0.5s per batch")
print("\n" + "=" * 80)
print("5. TRAINING METHOD COMPARISON")
print("=" * 80)
training_data = [
["Method", "Partition Fn?", "MCMC?", "Stability", "Speed", "Quality"],
["-" * 25, "-" * 15, "-" * 10, "-" * 12, "-" * 10, "-" * 10],
["Contrastive Divergence", "Avoided", "Yes", "Moderate", "Slow", "High"],
["Persistent CD", "Avoided", "Yes", "Low", "Slow", "Higher"],
["Score Matching", "Avoided", "No", "High", "Fast", "High"],
["Denoising Score", "Avoided", "No", "High", "Fast", "High"],
["Multi-scale Score", "Avoided", "No", "High", "Moderate", "Highest"],
["NCE", "Learned", "No", "High", "Fast", "Moderate"]
]
for row in training_data:
print(f"{row[0]:<25} {row[1]:<15} {row[2]:<10} {row[3]:<12} {row[4]:<10} {row[5]:<10}")
print("\nπ§ Recommendations:")
print(" β’ Denoising Score Matching: Best default choice (fast, stable, high quality)")
print(" β’ Multi-scale Score: State-of-the-art quality (diffusion connection)")
print(" β’ Contrastive Divergence: When explicit sampling needed")
print(" β’ NCE: Fast but requires good noise distribution")
print("\n" + "=" * 80)
print("6. COMPARISON WITH OTHER GENERATIVE MODELS")
print("=" * 80)
comparison_data = [
["Model", "Likelihood", "Sampling", "Architecture", "Composable?", "Best Use Case"],
["-" * 12, "-" * 12, "-" * 12, "-" * 15, "-" * 12, "-" * 35],
["EBM", "Intractable", "Slow (MCMC)", "Flexible", "β", "OOD, robustness, composition"],
["GAN", "Intractable", "Fast", "Flexible", "β", "High-quality images, fast sampling"],
["VAE", "Approximate", "Fast", "Enc-Dec", "β", "Latent space, fast inference"],
["Flow", "Exact", "Fast", "Invertible", "β", "Exact likelihood, density est"],
["Diffusion", "Tractable", "Slow", "Flexible", "β", "Highest quality, controllable"]
]
for row in comparison_data:
print(f"{row[0]:<12} {row[1]:<12} {row[2]:<12} {row[3]:<15} {row[4]:<12} {row[5]:<35}")
print("\n" + "=" * 80)
print("7. DECISION GUIDE: WHEN TO USE ENERGY-BASED MODELS")
print("=" * 80)
print("\nβ
USE EBMs WHEN:")
print(" 1. Compositional reasoning required (combine concepts: red AND square)")
print(" 2. Out-of-distribution detection critical (medical, autonomous systems)")
print(" 3. Adversarial robustness important (security applications)")
print(" 4. Flexible architecture needed (no invertibility constraints)")
print(" 5. Joint modeling beneficial (JEM: classification + generation)")
print(" 6. Interpretability valued (energy landscape visualization)")
print(" 7. Slow sampling acceptable (offline generation, design)")
print("\nβ AVOID EBMs WHEN:")
print(" 1. Real-time generation required (use GANs or cached diffusion)")
print(" 2. Limited computational budget (training + sampling expensive)")
print(" 3. Standard generative tasks (diffusion models better FID/IS)")
print(" 4. Large-scale images (512Γ512+) without specialized hardware")
print(" 5. Exact likelihood needed (use normalizing flows)")
print(" 6. Fast prototyping (less mature libraries than GANs/VAEs)")
print("\n" + "=" * 80)
print("8. HYPERPARAMETER RECOMMENDATIONS")
print("=" * 80)
hyperparam_data = [
["Parameter", "Typical Range", "Recommended Start", "Impact"],
["-" * 25, "-" * 20, "-" * 20, "-" * 35],
["Langevin step size Ξ΅", "0.001 - 0.1", "0.01", "Larger β faster, less stable"],
["Langevin steps T", "20 - 200", "100", "More β better samples, slower"],
["Learning rate", "1e-5 - 1e-3", "1e-4", "Standard impact"],
["Batch size", "64 - 256", "128", "Larger helps negative sampling"],
["Replay buffer size", "1K - 100K", "10K", "Larger β better diversity"],
["Reinit probability", "0.01 - 0.2", "0.05", "Higher β more exploration"],
["Energy reg Ξ»_E", "0.0 - 0.1", "0.01", "Prevents energy drift"],
["Gradient reg Ξ»_G", "0.0 - 0.1", "0.01", "Smooths energy landscape"],
["Noise std (DSM)", "0.01 - 0.5", "0.1", "Match to data scale"]
]
for row in hyperparam_data:
print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<35}")
print("\n" + "=" * 80)
print("9. RECENT ADVANCES AND FUTURE DIRECTIONS")
print("=" * 80)
print("\n㪠Recent Advances (2020-2024):")
print(" β’ Cooperative Training: Train EBM + generator jointly (faster MCMC)")
print(" β’ Flow Contrastive Estimation: Use flows as MCMC proposals")
print(" β’ Score-based diffusion: Unified framework (SDE/ODE solvers)")
print(" β’ Discrete EBMs: Graph generation, protein design")
print(" β’ Hardware acceleration: Custom chips for MCMC")
print("\nπ Future Directions:")
print(" β’ Learned samplers: Neural networks to amortize MCMC")
print(" β’ Theoretical analysis: Convergence guarantees, sample complexity")
print(" β’ Scalability: High-resolution images (1024Γ1024+)")
print(" β’ Scientific applications: Molecule design, physics simulation")
print(" β’ Multimodal: Vision + language energy functions")
print("\n" + "=" * 80)
# ============================================================
# RUN ALL DEMONSTRATIONS
# ============================================================
if __name__ == "__main__":
demo_energy_network()
demo_langevin_sampling()
demo_contrastive_divergence()
demo_score_matching()
demo_joint_energy_model()
demo_jem_training()
print_performance_comparison()