import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Energy-Based ModelsΒΆ

Energy Function:ΒΆ

\[p(x) = \frac{\exp(-E_\theta(x))}{Z_\theta}\]

where \(Z_\theta = \int \exp(-E_\theta(x)) dx\) is partition function.

Maximum Likelihood:ΒΆ

\[\nabla_\theta \log p(x) = -\nabla_\theta E_\theta(x) + \mathbb{E}_{x \sim p_\theta}[\nabla_\theta E_\theta(x)]\]

πŸ“š Reference Materials:

class EnergyNet(nn.Module):
    """Energy-based model."""
    
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        """Compute energy E(x)."""
        return self.net(x).squeeze()

print("EnergyNet defined")

2. Contrastive DivergenceΒΆ

Algorithm:ΒΆ

  1. Sample positive examples from data

  2. Generate negative examples via MCMC

  3. Update: decrease energy of positives, increase energy of negatives

def sample_langevin(energy_net, x_init, n_steps=60, step_size=10.0):
    """SGLD sampling from energy model."""
    x = x_init.clone().requires_grad_(True)
    
    for _ in range(n_steps):
        # Compute energy gradient
        energy = energy_net(x).sum()
        grad = torch.autograd.grad(energy, x, create_graph=False)[0]
        
        # Langevin update
        x = x - 0.5 * step_size * grad + np.sqrt(step_size) * torch.randn_like(x)
        x = torch.clamp(x, 0, 1)
        x = x.detach().requires_grad_(True)
    
    return x.detach()

print("Langevin sampling defined")

Training LoopΒΆ

Energy-based models (EBMs) are trained by contrastive divergence: push down the energy of real data samples and push up the energy of generated (negative) samples. The loss is \(\mathcal{L} = \mathbb{E}_{x \sim p_{\text{data}}}[E_\theta(x)] - \mathbb{E}_{x^- \sim p_\theta}[E_\theta(x^-)]\), where negative samples \(x^-\) are obtained via MCMC (typically Langevin dynamics) from the current model. Maintaining a replay buffer of previously generated negatives and initializing MCMC chains from buffer samples (persistent contrastive divergence) dramatically improves training stability and sample quality.

def train_ebm(model, train_loader, n_epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    losses = []
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        
        for x_pos, _ in train_loader:
            x_pos = x_pos.to(device)
            
            # Positive samples (data)
            energy_pos = model(x_pos)
            
            # Negative samples (generated)
            x_neg = torch.rand_like(x_pos)
            x_neg = sample_langevin(model, x_neg, n_steps=60)
            energy_neg = model(x_neg)
            
            # Contrastive divergence loss
            loss = energy_pos.mean() - energy_neg.mean()
            
            # Regularization
            reg = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
            loss = loss + 0.001 * reg
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
    
    return losses

print("Training function defined")

Train ModelΒΆ

During training, each step involves: (1) sampling a batch of real images, (2) running MCMC for a few steps to produce negative samples (or drawing from the replay buffer), (3) computing the contrastive divergence loss, and (4) taking a gradient step. The balance between positive and negative sample energies must be carefully maintained – if negative samples are too poor, the model learns trivial energy landscapes; if the MCMC runs too long, training becomes prohibitively slow. Spectral normalization and gradient penalties help keep the energy function well-behaved and prevent the loss from diverging.

# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

# Model
ebm = EnergyNet().to(device)

# Train
losses = train_ebm(ebm, train_loader, n_epochs=3)

Generate SamplesΒΆ

Generating samples from a trained EBM requires running MCMC (Langevin dynamics) for many steps starting from random noise, following the negative gradient of the energy function toward low-energy (high-probability) regions. Unlike GANs and VAEs which generate in a single forward pass, EBM sampling is iterative and can be computationally expensive, typically requiring hundreds to thousands of Langevin steps for high-quality results. The iterative nature is also an advantage: samples can be refined by running more steps, providing a quality-computation trade-off at inference time.

# Generate samples
ebm.eval()
n_samples = 16

x_init = torch.rand(n_samples, 1, 28, 28).to(device)
samples = sample_langevin(ebm, x_init, n_steps=200, step_size=10.0)

# Visualize
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
    ax = axes[i // 4, i % 4]
    ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')

plt.suptitle('EBM Generated Samples', fontsize=12)
plt.tight_layout()
plt.show()

Energy Landscape VisualizationΒΆ

Visualizing the learned energy landscape provides unique insight into what the model has learned about the data distribution. For 2D data, we can plot the energy function as a surface or contour map, revealing the basins (low energy) where the model places probability mass and the barriers (high energy) between modes. For image data, comparing the energy assigned to real images, generated samples, and random noise validates that the model correctly assigns low energy to plausible data and high energy to implausible inputs.

# Get real samples
real_samples, _ = next(iter(train_loader))
real_samples = real_samples[:16].to(device)

# Compute energies
with torch.no_grad():
    energy_real = ebm(real_samples)
    energy_generated = ebm(samples)

# Plot
fig, ax = plt.subplots(figsize=(10, 5))

ax.hist(energy_real.cpu().numpy(), bins=20, alpha=0.7, label='Real Data', edgecolor='black')
ax.hist(energy_generated.cpu().numpy(), bins=20, alpha=0.7, label='Generated', edgecolor='black')
ax.set_xlabel('Energy', fontsize=11)
ax.set_ylabel('Count', fontsize=11)
ax.set_title('Energy Distribution', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Real energy: {energy_real.mean().item():.3f}")
print(f"Generated energy: {energy_generated.mean().item():.3f}")

SummaryΒΆ

Energy-Based Models:ΒΆ

Key Concepts:

  1. Energy function \(E_\theta(x)\)

  2. Probability via Boltzmann: \(p(x) \propto \exp(-E(x))\)

  3. Training via contrastive divergence

  4. Sampling via Langevin dynamics

Training:ΒΆ

  • Decrease energy of data

  • Increase energy of generated samples

  • MCMC for negative samples

Advantages:ΒΆ

  • Flexible energy functions

  • No explicit density modeling

  • Natural for structured prediction

  • Composable models

Challenges:ΒΆ

  • Expensive sampling

  • Mode collapse

  • Partition function intractable

Applications:ΒΆ

  • Image generation

  • Denoising

  • Anomaly detection

  • Compositional generation

Advanced Energy-Based Models TheoryΒΆ

1. Mathematical FoundationsΒΆ

Energy Function and Probability DistributionΒΆ

Definition: An energy-based model defines a probability distribution via an energy function E_ΞΈ(x):

\[p_\theta(x) = \frac{\exp(-E_\theta(x))}{Z_\theta}\]

where Z_ΞΈ is the partition function:

\[Z_\theta = \int \exp(-E_\theta(x)) dx\]

Key Properties:

  • Lower energy β†’ higher probability

  • Z_ΞΈ ensures normalization but is typically intractable

  • E_ΞΈ(x) can be any differentiable function (neural network)

Maximum Likelihood GradientΒΆ

Theorem (MLE Gradient): The gradient of log-likelihood is:

\[\nabla_\theta \log p_\theta(x) = -\nabla_\theta E_\theta(x) + \mathbb{E}_{x' \sim p_\theta}[\nabla_\theta E_\theta(x')]\]

Interpretation:

  • First term: β€œPositive phase” - decrease energy of data

  • Second term: β€œNegative phase” - increase energy of model samples

  • Difference drives learning

Challenge: Computing expectation requires sampling from p_ΞΈ, which requires knowing Z_ΞΈ.

2. Contrastive Divergence (CD)ΒΆ

CD-k Algorithm (Hinton, 2002)ΒΆ

Key Idea: Approximate negative phase using k steps of MCMC starting from data.

Algorithm:

1. Initialize x⁰ from data distribution
2. Run k Gibbs sampling steps: x⁰ β†’ xΒΉ β†’ ... β†’ xᡏ
3. Gradient approximation:
   βˆ‡_ΞΈ L β‰ˆ -βˆ‡_ΞΈ E_ΞΈ(x⁰) + βˆ‡_ΞΈ E_ΞΈ(xᡏ)
4. Update: ΞΈ ← ΞΈ - Ξ±βˆ‡_ΞΈ L

Theoretical Justification:

  • CD-k minimizes difference between data distribution and k-step reconstruction

  • As k β†’ ∞, CD-k β†’ exact MLE gradient

  • In practice, k=1 or k=5 works well

Bias: CD-k is biased but has lower variance than exact gradient.

Persistent Contrastive Divergence (PCD)ΒΆ

Key Improvement: Maintain persistent β€œfantasy particles” across training iterations.

Algorithm:

1. Initialize persistent chains {x̃ᡒ} randomly
2. For each mini-batch:
   a. Sample x⁺ from data
   b. Run k MCMC steps from {x̃ᡒ}: x̃ᡒ → x̃'ᡒ
   c. Update: βˆ‡_ΞΈ L β‰ˆ -βˆ‡_ΞΈ E_ΞΈ(x⁺) + βˆ‡_ΞΈ E_ΞΈ(xΜƒ')
   d. Keep updated x̃' for next iteration

Advantages:

  • Better mixing: chains don’t restart from data

  • More accurate negative phase approximation

  • Faster convergence in practice

3. Score Matching for EBMsΒΆ

Connection to Score-Based ModelsΒΆ

Theorem: For EBM p(x) ∝ exp(-E(x)), the score is:

\[\nabla_x \log p(x) = -\nabla_x E(x) - \nabla_x \log Z = -\nabla_x E(x)\]

(partition function Z is constant w.r.t. x)

Denoising Score Matching LossΒΆ

Objective: Train energy via denoising instead of contrastive divergence:

\[\mathcal{L}_{DSM}(\theta) = \mathbb{E}_{x_0 \sim p_{data}, \epsilon \sim \mathcal{N}(0,\sigma^2 I)} \left[ \|\nabla_{x_t} E_\theta(x_t) + \frac{\epsilon}{\sigma^2}\|^2 \right]\]

where x_t = x_0 + Ξ΅.

Advantages:

  • No MCMC sampling required

  • Avoids partition function

  • More stable training

4. Restricted Boltzmann Machines (RBMs)ΒΆ

ArchitectureΒΆ

Bipartite Graph: Visible units v ∈ ℝⁿᡛ and hidden units h ∈ ℝⁿʰ.

Energy Function:

\[E(v, h) = -v^T W h - b^T v - c^T h\]

where W is weight matrix, b and c are biases.

Joint Distribution:

\[p(v, h) = \frac{\exp(-E(v,h))}{Z}\]

Conditional IndependenceΒΆ

Key Property: Given h, visible units are independent:

\[p(v_i = 1 | h) = \sigma(b_i + \sum_j W_{ij} h_j)\]

Similarly for p(h | v).

Gibbs SamplingΒΆ

Block Gibbs:

1. Sample h ~ p(h | v)
2. Sample v ~ p(v | h)
3. Repeat

Efficiency: Parallel sampling within each layer due to conditional independence.

Training RBMsΒΆ

CD-1 for RBMs:

1. v⁰ ← data sample
2. h⁰ ~ p(h | v⁰)
3. v¹ ~ p(v | h⁰)
4. hΒΉ ~ p(h | vΒΉ)
5. Ξ”W ∝ <vh>_data - <vh>_recon = v⁰h⁰ᡀ - vΒΉhΒΉα΅€

Practical Tips:

  • Learning rate: 0.01 - 0.1

  • Weight decay for regularization

  • Momentum (0.9) for faster convergence

5. Deep Belief Networks (DBNs)ΒΆ

ArchitectureΒΆ

Stack of RBMs: Layer-wise unsupervised pretraining.

Greedy Layer-wise Training:

1. Train RBM₁ on data
2. Use h₁ as input to train RBMβ‚‚
3. Repeat for L layers
4. Fine-tune entire network

Theoretical Justification: Each layer increases lower bound on data likelihood.

Fine-tuningΒΆ

Wake-Sleep Algorithm:

  • Wake phase: Update recognition weights (bottom-up)

  • Sleep phase: Update generative weights (top-down)

Modern Approach: Fine-tune with backpropagation after pretraining.

6. Modern Energy-Based ModelsΒΆ

Joint Energy-Based Models (JEMs)ΒΆ

Key Idea: Single model for both generation and classification.

Energy Function:

\[E_\theta(x, y) = -\log p_\theta(y|x) - \log p_\theta(x)\]

Training:

  • Classification: Standard cross-entropy on p(y|x)

  • Generation: SGLD to sample from p(x)

Advantages:

  • Unified model

  • Better calibration

  • Out-of-distribution detection

Conditional EBMsΒΆ

For Discriminative Tasks:

\[p(y | x) = \frac{\exp(-E_\theta(x, y))}{\sum_{y'} \exp(-E_\theta(x, y'))}\]

Applications:

  • Structured prediction (segmentation, parsing)

  • Image-to-image translation

  • Conditional generation

7. Noise Contrastive Estimation (NCE)ΒΆ

Key IdeaΒΆ

Objective: Distinguish data from noise distribution p_n.

NCE Loss:

\[\mathcal{L}_{NCE} = \mathbb{E}_{x \sim p_{data}} [\log h_\theta(x)] + k \cdot \mathbb{E}_{x \sim p_n} [\log(1 - h_\theta(x))]\]

where h_ΞΈ(x) = p_ΞΈ(x)/(p_ΞΈ(x) + kΒ·p_n(x)) and k is noise samples per data sample.

Advantages:

  • Approximate partition function as learnable parameter

  • Easier than MCMC sampling

  • Scales to high dimensions

Connection to GANsΒΆ

NCE vs GANs:

  • NCE: Fixed noise distribution

  • GANs: Learned noise (generator)

  • Both avoid explicit density modeling

8. MCMC Sampling MethodsΒΆ

Langevin DynamicsΒΆ

Update Rule:

\[x_{t+1} = x_t - \frac{\epsilon}{2} \nabla_x E_\theta(x_t) + \sqrt{\epsilon} z_t\]

where z_t ~ N(0, I).

Convergence: As Ξ΅ β†’ 0 and T β†’ ∞, samples converge to p_ΞΈ.

Hamiltonian Monte Carlo (HMC)ΒΆ

Augmented State: (x, v) where v is momentum.

Hamiltonian:

\[H(x, v) = E_\theta(x) + \frac{1}{2}v^T v\]

Leapfrog Integration:

1. v ← v - (Ξ΅/2)βˆ‡_x E(x)
2. x ← x + Ξ΅ v
3. v ← v - (Ξ΅/2)βˆ‡_x E(x)

Advantages:

  • Better mixing than Langevin

  • Fewer rejections

  • Explores energy landscape efficiently

Replica Exchange Monte CarloΒΆ

Parallel Tempering: Run chains at different temperatures T₁ < Tβ‚‚ < … < Tβ‚™.

Exchange Step: Swap states between adjacent temperatures with probability:

\[\alpha = \min\left(1, \exp\left[-(E(x_i) - E(x_j))\left(\frac{1}{T_i} - \frac{1}{T_j}\right)\right]\right)\]

Benefit: High-temperature chains escape local modes, low-temperature chains sample target.

9. Training ImprovementsΒΆ

Spectral NormalizationΒΆ

Objective: Constrain Lipschitz constant of energy function.

Method: Normalize weights by largest singular value:

\[W_{SN} = W / \sigma(W)\]

where Οƒ(W) is estimated via power iteration.

Benefit: Stabilizes training, prevents energy collapse.

Energy RegularizationΒΆ

Objectives:

  • Pull-away term: Encourage diversity in generated samples

  • Entropy maximization: Spread probability mass

  • Squared energy: Prevent unbounded energies

Combined Loss:

\[\mathcal{L} = \mathcal{L}_{CD} + \lambda_1 \mathbb{E}[(E(x))^2] + \lambda_2 H[p_\theta]\]

10. Evaluation MetricsΒΆ

Inception Score (IS)ΒΆ

Definition:

\[IS = \exp\left(\mathbb{E}_x [KL(p(y|x) \| p(y))]\right)\]

Interpretation:

  • High if samples are diverse (high H[p(y)])

  • High if samples are discriminative (low H[p(y|x)])

FrΓ©chet Inception Distance (FID)ΒΆ

Definition: Distance between data and generated distributions in Inception feature space:

\[FID = \|\mu_{data} - \mu_{gen}\|^2 + \text{Tr}(\Sigma_{data} + \Sigma_{gen} - 2(\Sigma_{data} \Sigma_{gen})^{1/2})\]

Lower is better: FID < 10 is excellent.

Log-Likelihood via Annealed Importance Sampling (AIS)ΒΆ

Method: Bridge between prior pβ‚€ (tractable) and target p_ΞΈ (intractable).

Estimator:

\[\log Z_\theta \approx \log Z_0 + \frac{1}{K} \sum_{k=1}^K \log w_k\]

where weights w_k come from intermediate distributions.

Use: Estimate log p_ΞΈ(x) = -E_ΞΈ(x) - log Z_ΞΈ.

11. State-of-the-Art ModelsΒΆ

EBM for Image GenerationΒΆ

Best Results (as of 2023):

  • JEM (Joint Energy-based Model): CIFAR-10 FID ~38

  • Conjugate Energy-Based Models: FID ~15-20

  • Improved SGLD: Competitive with GANs on small datasets

Limitation: Still lag behind diffusion models and GANs on large-scale generation.

EBM for Compositional ReasoningΒΆ

Key Advantage: Energy composition E_total = E₁ + Eβ‚‚ + … + E_n.

Applications:

  • Multi-attribute generation (combine β€œred” + β€œcar” + β€œconvertible”)

  • Concept combination without retraining

  • Out-of-distribution generalization

Example (Du et al., 2020): Compose classifiers as energy functions for zero-shot tasks.

12. Connections to Other ModelsΒΆ

Energy vs Score-Based ModelsΒΆ

Aspect

Energy-Based

Score-Based

Core function

E(x)

s(x) = βˆ‡log p(x)

Relation

s(x) = -βˆ‡E(x)

E(x) = -log p(x) + C

Partition

Explicit challenge

Avoided

Training

CD, NCE

Score matching

Sampling

MCMC

Langevin, SDE

Energy vs GANsΒΆ

Similarities:

  • Implicit generation (no explicit p(x))

  • Adversarial training dynamics

Differences:

  • EBM: Single network + MCMC

  • GAN: Two networks (G, D), no MCMC

  • EBM: Energy landscape interpretation

  • GAN: Min-max game interpretation

Energy vs Normalizing FlowsΒΆ

Flows: Explicit density via invertible transforms.

  • Exact likelihood

  • Single forward/backward pass

  • Architectural constraints

EBMs: Flexible energy functions.

  • Intractable likelihood

  • MCMC sampling required

  • Any architecture

13. ApplicationsΒΆ

Image SynthesisΒΆ

Method: Train EBM on images, sample via SGLD.

Enhancements:

  • Multi-scale generation

  • Hierarchical sampling

  • Classifier guidance

DenoisingΒΆ

Approach: Energy as denoising autoencoder objective.

Advantage: Direct connection to score matching.

Anomaly DetectionΒΆ

Principle: Anomalies have high energy.

Method:

1. Train EBM on normal data
2. Threshold energy: E(x) > Ο„ β†’ anomaly

Benefit: No need for anomaly examples.

3D Shape GenerationΒΆ

Energy over Point Clouds:

\[E(X) = E_{local}(X) + E_{global}(X)\]

Sampling: MCMC over point positions.

Application: Generate novel 3D objects, complete partial scans.

Molecule DesignΒΆ

Energy as Chemical Property Predictor:

\[E(m) = -\text{stability}(m) - \text{solubility}(m) + \text{toxicity}(m)\]

Optimization: MCMC or gradient-based search in molecular space.

14. Practical ConsiderationsΒΆ

Hyperparameter TuningΒΆ

Parameter

Typical Range

Effect

Learning rate

1e-5 to 1e-3

Higher β†’ faster but unstable

SGLD steps

20-200

More β†’ better samples, slower

SGLD step size

0.1-10.0

Larger β†’ faster mixing

CD-k

k=1, 5, 10

Higher k β†’ better gradient

PCD chains

100-1000

More β†’ lower variance

Computational CostΒΆ

Training:

  • Forward pass: O(N) (N = network parameters)

  • MCMC sampling: O(KΒ·N) (K = SGLD steps)

  • Typical: 5-10x slower than GANs

Sampling:

  • Single sample: O(KΒ·N), K ~ 100-1000

  • Diffusion models: Often faster (20-50 steps)

Debugging TipsΒΆ

Check:

  • Energy distribution: Should separate data vs random

  • SGLD trajectory: Should converge to low energy

  • Gradient norms: Clip if exploding

  • Sample diversity: Avoid mode collapse

Common Issues:

  • Energy collapse: All samples have same energy β†’ use regularization

  • Poor mixing: SGLD stuck β†’ tune step size, use HMC

  • Slow convergence: β†’ PCD, spectral normalization

15. LimitationsΒΆ

Computational CostΒΆ

Sampling: MCMC is slow, especially in high dimensions.

  • 100-1000 gradient evaluations per sample

  • Diffusion models: 10-50 steps

  • GANs: 1 forward pass

Mode CoverageΒΆ

Challenge: MCMC may not explore all modes in limited steps.

Solutions:

  • Parallel tempering

  • Multiple chains

  • Hybrid MCMC methods

Training InstabilityΒΆ

Issues:

  • Energy can diverge

  • MCMC chains may not converge

  • Gradient variance high with CD

Mitigations:

  • Spectral normalization

  • Gradient clipping

  • PCD instead of CD

Theoretical GapsΒΆ

Open Questions:

  • When does CD-k converge?

  • How many MCMC steps sufficient?

  • Optimal architecture for energy function?

16. Recent Advances (2020-2024)ΒΆ

Diffusion Recovery Likelihood (DRL)ΒΆ

Idea: Combine diffusion and EBM for hybrid model.

Method: Use diffusion to initialize, EBM to refine.

Continuous-Time EBMsΒΆ

Formulation: Energy evolves continuously via SDE.

Advantage: Unifies discrete (CD) and continuous (Langevin) perspectives.

Energy-Based Priors for Inverse ProblemsΒΆ

Application: Solve inverse problems with EBM prior:

\[\arg\min_x E_{prior}(x) + \|Ax - y\|^2\]

Use Cases: MRI reconstruction, deblurring, super-resolution.

Self-Supervised EBMsΒΆ

Training: Use contrastive learning objectives (SimCLR, MoCo) as energy.

Benefit: Leverage large unlabeled datasets.

17. Key PapersΒΆ

FoundationsΒΆ

  • Hinton (2002): β€œTraining Products of Experts by Minimizing Contrastive Divergence”

  • Tieleman (2008): β€œTraining Restricted Boltzmann Machines using Approximations to the Likelihood Gradient”

Modern EBMsΒΆ

  • Du & Mordatch (2019): β€œImplicit Generation and Modeling with Energy Based Models”

  • Grathwohl et al. (2020): β€œYour Classifier is Secretly an Energy Based Model (JEM)”

  • Nijkamp et al. (2020): β€œLearning Energy-Based Models by Diffusion Recovery Likelihood”

Theoretical AnalysisΒΆ

  • Bengio et al. (2013): β€œEstimating or Propagating Gradients Through Stochastic Neurons”

  • Song & Kingma (2021): β€œHow to Train Your Energy-Based Models”

ApplicationsΒΆ

  • Du et al. (2020): β€œEnergy-Based Models for Atomic-Resolution Protein Conformations”

  • Xie et al. (2021): β€œLearning Energy-Based Models in High-Dimensional Spaces with Multi-Scale Denoising”

18. Comparison: EBMs vs Other Generative ModelsΒΆ

Model

Training

Sampling

Likelihood

Flexibility

EBM

CD, NCE, Score

MCMC (slow)

Intractable

High

GAN

Min-max

Fast (1 pass)

No

High

VAE

ELBO

Fast (1 pass)

Approx.

Medium

Flow

Exact MLE

Fast (1 pass)

Exact

Low

Diffusion

Score match

Medium (20-50)

Via ODE

High

When to use EBMs:

  • Compositional generation (energy addition)

  • Flexible energy functions needed

  • Theoretical interpretability important

  • Small-scale tasks (computational cost acceptable)

When to avoid:

  • Need fast sampling (use GANs or flows)

  • Large-scale generation (use diffusion models)

  • Exact likelihood required (use flows or VAEs)

# Advanced Energy-Based Models Implementations

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, List

class SpectralNormConv2d(nn.Module):
    """Conv2d with spectral normalization for Lipschitz constraint."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, n_iter=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.n_iter = n_iter
        
        # Initialize spectral norm parameters
        self.register_buffer('u', torch.randn(1, out_channels))
        self.register_buffer('v', torch.randn(1, in_channels * kernel_size * kernel_size))
    
    def forward(self, x):
        """Apply spectrally normalized convolution."""
        # Get weight matrix
        weight = self.conv.weight.view(self.conv.out_channels, -1)
        
        if self.training:
            # Power iteration to estimate largest singular value
            u, v = self.u, self.v
            for _ in range(self.n_iter):
                v = F.normalize(u @ weight, dim=1, eps=1e-12)
                u = F.normalize(v @ weight.t(), dim=1, eps=1e-12)
            
            # Update buffers
            self.u.copy_(u.detach())
            self.v.copy_(v.detach())
            
            # Compute spectral norm
            sigma = (u @ weight @ v.t()).item()
        else:
            sigma = (self.u @ weight @ self.v.t()).item()
        
        # Normalize weight
        weight_sn = weight / (sigma + 1e-12)
        weight_sn = weight_sn.view_as(self.conv.weight)
        
        # Apply conv with normalized weight
        return F.conv2d(x, weight_sn, self.conv.bias, 
                       self.conv.stride, self.conv.padding)


class EnergyFunction(nn.Module):
    """
    Modern energy function with spectral normalization.
    Maps input x to scalar energy E(x).
    """
    
    def __init__(self, input_channels=1, base_channels=64, use_spectral_norm=True):
        super().__init__()
        
        Conv = SpectralNormConv2d if use_spectral_norm else nn.Conv2d
        
        self.encoder = nn.Sequential(
            Conv(input_channels, base_channels, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            Conv(base_channels, base_channels * 2, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            Conv(base_channels * 2, base_channels * 4, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            Conv(base_channels * 4, base_channels * 8, 4, 2, 1),
            nn.LeakyReLU(0.2),
        )
        
        # Output single energy value
        self.energy_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(base_channels * 8 * 2 * 2, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        """Compute energy E(x)."""
        features = self.encoder(x)
        energy = self.energy_head(features).squeeze(-1)
        return energy


class ContrastiveDivergenceTrainer:
    """
    Train EBM using Contrastive Divergence (CD-k).
    Approximates MLE gradient using k MCMC steps.
    """
    
    def __init__(self, energy_fn, optimizer, k_steps=1, sgld_lr=10.0, 
                 sgld_noise=True, energy_reg=0.001):
        self.energy_fn = energy_fn
        self.optimizer = optimizer
        self.k_steps = k_steps
        self.sgld_lr = sgld_lr
        self.sgld_noise = sgld_noise
        self.energy_reg = energy_reg
    
    def sgld_step(self, x, add_noise=True):
        """Single SGLD step: x - Ξ΅/2Β·βˆ‡E(x) + √Ρ·z."""
        x = x.clone().requires_grad_(True)
        
        # Compute energy gradient
        energy = self.energy_fn(x).sum()
        grad = torch.autograd.grad(energy, x, create_graph=False)[0]
        
        # SGLD update
        x_new = x - 0.5 * self.sgld_lr * grad
        
        if add_noise and self.sgld_noise:
            noise = torch.randn_like(x) * np.sqrt(self.sgld_lr)
            x_new = x_new + noise
        
        return torch.clamp(x_new.detach(), 0, 1)
    
    def sample_negative(self, x_init):
        """Generate negative samples via k SGLD steps."""
        x_neg = x_init.clone()
        
        for _ in range(self.k_steps):
            x_neg = self.sgld_step(x_neg)
        
        return x_neg
    
    def train_step(self, x_pos):
        """Single CD training step."""
        # Positive samples (data)
        energy_pos = self.energy_fn(x_pos)
        
        # Negative samples (SGLD from uniform)
        x_init = torch.rand_like(x_pos)
        x_neg = self.sample_negative(x_init)
        energy_neg = self.energy_fn(x_neg)
        
        # CD loss: increase energy gap
        cd_loss = energy_pos.mean() - energy_neg.mean()
        
        # Energy regularization (prevent unbounded energies)
        reg_loss = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
        
        loss = cd_loss + self.energy_reg * reg_loss
        
        # Backprop
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.energy_fn.parameters(), 1.0)
        self.optimizer.step()
        
        return {
            'loss': loss.item(),
            'energy_pos': energy_pos.mean().item(),
            'energy_neg': energy_neg.mean().item()
        }


class PersistentContrastiveDivergence:
    """
    Persistent CD: Maintain fantasy particles across iterations.
    Better negative samples with lower variance.
    """
    
    def __init__(self, energy_fn, optimizer, n_persistent=100, k_steps=1,
                 sgld_lr=10.0, energy_reg=0.001):
        self.energy_fn = energy_fn
        self.optimizer = optimizer
        self.k_steps = k_steps
        self.sgld_lr = sgld_lr
        self.energy_reg = energy_reg
        
        # Initialize persistent chains
        self.persistent_chains = None
        self.n_persistent = n_persistent
    
    def initialize_chains(self, shape):
        """Initialize persistent chains randomly."""
        self.persistent_chains = torch.rand(self.n_persistent, *shape)
    
    def sgld_step(self, x):
        """Single SGLD step."""
        x = x.clone().requires_grad_(True)
        
        energy = self.energy_fn(x).sum()
        grad = torch.autograd.grad(energy, x, create_graph=False)[0]
        
        x_new = x - 0.5 * self.sgld_lr * grad
        x_new = x_new + torch.randn_like(x) * np.sqrt(self.sgld_lr)
        
        return torch.clamp(x_new.detach(), 0, 1)
    
    def train_step(self, x_pos):
        """PCD training step."""
        batch_size = x_pos.size(0)
        
        # Initialize chains if needed
        if self.persistent_chains is None:
            self.initialize_chains(x_pos.shape[1:])
        
        # Move chains to same device
        self.persistent_chains = self.persistent_chains.to(x_pos.device)
        
        # Sample from persistent chains
        indices = torch.randperm(self.n_persistent)[:batch_size]
        x_neg = self.persistent_chains[indices].clone()
        
        # Run k SGLD steps
        for _ in range(self.k_steps):
            x_neg = self.sgld_step(x_neg)
        
        # Update persistent chains
        self.persistent_chains[indices] = x_neg.detach()
        
        # Compute energies
        energy_pos = self.energy_fn(x_pos)
        energy_neg = self.energy_fn(x_neg)
        
        # PCD loss
        cd_loss = energy_pos.mean() - energy_neg.mean()
        reg_loss = (energy_pos ** 2).mean() + (energy_neg ** 2).mean()
        
        loss = cd_loss + self.energy_reg * reg_loss
        
        # Backprop
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.energy_fn.parameters(), 1.0)
        self.optimizer.step()
        
        return {
            'loss': loss.item(),
            'energy_pos': energy_pos.mean().item(),
            'energy_neg': energy_neg.mean().item()
        }


class ScoreMatchingEBM:
    """
    Train EBM via denoising score matching.
    Avoids MCMC sampling during training.
    """
    
    def __init__(self, energy_fn, optimizer, sigma=0.1):
        self.energy_fn = energy_fn
        self.optimizer = optimizer
        self.sigma = sigma
    
    def train_step(self, x_clean):
        """Denoising score matching step."""
        # Add noise
        noise = torch.randn_like(x_clean) * self.sigma
        x_noisy = x_clean + noise
        
        # Compute score (gradient of energy)
        x_noisy_grad = x_noisy.clone().requires_grad_(True)
        energy = self.energy_fn(x_noisy_grad).sum()
        score_pred = -torch.autograd.grad(energy, x_noisy_grad, create_graph=True)[0]
        
        # Target score: -noise / sigma^2
        score_target = -noise / (self.sigma ** 2)
        
        # Score matching loss
        loss = 0.5 * ((score_pred - score_target) ** 2).sum(dim=[1,2,3]).mean()
        
        # Backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return {'loss': loss.item()}


class RestrictedBoltzmannMachine(nn.Module):
    """
    RBM: p(v,h) = exp(-E(v,h))/Z
    E(v,h) = -v^T W h - b^T v - c^T h
    """
    
    def __init__(self, n_visible, n_hidden):
        super().__init__()
        
        # Parameters
        self.W = nn.Parameter(torch.randn(n_visible, n_hidden) * 0.01)
        self.b = nn.Parameter(torch.zeros(n_visible))
        self.c = nn.Parameter(torch.zeros(n_hidden))
    
    def energy(self, v, h):
        """Compute E(v,h)."""
        interaction = -(v @ self.W @ h.t()).diagonal()
        visible_bias = -(v @ self.b)
        hidden_bias = -(h @ self.c)
        return interaction + visible_bias + hidden_bias
    
    def sample_h_given_v(self, v):
        """p(h=1|v) = Οƒ(c + W^T v)."""
        activation = v @ self.W + self.c
        prob_h = torch.sigmoid(activation)
        h = torch.bernoulli(prob_h)
        return h, prob_h
    
    def sample_v_given_h(self, h):
        """p(v=1|h) = Οƒ(b + W h)."""
        activation = h @ self.W.t() + self.b
        prob_v = torch.sigmoid(activation)
        v = torch.bernoulli(prob_v)
        return v, prob_v
    
    def gibbs_step(self, v):
        """Single Gibbs sampling step."""
        h, prob_h = self.sample_h_given_v(v)
        v_recon, prob_v = self.sample_v_given_h(h)
        return v_recon, prob_v, h, prob_h
    
    def cd_update(self, v_data, k=1, lr=0.01):
        """CD-k update rule."""
        batch_size = v_data.size(0)
        
        # Positive phase
        h0, prob_h0 = self.sample_h_given_v(v_data)
        
        # Negative phase: k Gibbs steps
        v_k = v_data.clone()
        for _ in range(k):
            v_k, prob_v_k, h_k, prob_h_k = self.gibbs_step(v_k)
        
        # Compute gradients
        # Ξ”W ∝ <vh>_data - <vh>_recon
        positive_grad = (v_data.t() @ prob_h0) / batch_size
        negative_grad = (prob_v_k.t() @ prob_h_k) / batch_size
        
        grad_W = positive_grad - negative_grad
        grad_b = (v_data - prob_v_k).mean(0)
        grad_c = (prob_h0 - prob_h_k).mean(0)
        
        # Update parameters
        with torch.no_grad():
            self.W += lr * grad_W
            self.b += lr * grad_b
            self.c += lr * grad_c
        
        # Reconstruction error
        recon_error = ((v_data - prob_v_k) ** 2).mean().item()
        
        return {'recon_error': recon_error}


class LangevinSampler:
    """
    SGLD sampler: x_{t+1} = x_t - Ξ΅/2Β·βˆ‡E(x_t) + √Ρ·z_t.
    """
    
    def __init__(self, energy_fn, step_size=1.0, n_steps=100, 
                 noise=True, clip_range=(0, 1)):
        self.energy_fn = energy_fn
        self.step_size = step_size
        self.n_steps = n_steps
        self.noise = noise
        self.clip_range = clip_range
    
    def sample(self, x_init):
        """Generate samples via Langevin dynamics."""
        x = x_init.clone()
        
        trajectory = [x.clone()]
        energies = []
        
        for step in range(self.n_steps):
            x = x.requires_grad_(True)
            
            # Compute energy and gradient
            energy = self.energy_fn(x).sum()
            grad = torch.autograd.grad(energy, x)[0]
            
            # SGLD update
            x = x - 0.5 * self.step_size * grad
            
            if self.noise:
                x = x + torch.randn_like(x) * np.sqrt(self.step_size)
            
            # Clip to valid range
            x = torch.clamp(x.detach(), *self.clip_range)
            
            trajectory.append(x.clone())
            energies.append(energy.item())
        
        return x, trajectory, energies


class HamiltonianMonteCarlo:
    """
    HMC sampler with leapfrog integration.
    Better mixing than Langevin dynamics.
    """
    
    def __init__(self, energy_fn, step_size=0.1, n_leapfrog=10):
        self.energy_fn = energy_fn
        self.step_size = step_size
        self.n_leapfrog = n_leapfrog
    
    def hamiltonian(self, x, v):
        """H(x,v) = E(x) + 0.5Β·v^T v."""
        potential = self.energy_fn(x).sum()
        kinetic = 0.5 * (v ** 2).sum()
        return potential + kinetic
    
    def leapfrog_step(self, x, v):
        """Leapfrog integration for Hamiltonian dynamics."""
        # Half step for momentum
        x_grad = x.clone().requires_grad_(True)
        energy = self.energy_fn(x_grad).sum()
        grad = torch.autograd.grad(energy, x_grad)[0]
        
        v = v - 0.5 * self.step_size * grad
        
        # Full step for position
        x = x + self.step_size * v
        
        # Half step for momentum
        x_grad = x.clone().requires_grad_(True)
        energy = self.energy_fn(x_grad).sum()
        grad = torch.autograd.grad(energy, x_grad)[0]
        
        v = v - 0.5 * self.step_size * grad
        
        return x.detach(), v.detach()
    
    def sample_step(self, x):
        """Single HMC step with Metropolis-Hastings."""
        # Sample momentum
        v = torch.randn_like(x)
        
        # Current Hamiltonian
        H_current = self.hamiltonian(x, v)
        
        # Leapfrog integration
        x_new, v_new = x.clone(), v.clone()
        for _ in range(self.n_leapfrog):
            x_new, v_new = self.leapfrog_step(x_new, v_new)
        
        # Proposed Hamiltonian
        H_proposed = self.hamiltonian(x_new, v_new)
        
        # Metropolis-Hastings acceptance
        accept_prob = torch.exp(H_current - H_proposed).item()
        
        if np.random.rand() < min(1.0, accept_prob):
            return x_new, True
        else:
            return x, False


# ============================================================================
# Demonstrations
# ============================================================================

print("=" * 60)
print("Energy-Based Models - Advanced Implementations")
print("=" * 60)

# 1. Energy function with spectral normalization
print("\n1. Energy Function:")
energy_fn = EnergyFunction(input_channels=1, base_channels=32, use_spectral_norm=True)
x_test = torch.randn(4, 1, 32, 32)
energy_test = energy_fn(x_test)
print(f"   Input shape: {x_test.shape}")
print(f"   Output (energies): {energy_test.shape}")
print(f"   Energy values: {energy_test.detach().numpy()}")

# 2. CD vs PCD comparison
print("\n2. Contrastive Divergence Methods:")
print("   CD-k: Restart MCMC from data each iteration")
print("   PCD: Maintain persistent fantasy particles")
print(f"   CD variance: Higher (fresh chains)")
print(f"   PCD variance: Lower (evolved chains)")
print(f"   Typical k: 1 (CD-1) or 5 (CD-5)")

# 3. RBM architecture
print("\n3. Restricted Boltzmann Machine:")
rbm = RestrictedBoltzmannMachine(n_visible=784, n_hidden=128)
print(f"   Visible units: 784")
print(f"   Hidden units: 128")
print(f"   Parameters: {784*128 + 784 + 128:,}")
print(f"   Energy: E(v,h) = -v^T W h - b^T v - c^T h")

# 4. Sampling comparison
print("\n4. MCMC Sampling Methods:")
print("   Langevin: x ← x - Ξ΅/2Β·βˆ‡E + √Ρ·z")
print("   HMC: Augment with momentum, leapfrog integration")
print("   Advantages:")
print("     - Langevin: Simple, gradient-based")
print("     - HMC: Better mixing, fewer rejections")

# 5. Training methods comparison
print("\n5. Training Method Comparison:")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Method          β”‚ Sampling β”‚ Stability β”‚ Speed       β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ CD-k            β”‚ Yes      β”‚ Medium    β”‚ Slow        β”‚")
print("   β”‚ PCD             β”‚ Yes      β”‚ Higher    β”‚ Slow        β”‚")
print("   β”‚ Score Matching  β”‚ No       β”‚ High      β”‚ Fast        β”‚")
print("   β”‚ NCE             β”‚ Partial  β”‚ High      β”‚ Medium      β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")

# 6. When to use guide
print("\n6. When to Use Each Method:")
print("   Use CD-1:")
print("     βœ“ Small models, limited compute")
print("     βœ“ Quick prototyping")
print("     βœ“ RBMs and simple architectures")
print("\n   Use PCD:")
print("     βœ“ Better samples needed")
print("     βœ“ Larger models")
print("     βœ“ Can maintain persistent chains")
print("\n   Use Score Matching:")
print("     βœ“ Avoid MCMC during training")
print("     βœ“ Faster training")
print("     βœ“ Partition function intractable")
print("\n   Use NCE:")
print("     βœ“ High-dimensional data")
print("     βœ“ Approximate partition function acceptable")
print("     βœ“ Connection to contrastive learning")

print("\n" + "=" * 60)

Advanced Energy-Based Models: Mathematical Foundations and Modern ArchitecturesΒΆ

1. Introduction to Energy-Based ModelsΒΆ

Energy-based models (EBMs) provide a unified framework for learning probability distributions by defining an energy function \(E_\theta(x)\) that assigns low energy to likely data points and high energy to unlikely ones. Unlike explicit generative models (GANs, VAEs, normalizing flows), EBMs model the data distribution implicitly through the Boltzmann distribution:

\[p_\theta(x) = \frac{e^{-E_\theta(x)}}{Z_\theta}\]

where \(Z_\theta = \int e^{-E_\theta(x)} dx\) is the partition function (intractable in general).

Key properties:

  • Flexibility: Energy function can be any neural network without architectural constraints

  • Expressiveness: Can model complex multimodal distributions

  • Compositionality: Energies can be added/composed for structured modeling

  • Implicit normalization: No need for explicit normalizing flows or sampling networks

Advantages over other generative models:

  • No generator network: Unlike GANs, no need for adversarial training

  • No encoder bottleneck: Unlike VAEs, no variational approximation

  • No invertibility constraints: Unlike normalizing flows, arbitrary architectures allowed

Challenges:

  • Intractable partition function: Computing \(Z_\theta\) requires integration over all possible \(x\)

  • Expensive sampling: MCMC methods slow for high-dimensional data

  • Training instability: Contrastive divergence and score matching can be unstable

2. Mathematical FrameworkΒΆ

2.1 Energy Function and Probability DistributionΒΆ

The energy function \(E_\theta: \mathcal{X} \to \mathbb{R}\) maps inputs to scalar energies. The probability distribution is:

\[p_\theta(x) = \frac{\exp(-E_\theta(x))}{Z_\theta}, \quad Z_\theta = \int_{\mathcal{X}} \exp(-E_\theta(x)) dx\]

Properties:

  • Lower energy β†’ higher probability: \(E_\theta(x_1) < E_\theta(x_2) \implies p_\theta(x_1) > p_\theta(x_2)\)

  • Energy is defined up to a constant: \(E_\theta(x) + c\) gives same distribution

  • Partition function ensures normalization: \(\int p_\theta(x) dx = 1\)

Log-likelihood:

\[\log p_\theta(x) = -E_\theta(x) - \log Z_\theta\]

Gradient with respect to parameters:

\[\nabla_\theta \log p_\theta(x) = -\nabla_\theta E_\theta(x) + \mathbb{E}_{x' \sim p_\theta}[\nabla_\theta E_\theta(x')]\]

The second term (gradient of log partition function) requires sampling from the model distribution, which is the main computational challenge.

2.2 Maximum Likelihood TrainingΒΆ

Given dataset \(\{x_i\}_{i=1}^N\), maximize log-likelihood:

\[\mathcal{L}(\theta) = \frac{1}{N} \sum_{i=1}^N \log p_\theta(x_i) = -\frac{1}{N} \sum_{i=1}^N E_\theta(x_i) - \log Z_\theta\]

Gradient:

\[\nabla_\theta \mathcal{L}(\theta) = -\frac{1}{N} \sum_{i=1}^N \nabla_\theta E_\theta(x_i) + \mathbb{E}_{x \sim p_\theta}[\nabla_\theta E_\theta(x)]\]

This is a positive-negative gradient:

  • Positive phase: Pull down energy on data samples

  • Negative phase: Push up energy on model samples

Challenge: Computing \(\mathbb{E}_{x \sim p_\theta}[\nabla_\theta E_\theta(x)]\) requires sampling from \(p_\theta\), which is intractable.

Solutions:

  1. Contrastive Divergence (CD): Approximate with short MCMC chains

  2. Score Matching: Avoid partition function entirely

  3. Noise Contrastive Estimation (NCE): Compare data vs noise distribution

3. Contrastive Divergence (CD)ΒΆ

3.1 AlgorithmΒΆ

Contrastive Divergence (Hinton, 2002) approximates the negative phase gradient by running short MCMC chains starting from data:

CD-k algorithm:

  1. Initialize \(x^{(0)} \sim p_{\text{data}}\) (data sample)

  2. Run \(k\) steps of Gibbs/Langevin sampling: \(x^{(k)} \sim p_\theta^{(k)}\)

  3. Approximate gradient:

\[\nabla_\theta \mathcal{L}(\theta) \approx -\nabla_\theta E_\theta(x^{(0)}) + \nabla_\theta E_\theta(x^{(k)})\]

Intuition:

  • \(x^{(0)}\) is data (low energy desired)

  • \(x^{(k)}\) is β€œfantasy” sample after \(k\) MCMC steps (high energy desired)

  • Gradient pushes down data energy, pushes up fantasy energy

Common choices:

  • CD-1: Single Gibbs/Langevin step (fast but biased)

  • CD-10: 10 steps (slower but less biased)

  • Persistent CD (PCD): Maintain persistent MCMC chains across batches

3.2 Langevin Dynamics SamplingΒΆ

For continuous data, use Langevin dynamics to sample from \(p_\theta(x)\):

\[x_{t+1} = x_t - \frac{\epsilon}{2} \nabla_x E_\theta(x_t) + \sqrt{\epsilon} \, z_t, \quad z_t \sim \mathcal{N}(0, I)\]

where:

  • \(\epsilon\) is the step size

  • \(\nabla_x E_\theta(x_t)\) is the energy gradient (drives to low energy regions)

  • \(\sqrt{\epsilon} \, z_t\) is noise (ensures proper exploration)

Convergence: As \(\epsilon \to 0\) and \(T \to \infty\), \(x_T \sim p_\theta(x)\).

Practical: Use finite \(\epsilon\) and \(T\) (e.g., \(\epsilon=0.01\), \(T=100\) steps).

3.3 Persistent Contrastive Divergence (PCD)ΒΆ

Maintain persistent MCMC chains \(\{x_i^{\text{chain}}\}\) across batches:

Algorithm:

  1. Initialize chains randomly

  2. Each iteration:

    • Sample data batch \(\{x_i^{\text{data}}\}\)

    • Update chains: \(x_i^{\text{chain}} \leftarrow\) Langevin step from \(x_i^{\text{chain}}\)

    • Gradient: \(\nabla_\theta \mathcal{L} \approx -\nabla_\theta E_\theta(x^{\text{data}}) + \nabla_\theta E_\theta(x^{\text{chain}})\)

Advantage: Chains better approximate \(p_\theta\) over time (less bias than CD-k).

Disadvantage: Chains can become β€œstuck” in local modes.

4. Score MatchingΒΆ

4.1 MotivationΒΆ

Score matching (HyvΓ€rinen, 2005) avoids the partition function by matching the score (gradient of log-density):

\[s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x E_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x E_\theta(x)\]

Note: \(\nabla_x \log Z_\theta = 0\) because \(Z_\theta\) doesn’t depend on \(x\).

Key insight: Score doesn’t require computing \(Z_\theta\)!

4.2 Explicit Score MatchingΒΆ

Match model score to data score:

\[\mathcal{L}_{\text{ESM}}(\theta) = \frac{1}{2} \mathbb{E}_{x \sim p_{\text{data}}}[\|\nabla_x \log p_\theta(x) - \nabla_x \log p_{\text{data}}(x)\|^2]\]

Problem: \(\nabla_x \log p_{\text{data}}(x)\) is unknown.

Solution: Integration by parts gives equivalent objective (HyvΓ€rinen, 2005):

\[\mathcal{L}_{\text{ESM}}(\theta) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\frac{1}{2}\|\nabla_x E_\theta(x)\|^2 + \text{tr}(\nabla_x^2 E_\theta(x))\right] + \text{const}\]

where \(\text{tr}(\nabla_x^2 E_\theta(x)) = \sum_{i=1}^D \frac{\partial^2 E_\theta(x)}{\partial x_i^2}\) is the trace of the Hessian.

Gradient:

\[\nabla_\theta \mathcal{L}_{\text{ESM}}(\theta) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\nabla_x E_\theta(x)^T \nabla_\theta \nabla_x E_\theta(x) + \nabla_\theta \text{tr}(\nabla_x^2 E_\theta(x))\right]\]

Computational cost: Requires computing Hessian trace (expensive for high-dimensional data).

4.3 Denoising Score Matching (DSM)ΒΆ

Vincent (2011) proposed a computationally efficient alternative:

Setup: Perturb data with noise \(q(x|x_0) = \mathcal{N}(x | x_0, \sigma^2 I)\).

Objective:

\[\mathcal{L}_{\text{DSM}}(\theta) = \frac{1}{2} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x \sim q(x|x_0)}\left[\left\|\nabla_x E_\theta(x) + \frac{x - x_0}{\sigma^2}\right\|^2\right]\]

Interpretation:

  • True score under noise perturbation: \(\nabla_x \log q(x|x_0) = -(x - x_0)/\sigma^2\)

  • Match model score \(-\nabla_x E_\theta(x)\) to this

Advantage: No Hessian computation! Only first-order gradient \(\nabla_x E_\theta(x)\).

Equivalence: DSM is equivalent to explicit score matching under certain conditions.

4.4 Sliced Score Matching (SSM)ΒΆ

Song et al. (2019) proposed sliced score matching to further reduce computational cost:

\[\mathcal{L}_{\text{SSM}}(\theta) = \frac{1}{2} \mathbb{E}_{x \sim p_{\text{data}}} \mathbb{E}_{v \sim p_v}\left[v^T \nabla_x^2 E_\theta(x) v + \frac{1}{2}(v^T \nabla_x E_\theta(x))^2\right]\]

where \(v \sim p_v\) is a random direction (e.g., \(p_v = \mathcal{N}(0, I)\)).

Advantage:

  • Hessian-vector product \(\nabla_x^2 E_\theta(x) v\) can be computed efficiently via automatic differentiation

  • Cost: \(O(D)\) instead of \(O(D^2)\) for full Hessian

Implementation: Use forward-mode AD or double backward pass.

5. Noise Contrastive Estimation (NCE)ΒΆ

5.1 PrincipleΒΆ

Noise Contrastive Estimation (Gutmann & HyvΓ€rinen, 2010) treats density estimation as binary classification:

Setup:

  • Data distribution: \(p_{\text{data}}(x)\)

  • Noise distribution: \(p_n(x)\) (e.g., uniform or Gaussian)

  • Sample ratio: \(\nu\) noise samples per data sample

Binary classification:

  • Label \(y=1\) for data: \(x \sim p_{\text{data}}\)

  • Label \(y=0\) for noise: \(x \sim p_n\)

Posterior:

\[P(y=1|x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + \nu p_n(x)} = \frac{p_\theta(x)}{p_\theta(x) + \nu p_n(x)} = \frac{e^{-E_\theta(x)}}{e^{-E_\theta(x)} + \nu Z_\theta p_n(x)}\]

Key trick: Treat \(\log Z_\theta\) as a learnable parameter \(c\):

\[P(y=1|x) = \frac{1}{1 + \nu e^{E_\theta(x) + c} p_n(x)} = \sigma(-E_\theta(x) - c - \log(\nu p_n(x)))\]

where \(\sigma\) is the sigmoid function.

5.2 NCE LossΒΆ

Binary cross-entropy:

\[\mathcal{L}_{\text{NCE}}(\theta, c) = -\mathbb{E}_{x \sim p_{\text{data}}}[\log P(y=1|x)] - \nu \mathbb{E}_{x \sim p_n}[\log P(y=0|x)]\]

Expanding:

\[\mathcal{L}_{\text{NCE}} = -\mathbb{E}_{x \sim p_{\text{data}}}\left[\log \sigma(-E_\theta(x) - c - \log(\nu p_n(x)))\right] - \nu \mathbb{E}_{x \sim p_n}\left[\log \sigma(E_\theta(x) + c + \log(\nu p_n(x)))\right]\]

Optimization: Jointly optimize \(\theta\) and \(c\) via gradient descent.

Advantage: Avoids computing partition function or sampling from model.

Disadvantage: Requires good noise distribution \(p_n\) (poor choice degrades performance).

6. Modern EBM ArchitecturesΒΆ

6.1 Joint Energy-Based Models (JEMs)ΒΆ

Grathwohl et al. (2020) proposed using a single network for both classification and generation:

Energy function: $\(E_\theta(x, y) = -f_\theta(x)[y]\)$

where \(f_\theta(x) \in \mathbb{R}^C\) is a classifier logit vector.

Distributions:

  • Conditional: \(p_\theta(y|x) = \frac{\exp(f_\theta(x)[y])}{\sum_{y'} \exp(f_\theta(x)[y'])}\) (standard classifier)

  • Marginal: \(p_\theta(x) = \frac{\sum_y \exp(f_\theta(x)[y])}{Z_\theta} = \frac{\exp(\text{LogSumExp}(f_\theta(x)))}{Z_\theta}\)

Training:

  1. Classification: Standard cross-entropy on \((x, y)\) pairs

  2. Generation: Contrastive divergence on marginal \(p_\theta(x)\)

Benefits:

  • Single network for both tasks

  • Improved robustness (adversarial examples have high energy)

  • Out-of-distribution detection (OOD samples have high energy)

6.2 IGEBM (Implicit Generation and Likelihood Estimation)ΒΆ

Du & Mordatch (2019) trained EBMs for high-resolution image generation:

Architecture: ResNet-based energy function \(E_\theta(x)\)

Training:

  1. Sample data: \(x^+ \sim p_{\text{data}}\)

  2. Sample negatives via Langevin: \(x^- \sim p_\theta\) (100-200 steps)

  3. Update: \(\mathcal{L} = E_\theta(x^+) - E_\theta(x^-)\) (hinge loss variant)

Improvements:

  • Spectral normalization: Stabilize training (constrain Lipschitz constant)

  • MCMC initialization: Initialize chains from replay buffer (mix of old samples + noise)

  • Multiscale training: Train on progressively higher resolutions

Results: Generated high-quality 128Γ—128 and 256Γ—256 images (comparable to GANs).

6.3 EBM for Compositional GenerationΒΆ

EBMs naturally support compositional reasoning via energy addition:

Concept: Combine multiple energy functions $\(p(x | c_1, c_2) \propto \exp(-E_1(x, c_1) - E_2(x, c_2))\)$

Applications:

  • Multi-attribute generation: \(E(x, \text{color}, \text{shape})\)

  • Logical composition: AND (\(E_1 + E_2\)), OR (\(-\log(e^{-E_1} + e^{-E_2})\)), NOT (\(-E\))

  • Language-guided generation: \(E_{\text{img}}(x) + \lambda E_{\text{CLIP}}(x, \text{caption})\)

Example (Du et al., 2020):

  • Train separate classifiers \(f_1(x), f_2(x), \ldots\) for different attributes

  • Compose: \(p(x | y_1=1, y_2=1) \propto \exp(f_1(x)[1] + f_2(x)[1])\)

  • Sample via Langevin dynamics with combined energy

7. Score-Based Generative Models (Diffusion Models Connection)ΒΆ

7.1 Score-Based ModelsΒΆ

Song & Ermon (2019, 2020) proposed training models to estimate the score function directly:

Score network: \(s_\theta(x, t): \mathbb{R}^D \times \mathbb{R}_+ \to \mathbb{R}^D\)

Objective: Denoising score matching across noise levels $\(\mathcal{L}(\theta) = \mathbb{E}_{t \sim \mathcal{U}(0,T)} \mathbb{E}_{x_0 \sim p_{\text{data}}} \mathbb{E}_{x_t \sim q(x_t|x_0)}\left[\lambda(t) \left\|s_\theta(x_t, t) - \nabla_{x_t} \log q(x_t | x_0)\right\|^2\right]\)$

where:

  • \(q(x_t | x_0) = \mathcal{N}(x_t | \alpha_t x_0, \sigma_t^2 I)\) is noise perturbation

  • \(\nabla_{x_t} \log q(x_t | x_0) = -(x_t - \alpha_t x_0)/\sigma_t^2\) is the true score

  • \(\lambda(t)\) is a weighting function

Sampling: Reverse-time SDE (Langevin dynamics) $\(x_{t-\Delta t} = x_t + s_\theta(x_t, t) \Delta t + \sqrt{2\Delta t} \, z_t, \quad z_t \sim \mathcal{N}(0, I)\)$

7.2 Connection to Diffusion ModelsΒΆ

Equivalence: Denoising diffusion probabilistic models (DDPMs) are discrete-time score-based models.

DDPM: \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1} | \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)

Reparameterization: Instead of predicting \(\mu_\theta\), predict noise \(\epsilon_\theta(x_t, t)\): $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t))\)$

Score equivalence: $\(s_\theta(x_t, t) = -\frac{\epsilon_\theta(x_t, t)}{\sigma_t}\)$

Unified perspective (Song et al., 2021):

  • Both are trained via denoising score matching

  • Sampling: Ancestral sampling (DDPM) vs. SDE/ODE solvers (score-based)

  • State-of-the-art: Combine best of both (e.g., DDIM, DPM-Solver)

7.3 Noise ConditioningΒΆ

Challenge: Single noise level insufficient for high-quality generation.

Solution: Noise conditioning (multiple noise levels)

  • Perturb data at multiple scales: \(\sigma_1 > \sigma_2 > \cdots > \sigma_T\)

  • Train score network: \(s_\theta(x, \sigma_t)\) for all \(t\)

  • Sample: Gradually denoise from \(\sigma_1\) to \(\sigma_T\)

Annealed Langevin Dynamics:

x_T ~ N(0, Οƒ_1^2 I)
for t = 1 to T:
    for k = 1 to K:
        x_t ← x_t + Ξ΅_t s_ΞΈ(x_t, Οƒ_t) + sqrt(2Ξ΅_t) z_k
    x_{t-1} ← x_t
return x_0

Improvement: Coarse-to-fine generation (large noise captures global structure, small noise refines details).

8. Training Techniques and TricksΒΆ

8.1 Stability TechniquesΒΆ

Spectral Normalization:

  • Divide each layer’s weights by their largest singular value

  • Constrains Lipschitz constant: \(\|E_\theta(x_1) - E_\theta(x_2)\| \leq L \|x_1 - x_2\|\)

  • Stabilizes training, prevents energy function from becoming too steep

Gradient Clipping:

  • Clip Langevin gradient: \(\nabla_x E_\theta(x) \leftarrow \text{clip}(\nabla_x E_\theta(x), -c, c)\)

  • Prevents exploding gradients during sampling

Batch Normalization (use carefully):

  • Can help, but breaks translation equivariance

  • Alternative: Group normalization, Layer normalization

8.2 MCMC InitializationΒΆ

Replay Buffer:

  • Maintain buffer of past negative samples

  • Initialize new chains from buffer (95%) + noise (5%)

  • Accelerates convergence (chains start closer to \(p_\theta\))

Multi-scale Sampling:

  • Sample at low resolution first (faster MCMC)

  • Upsample to high resolution and refine

  • Reduces computational cost

8.3 Loss FunctionsΒΆ

Hinge Loss (Xie et al., 2016): $\(\mathcal{L}_{\text{hinge}} = \mathbb{E}_{x^+ \sim p_{\text{data}}}[\max(0, E_\theta(x^+) - m^+)] + \mathbb{E}_{x^- \sim p_\theta}[\max(0, m^- - E_\theta(x^-))]\)$

Regularization:

  • \(L_2\) penalty on energy: \(\mathbb{E}[E_\theta(x)^2]\) (prevents energy drift)

  • Gradient penalty: \(\mathbb{E}[\|\nabla_x E_\theta(x)\|^2]\) (smoothness)

9. ApplicationsΒΆ

9.1 Image GenerationΒΆ

IGEBM: High-resolution unconditional generation (256Γ—256 images)

JEM: Joint classification and generation (CIFAR-10, ImageNet-32)

Compositional generation: Combine attributes, logical operations

9.2 Out-of-Distribution DetectionΒΆ

Motivation: EBMs assign high energy to OOD samples.

Method:

  1. Train EBM on in-distribution data

  2. Compute energy \(E_\theta(x)\) for test sample

  3. Threshold: OOD if \(E_\theta(x) > \tau\)

Performance: Strong OOD detection on CIFAR-10 vs SVHN, CIFAR-100, etc.

9.3 Adversarial RobustnessΒΆ

Observation: Adversarial examples have higher energy than clean examples.

Robust training:

  1. Generate adversarial examples via PGD on \(E_\theta\)

  2. Train to assign high energy to adversarial examples

Results: JEM achieves better robust accuracy than standard classifiers.

9.4 Controllable GenerationΒΆ

Conditional sampling: \(p(x | y) \propto \exp(-E_\theta(x, y))\)

Attribute editing:

  • Modify energy function: \(E'(x) = E(x) + \lambda E_{\text{attr}}(x)\)

  • Sample via Langevin to find low-energy \(x\) satisfying attribute

Inpainting:

  • Observed pixels: \(x_O\)

  • Energy: \(E(x) + \frac{\lambda}{2}\|x_O - x_O^{\text{target}}\|^2\)

  • Sample to fill in missing pixels \(x_{\bar{O}}\)

9.5 Inverse ProblemsΒΆ

General formulation: $\(p(x | y) \propto p(y | x) p(x) \propto \exp(-\|A(x) - y\|^2 / 2\sigma^2) \exp(-E_\theta(x))\)$

Sample via Langevin: $\(x_{t+1} = x_t - \frac{\epsilon}{2}[\nabla_x \|A(x_t) - y\|^2 / \sigma^2 + \nabla_x E_\theta(x_t)] + \sqrt{\epsilon} z_t\)$

Examples:

  • Super-resolution: \(A\) is downsampling operator

  • Denoising: \(A\) is identity, \(y = x + \text{noise}\)

  • Compressed sensing: \(A\) is measurement matrix

10. Theoretical PropertiesΒΆ

10.1 ExpressivenessΒΆ

Universal approximation:

  • Sufficient capacity \(E_\theta\) can approximate any distribution

  • Proof: Neural networks are universal function approximators

Multimodality:

  • EBMs naturally handle multimodal distributions

  • Multiple local minima in energy landscape β†’ multiple modes

10.2 Convergence of MCMCΒΆ

Langevin dynamics convergence:

  • Under smoothness and strong convexity, Langevin converges to \(p_\theta\) exponentially fast

  • In practice: Non-convex energies, finite steps β†’ approximate samples

Mixing time:

  • Time for MCMC to converge to stationary distribution

  • Depends on energy landscape (flatter β†’ faster mixing)

10.3 Score Matching ConsistencyΒΆ

Theorem (HyvΓ€rinen, 2005): Score matching is consistent: \(\theta^* = \arg\min \mathcal{L}_{\text{ESM}}(\theta)\) satisfies \(p_{\theta^*} = p_{\text{data}}\) (up to parameter identifiability).

Proof sketch:

  • ESM objective is minimized when \(\nabla_x \log p_\theta(x) = \nabla_x \log p_{\text{data}}(x)\) for all \(x\)

  • Integration gives \(\log p_\theta(x) - \log p_{\text{data}}(x) = c\)

  • Normalization implies \(c = 0\), so \(p_\theta = p_{\text{data}}\)

11. Comparison with Other Generative ModelsΒΆ

Aspect

EBM

GAN

VAE

Flow

Diffusion

Training

CD/Score matching

Adversarial

ELBO

Exact likelihood

Denoising

Sampling

MCMC (slow)

Single pass (fast)

Single pass (fast)

Single pass (fast)

Iterative (slow)

Likelihood

Intractable

Intractable

Approximate

Exact

Tractable

Architecture

Flexible

Flexible

Encoder-decoder

Invertible

Flexible

Stability

Moderate

Low

High

High

High

Quality

High (with tricks)

High

Moderate

Moderate

Highest

Compositionality

βœ“ (energy addition)

βœ—

βœ—

βœ—

βœ“ (score addition)

OOD detection

βœ“

βœ—

βœ“

βœ“

βœ“

When to use EBMs:

  • Compositional reasoning required

  • Flexibility in architecture

  • OOD detection or robustness important

  • Can afford slow sampling

When to avoid:

  • Real-time generation needed

  • Limited computational budget

  • Standard generative tasks (diffusion models better)

12. Recent Advances (2020-2024)ΒΆ

12.1 Cooperative TrainingΒΆ

Cooperative Training (Xie et al., 2020):

  • Train EBM and generator jointly

  • Generator initializes MCMC chains (faster convergence)

  • EBM guides generator training

Algorithm:

  1. Generator: \(x^- \sim q_\phi(x)\)

  2. Short MCMC: \(x^- \leftarrow\) Langevin refinement

  3. Update generator: \(\phi \leftarrow \phi - \alpha \nabla_\phi \mathbb{E}_{x \sim q_\phi}[E_\theta(x)]\)

  4. Update EBM: \(\theta \leftarrow \theta - \beta(E_\theta(x^+) - E_\theta(x^-))\)

12.2 Flow Contrastive Estimation (FCEM)ΒΆ

Idea: Use normalizing flows as proposal for EBM.

Joint distribution: $\(p_{\theta,\phi}(x) = \frac{1}{Z_\theta} e^{-E_\theta(x)} q_\phi(x)\)$

where \(q_\phi\) is a flow with tractable density.

Advantage:

  • Flow provides good initialization for MCMC

  • Reduces MCMC steps needed

  • Combines flexibility (EBM) with tractability (flow)

12.3 Implicit Contrastive LearningΒΆ

Contrastive learning connection:

  • InfoNCE loss in contrastive learning is a form of NCE

  • Representation learning via EBM: \(E_\theta(x, x^+) = -\langle f_\theta(x), f_\theta(x^+) \rangle\)

Applications:

  • Self-supervised learning (SimCLR, MoCo)

  • Metric learning

  • Anomaly detection

12.4 Discrete EBMsΒΆ

Structured prediction:

  • Variables: \(x \in \mathcal{X}\) (discrete, e.g., graphs, sequences)

  • Energy: \(E_\theta(x)\) (e.g., GNN for graphs)

  • Inference: \(\arg\min_x E_\theta(x)\) (combinatorial optimization)

Applications:

  • Molecule generation (graph EBM)

  • Protein design

  • Circuit design

13. Implementation ConsiderationsΒΆ

13.1 HyperparametersΒΆ

Langevin sampling:

  • Step size: \(\epsilon \in [0.001, 0.1]\) (larger for faster sampling, smaller for stability)

  • Steps: \(T \in [20, 200]\) (more steps β†’ better samples but slower)

  • Noise temperature: \(\tau \in [0.5, 2.0]\) (controls exploration)

Training:

  • Learning rate: \(10^{-4}\) to \(10^{-3}\) (Adam optimizer)

  • Batch size: 128-256 (larger helps negative sampling)

  • Noise levels: 10-1000 (for score matching)

Replay buffer:

  • Size: 10,000 samples

  • Refresh rate: 5% new samples each batch

13.2 Architectural ChoicesΒΆ

Energy network:

  • ResNet (standard for images)

  • U-Net (score-based models)

  • Transformers (sequences, long-range dependencies)

Output:

  • Scalar energy (no activation)

  • NO softmax/sigmoid (energy can be any real value)

Normalization:

  • Spectral normalization (recommended)

  • Group normalization (better than BatchNorm for EBMs)

13.3 Debugging TipsΒΆ

Check energy landscape:

  • Data samples should have lower energy than noise

  • Plot energy histogram: data vs. MCMC samples

Verify gradients:

  • \(\nabla_x E_\theta(x)\) should point away from data manifold

  • Visualize energy gradients as vector field

Monitor MCMC quality:

  • Visual inspection of samples

  • Acceptance rate (for Metropolis-Hastings)

  • Effective sample size (ESS)

14. Benchmarks and ResultsΒΆ

14.1 Image Generation (CIFAR-10)ΒΆ

Model

FID ↓

IS ↑

Sampling Speed

IGEBM

38.2

6.02

100 steps (~10s)

JEM

40.5

6.8

100 steps (~10s)

DDPM

3.17

9.46

1000 steps (~50s)

GAN (StyleGAN2)

2.92

9.18

1 step (~0.1s)

Note: EBMs lag behind GANs/diffusion on standard metrics, but offer unique advantages (compositionality, OOD detection).

14.2 OOD Detection (CIFAR-10 vs SVHN)ΒΆ

Model

AUROC ↑

FPR@95TPR ↓

Softmax Baseline

0.890

0.421

ODIN

0.921

0.336

Mahalanobis

0.937

0.298

JEM (Energy)

0.964

0.182

Observation: Energy-based OOD detection significantly outperforms classifier-based methods.

14.3 Adversarial Robustness (CIFAR-10)ΒΆ

Model

Clean Acc

PGD-20 Acc

AutoAttack Acc

Standard ResNet

95.2%

0.0%

0.0%

Adversarial Training

84.7%

53.1%

48.2%

JEM + Adversarial

82.3%

56.4%

51.7%

Benefit: EBM training improves robustness by ~3-5% over standard adversarial training.

15. Limitations and Future DirectionsΒΆ

15.1 Current LimitationsΒΆ

Computational cost:

  • MCMC sampling expensive (100-1000Γ— slower than GANs)

  • Requires many energy evaluations per sample

Scalability:

  • Challenging for high-resolution images (512Γ—512+)

  • MCMC mixing poor in high dimensions

Training instability:

  • Contrastive divergence can diverge

  • Requires careful tuning of hyperparameters

15.2 Future DirectionsΒΆ

Faster sampling:

  • Learned MCMC samplers (neural samplers)

  • Hybrid models (flow + EBM, generator + EBM)

  • Amortized inference

Theoretical understanding:

  • Convergence guarantees for CD

  • Sample complexity analysis

  • Connection to diffusion models

Applications:

  • Scientific domains (protein folding, molecule design)

  • Structured prediction (graphs, programs)

  • Multimodal reasoning (vision + language)

Hardware acceleration:

  • Specialized chips for MCMC

  • Parallel tempering on GPUs

16. SummaryΒΆ

Key Takeaways:

  1. Flexibility: EBMs can use any architecture, no invertibility or generator constraints

  2. Compositionality: Energy functions naturally compose for structured reasoning

  3. Robustness: Strong OOD detection and adversarial robustness

  4. Training: Contrastive divergence (CD), score matching, NCE

  5. Sampling: Langevin dynamics (slow but flexible)

  6. Modern variants: JEM (joint classification/generation), IGEBM (high-res images), score-based (diffusion connection)

  7. Trade-offs: High quality and flexibility vs. computational cost

Recommended reading:

  • Hinton (2002): β€œTraining Products of Experts by Minimizing Contrastive Divergence”

  • HyvΓ€rinen (2005): β€œEstimation of Non-Normalized Statistical Models by Score Matching”

  • Du & Mordatch (2019): β€œImplicit Generation and Likelihood Estimation (IGEBM)”

  • Grathwohl et al. (2020): β€œYour Classifier is Secretly an Energy Based Model (JEM)”

  • Song & Ermon (2019): β€œGenerative Modeling by Estimating Gradients of the Data Distribution”

  • Song et al. (2021): β€œScore-Based Generative Modeling through SDEs”

When to use EBMs:

  • βœ“ Compositional reasoning required

  • βœ“ OOD detection or robustness critical

  • βœ“ Flexible architecture needed

  • βœ“ Slow sampling acceptable

When to avoid:

  • βœ— Real-time generation required

  • βœ— Limited computational budget

  • βœ— Standard image generation (use diffusion/GANs)

# ============================================================
# ADVANCED ENERGY-BASED MODELS: PRODUCTION IMPLEMENTATIONS
# Complete PyTorch implementations with modern training techniques
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
import math

# ============================================================
# 1. ENERGY NETWORK ARCHITECTURES
# ============================================================

class SpectralNorm:
    """Spectral normalization wrapper for constraining Lipschitz constant."""
    
    def __init__(self, module: nn.Module, name: str = 'weight', power_iterations: int = 1):
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        
        # Register u and v vectors
        weight = getattr(module, name)
        with torch.no_grad():
            u = torch.randn(weight.size(0), device=weight.device)
            u = u / u.norm()
            v = torch.randn(weight.size(1), device=weight.device)
            v = v / v.norm()
        
        self.register_buffer('u', u)
        self.register_buffer('v', v)
    
    def register_buffer(self, name: str, tensor: torch.Tensor):
        """Register buffer in module."""
        self.module.register_buffer(f'_sn_{name}', tensor)
    
    def _compute_sigma(self, weight: torch.Tensor) -> torch.Tensor:
        """Compute largest singular value via power iteration."""
        u = getattr(self.module, f'_sn_u')
        v = getattr(self.module, f'_sn_v')
        
        # Reshape weight to 2D
        weight_mat = weight.view(weight.size(0), -1)
        
        # Power iterations
        for _ in range(self.power_iterations):
            v = F.normalize(torch.mv(weight_mat.t(), u), dim=0)
            u = F.normalize(torch.mv(weight_mat, v), dim=0)
        
        # Compute sigma = u^T W v
        sigma = torch.dot(u, torch.mv(weight_mat, v))
        
        # Update u, v
        setattr(self.module, f'_sn_u', u.detach())
        setattr(self.module, f'_sn_v', v.detach())
        
        return sigma
    
    def __call__(self, module: nn.Module, inputs):
        """Apply spectral normalization before forward pass."""
        weight = getattr(module, self.name)
        sigma = self._compute_sigma(weight)
        
        # Normalize weight
        weight_sn = weight / sigma
        setattr(module, self.name, weight_sn)
        
        return None

def spectral_norm(module: nn.Module, name: str = 'weight', power_iterations: int = 1) -> nn.Module:
    """Apply spectral normalization to a module."""
    SpectralNorm(module, name, power_iterations)
    return module


class ResidualBlock(nn.Module):
    """Residual block with spectral normalization for energy networks."""
    
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, use_spectral_norm: bool = True):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        # Apply spectral normalization if requested
        if use_spectral_norm:
            self.conv1 = spectral_norm(self.conv1)
            self.conv2 = spectral_norm(self.conv2)
        
        # Normalization and activation
        self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.swish = nn.SiLU()
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            if use_spectral_norm:
                self.shortcut = spectral_norm(self.shortcut)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.swish(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.swish(out)
        return out


class EnergyNetwork(nn.Module):
    """ResNet-based energy network with spectral normalization.
    
    Maps input x to scalar energy E_ΞΈ(x).
    """
    
    def __init__(self, in_channels: int = 3, base_channels: int = 64, num_blocks: List[int] = [2, 2, 2, 2], 
                 use_spectral_norm: bool = True):
        super().__init__()
        
        self.in_channels = in_channels
        
        # Initial convolution
        self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1, bias=False)
        if use_spectral_norm:
            self.conv1 = spectral_norm(self.conv1)
        self.gn1 = nn.GroupNorm(min(32, base_channels), base_channels)
        self.swish = nn.SiLU()
        
        # Residual blocks
        channels = base_channels
        self.layer1 = self._make_layer(channels, channels, num_blocks[0], stride=1, use_spectral_norm=use_spectral_norm)
        self.layer2 = self._make_layer(channels, channels*2, num_blocks[1], stride=2, use_spectral_norm=use_spectral_norm)
        self.layer3 = self._make_layer(channels*2, channels*4, num_blocks[2], stride=2, use_spectral_norm=use_spectral_norm)
        self.layer4 = self._make_layer(channels*4, channels*8, num_blocks[3], stride=2, use_spectral_norm=use_spectral_norm)
        
        # Global pooling and energy output
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(channels*8, 1)
        if use_spectral_norm:
            self.fc = spectral_norm(self.fc)
    
    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int, 
                    use_spectral_norm: bool) -> nn.Sequential:
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, use_spectral_norm))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1, use_spectral_norm))
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [batch_size, in_channels, H, W]
        
        Returns:
            energy: Scalar energy values [batch_size, 1]
        """
        out = self.swish(self.gn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        energy = self.fc(out)
        return energy


class JointEnergyModel(nn.Module):
    """Joint Energy-Based Model for classification and generation.
    
    Uses classifier logits as energy function:
    E(x, y) = -f_ΞΈ(x)[y]
    p(y|x) = softmax(f_ΞΈ(x))
    p(x) ∝ exp(LogSumExp(f_θ(x)))
    """
    
    def __init__(self, in_channels: int = 3, num_classes: int = 10, base_channels: int = 64, 
                 num_blocks: List[int] = [2, 2, 2, 2], use_spectral_norm: bool = True):
        super().__init__()
        
        self.num_classes = num_classes
        
        # Feature extractor (ResNet)
        self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1, bias=False)
        if use_spectral_norm:
            self.conv1 = spectral_norm(self.conv1)
        self.gn1 = nn.GroupNorm(min(32, base_channels), base_channels)
        self.swish = nn.SiLU()
        
        # Residual blocks
        channels = base_channels
        self.layer1 = self._make_layer(channels, channels, num_blocks[0], stride=1, use_spectral_norm=use_spectral_norm)
        self.layer2 = self._make_layer(channels, channels*2, num_blocks[1], stride=2, use_spectral_norm=use_spectral_norm)
        self.layer3 = self._make_layer(channels*2, channels*4, num_blocks[2], stride=2, use_spectral_norm=use_spectral_norm)
        self.layer4 = self._make_layer(channels*4, channels*8, num_blocks[3], stride=2, use_spectral_norm=use_spectral_norm)
        
        # Global pooling and classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(channels*8, num_classes)
        if use_spectral_norm:
            self.fc = spectral_norm(self.fc)
    
    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int, 
                    use_spectral_norm: bool) -> nn.Sequential:
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, use_spectral_norm))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1, use_spectral_norm))
        return nn.Sequential(*layers)
    
    def extract_features(self, x: torch.Tensor) -> torch.Tensor:
        """Extract features from input."""
        out = self.swish(self.gn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        return out
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [batch_size, in_channels, H, W]
        
        Returns:
            logits: Class logits [batch_size, num_classes]
        """
        features = self.extract_features(x)
        logits = self.fc(features)
        return logits
    
    def energy(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Compute energy E(x) or E(x, y).
        
        Args:
            x: Input tensor [batch_size, in_channels, H, W]
            y: Optional class labels [batch_size]
        
        Returns:
            energy: Energy values [batch_size, 1]
        """
        logits = self.forward(x)
        
        if y is None:
            # Marginal energy: E(x) = -LogSumExp(f_ΞΈ(x))
            energy = -torch.logsumexp(logits, dim=1, keepdim=True)
        else:
            # Conditional energy: E(x, y) = -f_ΞΈ(x)[y]
            energy = -logits.gather(1, y.unsqueeze(1))
        
        return energy


# ============================================================
# 2. LANGEVIN DYNAMICS SAMPLING
# ============================================================

class LangevinSampler:
    """Langevin dynamics sampler for EBMs.
    
    x_{t+1} = x_t - (Ξ΅/2) βˆ‡_x E_ΞΈ(x_t) + √Ρ z_t
    """
    
    def __init__(self, energy_model: nn.Module, step_size: float = 0.01, num_steps: int = 100, 
                 noise_scale: float = 1.0, clip_grad: Optional[float] = 0.1):
        self.energy_model = energy_model
        self.step_size = step_size
        self.num_steps = num_steps
        self.noise_scale = noise_scale
        self.clip_grad = clip_grad
    
    def sample(self, x_init: torch.Tensor, y: Optional[torch.Tensor] = None, 
               verbose: bool = False) -> torch.Tensor:
        """Sample from p_ΞΈ(x) or p_ΞΈ(x|y) using Langevin dynamics.
        
        Args:
            x_init: Initial samples [batch_size, channels, H, W]
            y: Optional conditioning labels [batch_size]
            verbose: Print energy during sampling
        
        Returns:
            x: Final samples [batch_size, channels, H, W]
        """
        x = x_init.clone().detach().requires_grad_(True)
        
        for step in range(self.num_steps):
            # Compute energy gradient
            if isinstance(self.energy_model, JointEnergyModel):
                energy = self.energy_model.energy(x, y)
            else:
                energy = self.energy_model(x)
            
            grad = torch.autograd.grad(energy.sum(), x, create_graph=False)[0]
            
            # Clip gradient (for stability)
            if self.clip_grad is not None:
                grad = torch.clamp(grad, -self.clip_grad, self.clip_grad)
            
            # Langevin update
            noise = torch.randn_like(x) * self.noise_scale
            x = x - (self.step_size / 2) * grad + np.sqrt(self.step_size) * noise
            
            # Clamp to valid range (e.g., [0, 1] for images)
            x = torch.clamp(x, 0, 1)
            x = x.detach().requires_grad_(True)
            
            if verbose and step % 20 == 0:
                print(f"Step {step}: Energy = {energy.mean().item():.3f}")
        
        return x.detach()


class ReplayBuffer:
    """Replay buffer for storing and sampling negative examples."""
    
    def __init__(self, buffer_size: int = 10000, image_shape: Tuple[int, int, int] = (3, 32, 32)):
        self.buffer_size = buffer_size
        self.image_shape = image_shape
        self.buffer = []
    
    def push(self, samples: torch.Tensor):
        """Add samples to buffer."""
        samples = samples.detach().cpu()
        for sample in samples:
            if len(self.buffer) < self.buffer_size:
                self.buffer.append(sample)
            else:
                # Replace random sample
                idx = np.random.randint(0, self.buffer_size)
                self.buffer[idx] = sample
    
    def sample(self, batch_size: int, device: torch.device, reinit_prob: float = 0.05) -> torch.Tensor:
        """Sample from buffer with probability of reinitialization.
        
        Args:
            batch_size: Number of samples
            device: Device to place samples on
            reinit_prob: Probability of sampling from noise instead of buffer
        
        Returns:
            samples: Sampled images [batch_size, channels, H, W]
        """
        samples = []
        for _ in range(batch_size):
            if len(self.buffer) == 0 or np.random.rand() < reinit_prob:
                # Sample from noise
                sample = torch.rand(self.image_shape)
            else:
                # Sample from buffer
                idx = np.random.randint(0, len(self.buffer))
                sample = self.buffer[idx]
            samples.append(sample)
        
        samples = torch.stack(samples).to(device)
        return samples


# ============================================================
# 3. CONTRASTIVE DIVERGENCE TRAINING
# ============================================================

class ContrastiveDivergenceTrainer:
    """Trainer for EBMs using Contrastive Divergence.
    
    Loss: L = E_ΞΈ(x^+) - E_ΞΈ(x^-)
    where x^+ ~ p_data, x^- ~ p_ΞΈ (via MCMC)
    """
    
    def __init__(self, energy_model: nn.Module, langevin_sampler: LangevinSampler, 
                 replay_buffer: ReplayBuffer, optimizer: optim.Optimizer, 
                 energy_reg: float = 0.0, grad_reg: float = 0.0):
        self.energy_model = energy_model
        self.langevin_sampler = langevin_sampler
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.energy_reg = energy_reg
        self.grad_reg = grad_reg
    
    def train_step(self, x_pos: torch.Tensor) -> Dict[str, float]:
        """Single training step.
        
        Args:
            x_pos: Positive samples (data) [batch_size, channels, H, W]
        
        Returns:
            metrics: Dictionary of training metrics
        """
        batch_size = x_pos.size(0)
        device = x_pos.device
        
        # Sample negative examples from replay buffer
        x_init = self.replay_buffer.sample(batch_size, device)
        
        # Run Langevin dynamics
        with torch.no_grad():
            x_neg = self.langevin_sampler.sample(x_init)
        
        # Update replay buffer
        self.replay_buffer.push(x_neg)
        
        # Compute energies
        energy_pos = self.energy_model(x_pos)
        energy_neg = self.energy_model(x_neg)
        
        # Contrastive divergence loss
        cd_loss = energy_pos.mean() - energy_neg.mean()
        
        # Regularization
        loss = cd_loss
        
        # Energy regularization (prevent energy drift)
        if self.energy_reg > 0:
            energy_reg_loss = self.energy_reg * (energy_pos ** 2).mean()
            loss += energy_reg_loss
        
        # Gradient regularization (smoothness)
        if self.grad_reg > 0:
            x_pos.requires_grad_(True)
            energy_pos_grad = self.energy_model(x_pos)
            grad = torch.autograd.grad(energy_pos_grad.sum(), x_pos, create_graph=True)[0]
            grad_reg_loss = self.grad_reg * (grad ** 2).sum(dim=(1, 2, 3)).mean()
            loss += grad_reg_loss
        
        # Backward and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Metrics
        metrics = {
            'cd_loss': cd_loss.item(),
            'energy_pos': energy_pos.mean().item(),
            'energy_neg': energy_neg.mean().item(),
            'loss': loss.item()
        }
        
        return metrics


# ============================================================
# 4. SCORE MATCHING TRAINING
# ============================================================

class DenoisingScoreMatching:
    """Denoising score matching for EBMs.
    
    Loss: E_{x_0, x}[||βˆ‡_x E_ΞΈ(x) + (x - x_0)/σ²||Β²]
    where x = x_0 + Οƒ Ξ΅, Ξ΅ ~ N(0, I)
    """
    
    def __init__(self, energy_model: nn.Module, optimizer: optim.Optimizer, noise_std: float = 0.1):
        self.energy_model = energy_model
        self.optimizer = optimizer
        self.noise_std = noise_std
    
    def train_step(self, x_clean: torch.Tensor) -> Dict[str, float]:
        """Single training step.
        
        Args:
            x_clean: Clean data samples [batch_size, channels, H, W]
        
        Returns:
            metrics: Dictionary of training metrics
        """
        # Add noise
        noise = torch.randn_like(x_clean) * self.noise_std
        x_noisy = x_clean + noise
        
        # Compute energy gradient (score)
        x_noisy.requires_grad_(True)
        energy = self.energy_model(x_noisy)
        score = torch.autograd.grad(energy.sum(), x_noisy, create_graph=True)[0]
        
        # True score: βˆ‡_x log q(x|x_0) = -(x - x_0)/σ²
        true_score = -noise / (self.noise_std ** 2)
        
        # Score matching loss
        loss = 0.5 * ((score - true_score) ** 2).sum(dim=(1, 2, 3)).mean()
        
        # Backward and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Metrics
        metrics = {
            'sm_loss': loss.item(),
            'score_norm': score.norm(dim=(1, 2, 3)).mean().item()
        }
        
        return metrics


class MultiScaleScoreMatching:
    """Multi-scale denoising score matching with noise conditioning.
    
    Trains score network s_ΞΈ(x, Οƒ) for multiple noise levels.
    """
    
    def __init__(self, energy_model: nn.Module, optimizer: optim.Optimizer, 
                 noise_levels: List[float] = [1.0, 0.5, 0.25, 0.1, 0.05]):
        self.energy_model = energy_model
        self.optimizer = optimizer
        self.noise_levels = noise_levels
    
    def train_step(self, x_clean: torch.Tensor) -> Dict[str, float]:
        """Single training step.
        
        Args:
            x_clean: Clean data samples [batch_size, channels, H, W]
        
        Returns:
            metrics: Dictionary of training metrics
        """
        batch_size = x_clean.size(0)
        
        # Sample noise level for each sample
        noise_level_idx = torch.randint(0, len(self.noise_levels), (batch_size,))
        noise_stds = torch.tensor([self.noise_levels[i] for i in noise_level_idx], 
                                   device=x_clean.device, dtype=x_clean.dtype)
        
        # Add noise
        noise = torch.randn_like(x_clean)
        x_noisy = x_clean + noise * noise_stds.view(-1, 1, 1, 1)
        
        # Compute energy gradient (score)
        x_noisy.requires_grad_(True)
        energy = self.energy_model(x_noisy)
        score = torch.autograd.grad(energy.sum(), x_noisy, create_graph=True)[0]
        
        # True score: -(x - x_0)/σ²
        true_score = -noise / noise_stds.view(-1, 1, 1, 1)
        
        # Weighted score matching loss (weight by σ²)
        weights = noise_stds ** 2
        loss = 0.5 * (weights.view(-1, 1, 1, 1) * (score - true_score) ** 2).sum(dim=(1, 2, 3)).mean()
        
        # Backward and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Metrics
        metrics = {
            'msm_loss': loss.item(),
            'score_norm': score.norm(dim=(1, 2, 3)).mean().item()
        }
        
        return metrics


# ============================================================
# 5. JOINT ENERGY-BASED MODEL (JEM) TRAINING
# ============================================================

class JEMTrainer:
    """Trainer for Joint Energy-Based Models (classification + generation)."""
    
    def __init__(self, jem_model: JointEnergyModel, langevin_sampler: LangevinSampler, 
                 replay_buffer: ReplayBuffer, optimizer: optim.Optimizer, 
                 alpha: float = 1.0, beta: float = 1.0):
        """
        Args:
            jem_model: Joint energy-based model
            langevin_sampler: Sampler for generating negative examples
            replay_buffer: Buffer for storing negative examples
            optimizer: Optimizer
            alpha: Weight for classification loss
            beta: Weight for generation loss
        """
        self.jem_model = jem_model
        self.langevin_sampler = langevin_sampler
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.alpha = alpha
        self.beta = beta
    
    def train_step(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, float]:
        """Single training step.
        
        Args:
            x: Input images [batch_size, channels, H, W]
            y: Class labels [batch_size]
        
        Returns:
            metrics: Dictionary of training metrics
        """
        batch_size = x.size(0)
        device = x.device
        
        # ===== Classification Loss =====
        logits = self.jem_model(x)
        class_loss = F.cross_entropy(logits, y)
        
        # ===== Generation Loss (Contrastive Divergence) =====
        # Sample negative examples from replay buffer
        x_init = self.replay_buffer.sample(batch_size, device)
        
        # Run Langevin dynamics (unconditional)
        with torch.no_grad():
            x_neg = self.langevin_sampler.sample(x_init)
        
        # Update replay buffer
        self.replay_buffer.push(x_neg)
        
        # Compute energies (marginal)
        energy_pos = self.jem_model.energy(x)
        energy_neg = self.jem_model.energy(x_neg)
        
        # Contrastive divergence loss
        gen_loss = energy_pos.mean() - energy_neg.mean()
        
        # ===== Total Loss =====
        loss = self.alpha * class_loss + self.beta * gen_loss
        
        # Backward and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Metrics
        acc = (logits.argmax(dim=1) == y).float().mean()
        metrics = {
            'loss': loss.item(),
            'class_loss': class_loss.item(),
            'gen_loss': gen_loss.item(),
            'accuracy': acc.item(),
            'energy_pos': energy_pos.mean().item(),
            'energy_neg': energy_neg.mean().item()
        }
        
        return metrics


# ============================================================
# 6. DEMONSTRATIONS
# ============================================================

def demo_energy_network():
    """Demonstrate energy network architecture."""
    print("=" * 60)
    print("DEMO: Energy Network Architecture")
    print("=" * 60)
    
    # Create energy network
    energy_net = EnergyNetwork(in_channels=3, base_channels=64, num_blocks=[2, 2, 2, 2])
    
    # Dummy input
    x = torch.randn(8, 3, 32, 32)
    
    # Forward pass
    energy = energy_net(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Energy shape: {energy.shape}")
    print(f"Energy values: {energy.squeeze()}")
    print(f"Total parameters: {sum(p.numel() for p in energy_net.parameters()):,}")
    print()


def demo_langevin_sampling():
    """Demonstrate Langevin dynamics sampling."""
    print("=" * 60)
    print("DEMO: Langevin Dynamics Sampling")
    print("=" * 60)
    
    # Create energy network
    energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
    energy_net.eval()
    
    # Create sampler
    sampler = LangevinSampler(energy_net, step_size=0.01, num_steps=50, clip_grad=0.1)
    
    # Initialize from noise
    x_init = torch.rand(4, 3, 32, 32)
    
    print("Running Langevin sampling (50 steps)...")
    x_samples = sampler.sample(x_init, verbose=True)
    
    print(f"\nInitial samples range: [{x_init.min():.3f}, {x_init.max():.3f}]")
    print(f"Final samples range: [{x_samples.min():.3f}, {x_samples.max():.3f}]")
    
    # Compute initial vs final energies
    with torch.no_grad():
        energy_init = energy_net(x_init)
        energy_final = energy_net(x_samples)
    
    print(f"\nInitial energy: {energy_init.mean():.3f} Β± {energy_init.std():.3f}")
    print(f"Final energy: {energy_final.mean():.3f} Β± {energy_final.std():.3f}")
    print(f"Energy reduction: {(energy_init - energy_final).mean():.3f}")
    print()


def demo_contrastive_divergence():
    """Demonstrate contrastive divergence training."""
    print("=" * 60)
    print("DEMO: Contrastive Divergence Training")
    print("=" * 60)
    
    # Create energy network
    energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
    
    # Create sampler and buffer
    sampler = LangevinSampler(energy_net, step_size=0.01, num_steps=20)
    buffer = ReplayBuffer(buffer_size=1000, image_shape=(3, 32, 32))
    
    # Create trainer
    optimizer = optim.Adam(energy_net.parameters(), lr=1e-4)
    trainer = ContrastiveDivergenceTrainer(energy_net, sampler, buffer, optimizer, 
                                           energy_reg=0.01, grad_reg=0.01)
    
    # Dummy training data
    x_data = torch.rand(16, 3, 32, 32)
    
    # Training step
    metrics = trainer.train_step(x_data)
    
    print("Training metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value:.4f}")
    print()


def demo_score_matching():
    """Demonstrate denoising score matching."""
    print("=" * 60)
    print("DEMO: Denoising Score Matching")
    print("=" * 60)
    
    # Create energy network
    energy_net = EnergyNetwork(in_channels=3, base_channels=32, num_blocks=[1, 1, 1, 1])
    
    # Create trainer
    optimizer = optim.Adam(energy_net.parameters(), lr=1e-4)
    trainer = DenoisingScoreMatching(energy_net, optimizer, noise_std=0.1)
    
    # Dummy training data
    x_data = torch.rand(16, 3, 32, 32)
    
    # Training step
    metrics = trainer.train_step(x_data)
    
    print("Score matching metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value:.4f}")
    print()


def demo_joint_energy_model():
    """Demonstrate Joint Energy-Based Model."""
    print("=" * 60)
    print("DEMO: Joint Energy-Based Model (JEM)")
    print("=" * 60)
    
    # Create JEM
    jem = JointEnergyModel(in_channels=3, num_classes=10, base_channels=32, num_blocks=[1, 1, 1, 1])
    
    # Dummy input
    x = torch.rand(8, 3, 32, 32)
    y = torch.randint(0, 10, (8,))
    
    # Forward pass (classification)
    logits = jem(x)
    probs = F.softmax(logits, dim=1)
    
    print(f"Input shape: {x.shape}")
    print(f"Logits shape: {logits.shape}")
    print(f"Class probabilities (first sample): {probs[0]}")
    
    # Energy computation
    energy_marginal = jem.energy(x)
    energy_conditional = jem.energy(x, y)
    
    print(f"\nMarginal energy E(x): {energy_marginal.squeeze()}")
    print(f"Conditional energy E(x,y): {energy_conditional.squeeze()}")
    
    print(f"\nTotal parameters: {sum(p.numel() for p in jem.parameters()):,}")
    print()


def demo_jem_training():
    """Demonstrate JEM training."""
    print("=" * 60)
    print("DEMO: JEM Training Step")
    print("=" * 60)
    
    # Create JEM
    jem = JointEnergyModel(in_channels=3, num_classes=10, base_channels=32, num_blocks=[1, 1, 1, 1])
    
    # Create sampler and buffer
    sampler = LangevinSampler(jem, step_size=0.01, num_steps=20)
    buffer = ReplayBuffer(buffer_size=1000, image_shape=(3, 32, 32))
    
    # Create trainer
    optimizer = optim.Adam(jem.parameters(), lr=1e-4)
    trainer = JEMTrainer(jem, sampler, buffer, optimizer, alpha=1.0, beta=1.0)
    
    # Dummy training data
    x = torch.rand(16, 3, 32, 32)
    y = torch.randint(0, 10, (16,))
    
    # Training step
    metrics = trainer.train_step(x, y)
    
    print("JEM training metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value:.4f}")
    print()


def print_performance_comparison():
    """Print comprehensive performance comparison and decision guide."""
    print("=" * 80)
    print("ENERGY-BASED MODELS: COMPREHENSIVE PERFORMANCE ANALYSIS")
    print("=" * 80)
    
    print("\n" + "=" * 80)
    print("1. IMAGE GENERATION BENCHMARKS (CIFAR-10)")
    print("=" * 80)
    
    generation_data = [
        ["Model", "FID ↓", "IS ↑", "Sampling Time", "Training Stable?"],
        ["-" * 20, "-" * 10, "-" * 10, "-" * 15, "-" * 15],
        ["IGEBM", "38.2", "6.02", "~10s (100 steps)", "Moderate"],
        ["JEM", "40.5", "6.8", "~10s (100 steps)", "Moderate"],
        ["Score-based", "3.2", "9.5", "~50s (1000 steps)", "High"],
        ["DDPM", "3.17", "9.46", "~50s (1000 steps)", "High"],
        ["StyleGAN2", "2.92", "9.18", "~0.1s (1 step)", "Low"],
        ["BigGAN", "6.95", "9.22", "~0.1s (1 step)", "Low"]
    ]
    
    for row in generation_data:
        print(f"{row[0]:<20} {row[1]:<10} {row[2]:<10} {row[3]:<15} {row[4]:<15}")
    
    print("\nπŸ“Š Key Observations:")
    print("  β€’ EBMs (IGEBM, JEM) lag GANs/diffusion on standard metrics")
    print("  β€’ Sampling 100-1000Γ— slower than GANs due to MCMC")
    print("  β€’ More stable training than GANs, less stable than diffusion")
    print("  β€’ Score-based models bridge gap (diffusion connection)")
    
    print("\n" + "=" * 80)
    print("2. OUT-OF-DISTRIBUTION (OOD) DETECTION (CIFAR-10 vs SVHN)")
    print("=" * 80)
    
    ood_data = [
        ["Method", "AUROC ↑", "FPR@95TPR ↓", "Approach"],
        ["-" * 25, "-" * 10, "-" * 15, "-" * 30],
        ["Softmax Baseline", "0.890", "0.421", "Max softmax probability"],
        ["ODIN", "0.921", "0.336", "Temperature + input perturbation"],
        ["Mahalanobis", "0.937", "0.298", "Feature space distance"],
        ["JEM (Energy)", "0.964", "0.182", "Energy threshold"],
        ["EBM + Ensembles", "0.978", "0.145", "Energy + multiple models"]
    ]
    
    for row in ood_data:
        print(f"{row[0]:<25} {row[1]:<10} {row[2]:<15} {row[3]:<30}")
    
    print("\n🎯 Key Observations:")
    print("  β€’ Energy-based OOD detection significantly outperforms classifier-based methods")
    print("  β€’ JEM energy threshold: +7.4 AUROC vs softmax baseline")
    print("  β€’ Lower FPR@95TPR: More reliable rejection of OOD samples")
    print("  β€’ Ensembles further improve OOD detection")
    
    print("\n" + "=" * 80)
    print("3. ADVERSARIAL ROBUSTNESS (CIFAR-10, PGD-20 Attack)")
    print("=" * 80)
    
    robust_data = [
        ["Model", "Clean Acc", "PGD-20 Acc", "AutoAttack", "Training Cost"],
        ["-" * 25, "-" * 12, "-" * 13, "-" * 13, "-" * 15],
        ["Standard ResNet", "95.2%", "0.0%", "0.0%", "1Γ— baseline"],
        ["Adversarial Training", "84.7%", "53.1%", "48.2%", "3Γ— baseline"],
        ["JEM + Adversarial", "82.3%", "56.4%", "51.7%", "5Γ— baseline"],
        ["TRADES", "84.9%", "54.4%", "49.8%", "3Γ— baseline"]
    ]
    
    for row in robust_data:
        print(f"{row[0]:<25} {row[1]:<12} {row[2]:<13} {row[3]:<13} {row[4]:<15}")
    
    print("\nπŸ›‘οΈ Key Observations:")
    print("  β€’ JEM improves robust accuracy by ~3-4% over standard adversarial training")
    print("  β€’ Energy landscape provides additional robustness signal")
    print("  β€’ Higher computational cost (5Γ— vs 3Γ— for standard adversarial training)")
    print("  β€’ Trade-off: Clean accuracy drops ~3% vs standard training")
    
    print("\n" + "=" * 80)
    print("4. COMPUTATIONAL COMPLEXITY ANALYSIS")
    print("=" * 80)
    
    complexity_data = [
        ["Operation", "Time Complexity", "Space", "Bottleneck"],
        ["-" * 30, "-" * 20, "-" * 15, "-" * 30],
        ["Energy forward", "O(D)", "O(D)", "Network evaluation"],
        ["Score computation", "O(D)", "O(D)", "Backprop through network"],
        ["Langevin sampling (T steps)", "O(TΒ·D)", "O(D)", "MCMC iterations"],
        ["CD training (batch B)", "O(BΒ·TΒ·D)", "O(BΒ·D)", "Negative sampling"],
        ["Score matching", "O(BΒ·D)", "O(D)", "Score gradient computation"]
    ]
    
    for row in complexity_data:
        print(f"{row[0]:<30} {row[1]:<20} {row[2]:<15} {row[3]:<30}")
    
    print("\n⏱️ Typical Values:")
    print("  β€’ D = 3Γ—32Γ—32 = 3,072 (CIFAR-10)")
    print("  β€’ T = 20-200 Langevin steps")
    print("  β€’ B = 128-256 batch size")
    print("  β€’ CD: ~10s per batch (100 steps), Score matching: ~0.5s per batch")
    
    print("\n" + "=" * 80)
    print("5. TRAINING METHOD COMPARISON")
    print("=" * 80)
    
    training_data = [
        ["Method", "Partition Fn?", "MCMC?", "Stability", "Speed", "Quality"],
        ["-" * 25, "-" * 15, "-" * 10, "-" * 12, "-" * 10, "-" * 10],
        ["Contrastive Divergence", "Avoided", "Yes", "Moderate", "Slow", "High"],
        ["Persistent CD", "Avoided", "Yes", "Low", "Slow", "Higher"],
        ["Score Matching", "Avoided", "No", "High", "Fast", "High"],
        ["Denoising Score", "Avoided", "No", "High", "Fast", "High"],
        ["Multi-scale Score", "Avoided", "No", "High", "Moderate", "Highest"],
        ["NCE", "Learned", "No", "High", "Fast", "Moderate"]
    ]
    
    for row in training_data:
        print(f"{row[0]:<25} {row[1]:<15} {row[2]:<10} {row[3]:<12} {row[4]:<10} {row[5]:<10}")
    
    print("\nπŸ”§ Recommendations:")
    print("  β€’ Denoising Score Matching: Best default choice (fast, stable, high quality)")
    print("  β€’ Multi-scale Score: State-of-the-art quality (diffusion connection)")
    print("  β€’ Contrastive Divergence: When explicit sampling needed")
    print("  β€’ NCE: Fast but requires good noise distribution")
    
    print("\n" + "=" * 80)
    print("6. COMPARISON WITH OTHER GENERATIVE MODELS")
    print("=" * 80)
    
    comparison_data = [
        ["Model", "Likelihood", "Sampling", "Architecture", "Composable?", "Best Use Case"],
        ["-" * 12, "-" * 12, "-" * 12, "-" * 15, "-" * 12, "-" * 35],
        ["EBM", "Intractable", "Slow (MCMC)", "Flexible", "βœ“", "OOD, robustness, composition"],
        ["GAN", "Intractable", "Fast", "Flexible", "βœ—", "High-quality images, fast sampling"],
        ["VAE", "Approximate", "Fast", "Enc-Dec", "βœ—", "Latent space, fast inference"],
        ["Flow", "Exact", "Fast", "Invertible", "βœ—", "Exact likelihood, density est"],
        ["Diffusion", "Tractable", "Slow", "Flexible", "βœ“", "Highest quality, controllable"]
    ]
    
    for row in comparison_data:
        print(f"{row[0]:<12} {row[1]:<12} {row[2]:<12} {row[3]:<15} {row[4]:<12} {row[5]:<35}")
    
    print("\n" + "=" * 80)
    print("7. DECISION GUIDE: WHEN TO USE ENERGY-BASED MODELS")
    print("=" * 80)
    
    print("\nβœ… USE EBMs WHEN:")
    print("  1. Compositional reasoning required (combine concepts: red AND square)")
    print("  2. Out-of-distribution detection critical (medical, autonomous systems)")
    print("  3. Adversarial robustness important (security applications)")
    print("  4. Flexible architecture needed (no invertibility constraints)")
    print("  5. Joint modeling beneficial (JEM: classification + generation)")
    print("  6. Interpretability valued (energy landscape visualization)")
    print("  7. Slow sampling acceptable (offline generation, design)")
    
    print("\n❌ AVOID EBMs WHEN:")
    print("  1. Real-time generation required (use GANs or cached diffusion)")
    print("  2. Limited computational budget (training + sampling expensive)")
    print("  3. Standard generative tasks (diffusion models better FID/IS)")
    print("  4. Large-scale images (512Γ—512+) without specialized hardware")
    print("  5. Exact likelihood needed (use normalizing flows)")
    print("  6. Fast prototyping (less mature libraries than GANs/VAEs)")
    
    print("\n" + "=" * 80)
    print("8. HYPERPARAMETER RECOMMENDATIONS")
    print("=" * 80)
    
    hyperparam_data = [
        ["Parameter", "Typical Range", "Recommended Start", "Impact"],
        ["-" * 25, "-" * 20, "-" * 20, "-" * 35],
        ["Langevin step size Ξ΅", "0.001 - 0.1", "0.01", "Larger β†’ faster, less stable"],
        ["Langevin steps T", "20 - 200", "100", "More β†’ better samples, slower"],
        ["Learning rate", "1e-5 - 1e-3", "1e-4", "Standard impact"],
        ["Batch size", "64 - 256", "128", "Larger helps negative sampling"],
        ["Replay buffer size", "1K - 100K", "10K", "Larger β†’ better diversity"],
        ["Reinit probability", "0.01 - 0.2", "0.05", "Higher β†’ more exploration"],
        ["Energy reg Ξ»_E", "0.0 - 0.1", "0.01", "Prevents energy drift"],
        ["Gradient reg Ξ»_G", "0.0 - 0.1", "0.01", "Smooths energy landscape"],
        ["Noise std (DSM)", "0.01 - 0.5", "0.1", "Match to data scale"]
    ]
    
    for row in hyperparam_data:
        print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<35}")
    
    print("\n" + "=" * 80)
    print("9. RECENT ADVANCES AND FUTURE DIRECTIONS")
    print("=" * 80)
    
    print("\nπŸ”¬ Recent Advances (2020-2024):")
    print("  β€’ Cooperative Training: Train EBM + generator jointly (faster MCMC)")
    print("  β€’ Flow Contrastive Estimation: Use flows as MCMC proposals")
    print("  β€’ Score-based diffusion: Unified framework (SDE/ODE solvers)")
    print("  β€’ Discrete EBMs: Graph generation, protein design")
    print("  β€’ Hardware acceleration: Custom chips for MCMC")
    
    print("\nπŸš€ Future Directions:")
    print("  β€’ Learned samplers: Neural networks to amortize MCMC")
    print("  β€’ Theoretical analysis: Convergence guarantees, sample complexity")
    print("  β€’ Scalability: High-resolution images (1024Γ—1024+)")
    print("  β€’ Scientific applications: Molecule design, physics simulation")
    print("  β€’ Multimodal: Vision + language energy functions")
    
    print("\n" + "=" * 80)


# ============================================================
# RUN ALL DEMONSTRATIONS
# ============================================================

if __name__ == "__main__":
    demo_energy_network()
    demo_langevin_sampling()
    demo_contrastive_divergence()
    demo_score_matching()
    demo_joint_energy_model()
    demo_jem_training()
    print_performance_comparison()