Neural Ordinary Differential Equations (Neural ODEs)¶
Learning Objectives:
Understand continuous-depth networks
Implement adjoint method for backpropagation
Train Neural ODE models
Apply to time series and generative modeling
Prerequisites: Deep learning, differential equations, numerical methods
Time: 90 minutes
📚 Reference Materials:
neuralODE_Adjoint.pdf - Neural ODE theory and adjoint method derivation
1. From ResNets to ODEs¶
ResNet Blocks¶
Standard ResNet block: $\(h_{t+1} = h_t + f_\theta(h_t)\)$
where \(f_\theta\) is a neural network (typically conv layers + activation).
Iterating \(T\) blocks: $\(h_T = h_0 + \sum_{t=0}^{T-1} f_\theta(h_t, t)\)$
Continuous Limit¶
As depth \(T \to \infty\) with step size \(\Delta t \to 0\):
becomes an ODE:
with initial condition \(h(0) = h_0\) and solution \(h(T)\) at time \(T\).
Key insight: Neural network depth becomes continuous!
Neural ODE Layer¶
Forward pass: Solve ODE from \(t=0\) to \(t=T\) $\(h(T) = h(0) + \int_0^T f_\theta(h(t), t) dt\)$
Use ODE solver (e.g., Runge-Kutta, Adams-Bashforth)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Check if torchdiffeq is available
try:
from torchdiffeq import odeint_adjoint as odeint
print("✅ torchdiffeq is available")
except ImportError:
print("⚠️ Installing torchdiffeq...")
import subprocess
subprocess.check_call(['pip', 'install', 'torchdiffeq'])
from torchdiffeq import odeint_adjoint as odeint
print("✅ torchdiffeq installed successfully")
2. The Adjoint Method¶
Problem: Memory Cost¶
Standard backpropagation:
Forward pass: Compute and store all intermediate states
Backward pass: Use stored states to compute gradients
Memory: \(O(T \cdot d)\) where \(T\) is number of steps, \(d\) is state dimension
For very deep networks: Memory becomes prohibitive!
Solution: Adjoint Sensitivity Method¶
Idea: Compute gradients by solving another ODE backward in time
Adjoint state: $\(a(t) = \frac{\partial L}{\partial h(t)}\)$
where \(L\) is the loss.
Adjoint ODE: $\(\frac{da(t)}{dt} = -a(t)^T \frac{\partial f_\theta}{\partial h}\)$
Gradient: $\(\frac{\partial L}{\partial \theta} = -\int_T^0 a(t)^T \frac{\partial f_\theta}{\partial \theta} dt\)$
Key advantage: Only need to store \(a(T)\) and integrate backward!
Memory: \(O(1)\) instead of \(O(T)\)! ✅
2.5. Adjoint Method: Complete Derivation¶
Forward Problem¶
Given ODE initial value problem: $\(\frac{dh(t)}{dt} = f_\theta(h(t), t), \quad h(t_0) = h_0\)$
Solution at time \(t_1\): $\(h(t_1) = h_0 + \int_{t_0}^{t_1} f_\theta(h(t), t) dt\)$
Loss Function¶
Define loss \(L\) that depends on final state: $\(L = L(h(t_1))\)$
Goal: Compute \(\frac{dL}{d\theta}\) and \(\frac{dL}{dh_0}\) efficiently
Adjoint State Definition¶
Define adjoint state: $\(a(t) \triangleq \frac{\partial L}{\partial h(t)}\)$
Physical interpretation: Sensitivity of loss to hidden state at time \(t\)
Deriving the Adjoint ODE¶
Starting from the definition: $\(a(t) = \frac{\partial L}{\partial h(t)}\)$
Take derivative w.r.t. time: $\(\frac{da(t)}{dt} = \frac{d}{dt}\left(\frac{\partial L}{\partial h(t)}\right)\)$
By chain rule and using \(\frac{dh}{dt} = f_\theta(h, t)\): $\(\frac{da(t)}{dt} = -a(t)^T \frac{\partial f_\theta(h(t), t)}{\partial h}\)$
This is the adjoint ODE! It evolves backward in time from \(t_1\) to \(t_0\).
Boundary condition: \(a(t_1) = \frac{\partial L}{\partial h(t_1)}\)
Gradient Computation¶
Gradient w.r.t. parameters:
Gradient w.r.t. initial state:
The Augmented System¶
In practice, solve augmented ODE backward:
from \(t=t_1\) back to \(t=t_0\)
Key advantages:
Memory: \(O(1)\) instead of \(O(T)\) where \(T\) is number of forward steps
Flexibility: Can use adaptive ODE solvers without worrying about gradient computation
Accuracy: Gradient accuracy controlled by ODE solver tolerance
Comparison: Standard Backprop vs Adjoint¶
Method |
Memory |
Accuracy |
Flexibility |
|---|---|---|---|
Standard Backprop |
\(O(T \cdot d)\) |
Exact |
Limited to fixed discretization |
Adjoint Method |
\(O(d)\) |
ODE solver tolerance |
Adaptive solvers supported |
where \(T\) = number of steps, \(d\) = state dimension
# Visualize adjoint computation
def visualize_adjoint_method():
"""
Demonstrate forward and backward passes in Neural ODE
Shows memory efficiency of adjoint method
"""
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# Time points
t = np.linspace(0, 1, 100)
# Forward pass: hidden state trajectory
h_trajectory = np.sin(2 * np.pi * t) + 0.5 * np.cos(4 * np.pi * t)
ax = axes[0]
ax.plot(t, h_trajectory, 'b-', linewidth=3, label='$h(t)$ trajectory')
ax.scatter([0, 1], [h_trajectory[0], h_trajectory[-1]],
s=200, c=['green', 'red'], zorder=5, edgecolors='black', linewidths=2)
ax.annotate('$h(0)$ (initial)', xy=(0, h_trajectory[0]),
xytext=(0.15, h_trajectory[0] + 0.3),
fontsize=12, fontweight='bold',
arrowprops=dict(arrowstyle='->', lw=2, color='green'))
ax.annotate('$h(1)$ (final)', xy=(1, h_trajectory[-1]),
xytext=(0.85, h_trajectory[-1] - 0.4),
fontsize=12, fontweight='bold',
arrowprops=dict(arrowstyle='->', lw=2, color='red'))
# Add arrows showing forward direction
for i in range(0, 90, 20):
ax.annotate('', xy=(t[i+10], h_trajectory[i+10]),
xytext=(t[i], h_trajectory[i]),
arrowprops=dict(arrowstyle='->', lw=2, color='blue', alpha=0.6))
ax.set_xlabel('Time $t$', fontsize=13)
ax.set_ylabel('Hidden State $h(t)$', fontsize=13)
ax.set_title('Forward Pass: Solve ODE $\\frac{dh}{dt} = f_\\theta(h, t)$',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(fontsize=12)
# Backward pass: adjoint trajectory
a_trajectory = np.exp(-2 * (t - 1)**2) # Adjoint decays from t=1 to t=0
grad_accumulation = np.cumsum(a_trajectory[::-1])[::-1] / 100
ax = axes[1]
# Plot adjoint trajectory
ax.plot(t, a_trajectory, 'r-', linewidth=3, label='$a(t)$ (adjoint state)')
ax.fill_between(t, 0, a_trajectory, alpha=0.2, color='red')
# Plot gradient accumulation
ax.plot(t, grad_accumulation / grad_accumulation.max() * a_trajectory.max() * 0.8,
'g--', linewidth=2, label='$\\frac{\\partial L}{\\partial \\theta}$ accumulation')
ax.scatter([0, 1], [a_trajectory[-1], a_trajectory[0]],
s=200, c=['green', 'red'], zorder=5, edgecolors='black', linewidths=2)
ax.annotate('$a(1) = \\frac{\\partial L}{\\partial h(1)}$', xy=(1, a_trajectory[0]),
xytext=(0.75, a_trajectory[0] + 0.15),
fontsize=12, fontweight='bold',
arrowprops=dict(arrowstyle='->', lw=2, color='red'))
ax.annotate('$a(0) = \\frac{\\partial L}{\\partial h(0)}$', xy=(0, a_trajectory[-1]),
xytext=(0.15, a_trajectory[-1] + 0.15),
fontsize=12, fontweight='bold',
arrowprops=dict(arrowstyle='->', lw=2, color='green'))
# Add arrows showing backward direction
for i in range(90, 0, -20):
ax.annotate('', xy=(t[i-10], a_trajectory[i-10]),
xytext=(t[i], a_trajectory[i]),
arrowprops=dict(arrowstyle='->', lw=2, color='red', alpha=0.6))
ax.set_xlabel('Time $t$', fontsize=13)
ax.set_ylabel('Adjoint State $a(t)$', fontsize=13)
ax.set_title('Backward Pass: Solve Adjoint ODE $\\frac{da}{dt} = -a^T \\frac{\\partial f}{\\partial h}$ (backward in time)',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("ADJOINT METHOD EXPLANATION")
print("="*70)
print("Forward Pass (t: 0 → 1):")
print(" • Solve ODE to get h(t) trajectory")
print(" • Only store final state h(1)")
print(" • Memory: O(1) - constant!")
print("\nBackward Pass (t: 1 → 0):")
print(" • Compute a(1) = ∂L/∂h(1) from loss")
print(" • Solve adjoint ODE backward in time")
print(" • Accumulate ∂L/∂θ along the way")
print(" • Result: a(0) = ∂L/∂h(0) and ∂L/∂θ")
print("\nKey Insight:")
print(" • No need to store intermediate states!")
print(" • Recompute h(t) during backward pass if needed")
print(" • Memory efficient for very deep networks")
print("="*70)
visualize_adjoint_method()
# Visualize ResNet vs Neural ODE
def plot_resnet_vs_ode():
"""Compare discrete ResNet with continuous Neural ODE"""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# ResNet: Discrete layers
t_discrete = np.arange(0, 11)
h_discrete = np.cumsum(np.random.randn(11) * 0.3) # Random walk
axes[0].plot(t_discrete, h_discrete, 'o-', linewidth=2, markersize=10, color='blue')
for i, (x, y) in enumerate(zip(t_discrete, h_discrete)):
axes[0].annotate(f'Layer {i}', (x, y), textcoords="offset points",
xytext=(0,10), ha='center', fontsize=9)
axes[0].set_xlabel('Layer Index', fontsize=12)
axes[0].set_ylabel('Hidden State', fontsize=12)
axes[0].set_title('ResNet: Discrete Layers', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
# Neural ODE: Continuous
t_continuous = np.linspace(0, 10, 100)
h_continuous = np.cumsum(np.random.randn(100) * 0.1) # Smooth trajectory
axes[1].plot(t_continuous, h_continuous, '-', linewidth=2, color='green')
axes[1].fill_between(t_continuous, h_continuous-0.2, h_continuous+0.2,
alpha=0.2, color='green', label='Continuous flow')
axes[1].set_xlabel('Time $t$', fontsize=12)
axes[1].set_ylabel('Hidden State $h(t)$', fontsize=12)
axes[1].set_title('Neural ODE: Continuous Depth', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("ResNet: Discrete transformations at each layer")
print("Neural ODE: Smooth continuous evolution governed by ODE")
plot_resnet_vs_ode()
3. Neural ODE Implementation¶
ODE Function¶
Define \(f_\theta(h, t)\): the dynamics function
Requirements:
Input: state \(h\) and time \(t\)
Output: derivative \(\frac{dh}{dt}\)
ODE Block¶
Components:
ODE function \(f_\theta\)
Integration time \([0, T]\)
ODE solver (e.g.,
dopri5,rk4)
Usage:
from torchdiffeq import odeint_adjoint
# Forward pass
h_T = odeint_adjoint(odefunc, h_0, t=[0, T])
The _adjoint version uses the adjoint method for backpropagation!
3.5. ODE Solvers: Theory and Practice¶
Common ODE Solver Methods¶
1. Euler Method (1st order) $\(h_{n+1} = h_n + \Delta t \cdot f(h_n, t_n)\)$
Pros: Simple, fast
Cons: Low accuracy, requires small step size
Error: \(O(\Delta t)\) per step, \(O(\Delta t)\) globally
2. Runge-Kutta 4 (RK4, 4th order) $\(k_1 = f(h_n, t_n)\)\( \)\(k_2 = f(h_n + \frac{\Delta t}{2}k_1, t_n + \frac{\Delta t}{2})\)\( \)\(k_3 = f(h_n + \frac{\Delta t}{2}k_2, t_n + \frac{\Delta t}{2})\)\( \)\(k_4 = f(h_n + \Delta t \cdot k_3, t_n + \Delta t)\)\( \)\(h_{n+1} = h_n + \frac{\Delta t}{6}(k_1 + 2k_2 + 2k_3 + k_4)\)$
Pros: More accurate, stable
Cons: 4× function evaluations per step
Error: \(O(\Delta t^4)\) per step, \(O(\Delta t^4)\) globally
3. Adaptive Methods (Dormand-Prince, dopri5)
Automatically adjust step size based on error estimate:
Use two different order methods
Compare results to estimate error
Increase/decrease \(\Delta t\) accordingly
Advantages:
✅ Accuracy: Maintains user-specified tolerance
✅ Efficiency: Large steps where possible, small where needed
✅ Reliability: Detects stiff regions automatically
Solver Selection Guidelines¶
Scenario |
Recommended Solver |
Reason |
|---|---|---|
Training |
dopri5 (adaptive) |
Balance speed/accuracy, handles varying dynamics |
Inference (speed) |
Euler or RK4 (fixed) |
Faster, predictable cost |
Inference (quality) |
dopri5 |
Best quality generations |
Stiff problems |
Implicit solvers |
Stability for fast/slow dynamics |
Trade-offs in Neural ODEs¶
Number of Function Evaluations (NFE):
Measures computational cost
Varies with solver and tolerance
Typical: 20-100 NFE for training, 5-20 for inference
Tolerance:
rtol(relative tolerance): \(10^{-3}\) to \(10^{-7}\)atol(absolute tolerance): \(10^{-4}\) to \(10^{-9}\)Lower tolerance → More accurate, more NFE
Higher tolerance → Faster, less accurate
Practical Tips:
Start with
rtol=1e-3, atol=1e-4for trainingUse
rtol=1e-5, atol=1e-6for high-quality generationMonitor NFE during training (should be stable)
If NFE explodes, network might be poorly conditioned
# Demonstrate different ODE solvers
def compare_ode_solvers():
"""
Compare Euler, RK4, and adaptive (dopri5) methods
Demonstrates accuracy vs computational cost trade-off
"""
# Define a simple ODE: dh/dt = -2h + sin(t)
def f(t, h):
return -2 * h + np.sin(t)
# Analytical solution for comparison
def h_exact(t):
return (1/5) * (np.sin(t) - 2*np.cos(t) + 2*np.exp(-2*t))
# Time span
t_span = np.linspace(0, 5, 200)
h0 = 0.0
# Euler method implementation
def euler_solve(f, h0, t_span):
h = [h0]
nfe = 0
for i in range(len(t_span) - 1):
dt = t_span[i+1] - t_span[i]
h_new = h[-1] + dt * f(t_span[i], h[-1])
h.append(h_new)
nfe += 1
return np.array(h), nfe
# RK4 method implementation
def rk4_solve(f, h0, t_span):
h = [h0]
nfe = 0
for i in range(len(t_span) - 1):
dt = t_span[i+1] - t_span[i]
t = t_span[i]
h_curr = h[-1]
k1 = f(t, h_curr)
k2 = f(t + dt/2, h_curr + dt/2 * k1)
k3 = f(t + dt/2, h_curr + dt/2 * k2)
k4 = f(t + dt, h_curr + dt * k3)
h_new = h_curr + (dt/6) * (k1 + 2*k2 + 2*k3 + k4)
h.append(h_new)
nfe += 4 # RK4 uses 4 function evaluations per step
return np.array(h), nfe
# Get exact solution
h_true = h_exact(t_span)
# Solve with different methods
h_euler, nfe_euler = euler_solve(f, h0, t_span)
h_rk4, nfe_rk4 = rk4_solve(f, h0, t_span)
# Compute errors
error_euler = np.abs(h_euler - h_true)
error_rk4 = np.abs(h_rk4 - h_true)
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Plot 1: Solutions
ax = axes[0]
ax.plot(t_span, h_true, 'k-', linewidth=3, label='Exact solution', alpha=0.7)
ax.plot(t_span, h_euler, 'b--', linewidth=2, label=f'Euler (NFE={nfe_euler})')
ax.plot(t_span, h_rk4, 'r:', linewidth=2, label=f'RK4 (NFE={nfe_rk4})')
ax.set_xlabel('Time $t$', fontsize=12)
ax.set_ylabel('$h(t)$', fontsize=12)
ax.set_title('ODE Solutions Comparison', fontweight='bold', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
# Plot 2: Errors
ax = axes[1]
ax.semilogy(t_span, error_euler, 'b-', linewidth=2, label='Euler error')
ax.semilogy(t_span, error_rk4, 'r-', linewidth=2, label='RK4 error')
ax.set_xlabel('Time $t$', fontsize=12)
ax.set_ylabel('Absolute Error (log scale)', fontsize=12)
ax.set_title('Approximation Error', fontweight='bold', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, which='both')
# Plot 3: Accuracy vs Cost
ax = axes[2]
methods = ['Euler', 'RK4']
nfes = [nfe_euler, nfe_rk4]
max_errors = [error_euler.max(), error_rk4.max()]
colors = ['blue', 'red']
for i, (method, nfe, err) in enumerate(zip(methods, nfes, max_errors)):
ax.scatter(nfe, err, s=300, c=colors[i], alpha=0.7,
edgecolors='black', linewidth=2, label=method, zorder=5)
ax.annotate(method, xy=(nfe, err), xytext=(10, 10),
textcoords='offset points', fontsize=11, fontweight='bold')
ax.set_xlabel('Number of Function Evaluations (NFE)', fontsize=12)
ax.set_ylabel('Maximum Error', fontsize=12)
ax.set_title('Accuracy vs Computational Cost', fontweight='bold', fontsize=13)
ax.set_yscale('log')
ax.grid(True, alpha=0.3, which='both')
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("ODE SOLVER COMPARISON")
print("="*70)
print(f"Method | NFE | Max Error | Mean Error")
print("-" * 70)
print(f"Euler | {nfe_euler:5d} | {error_euler.max():.2e} | {error_euler.mean():.2e}")
print(f"RK4 | {nfe_rk4:5d} | {error_rk4.max():.2e} | {error_rk4.mean():.2e}")
print("="*70)
print("\nKey Takeaways:")
print(" • RK4 is ~100× more accurate than Euler")
print(" • RK4 uses 4× more function evaluations per step")
print(" • For Neural ODEs: Adaptive solvers (dopri5) are best")
print(" • They automatically balance accuracy vs cost")
print("="*70)
compare_ode_solvers()
3.5. ODE Solvers: Theory and Comparison¶
Numerical Integration Methods¶
To solve \(\frac{dh}{dt} = f(h, t)\), we need numerical methods.
1. Euler Method (1st Order)¶
Formula: $\(h_{n+1} = h_n + \Delta t \cdot f(h_n, t_n)\)$
Properties:
✅ Simple, fast
❌ Low accuracy: \(O(\Delta t)\) local error
❌ Requires small step size
Global error: \(O(\Delta t)\)
2. Runge-Kutta 4 (RK4) - 4th Order¶
Formula: $\(k_1 = f(h_n, t_n)\)\( \)\(k_2 = f(h_n + \frac{\Delta t}{2}k_1, t_n + \frac{\Delta t}{2})\)\( \)\(k_3 = f(h_n + \frac{\Delta t}{2}k_2, t_n + \frac{\Delta t}{2})\)\( \)\(k_4 = f(h_n + \Delta t \cdot k_3, t_n + \Delta t)\)\( \)\(h_{n+1} = h_n + \frac{\Delta t}{6}(k_1 + 2k_2 + 2k_3 + k_4)\)$
Properties:
✅ High accuracy: \(O(\Delta t^4)\) local error
✅ Good balance of speed/accuracy
❌ Fixed step size
Global error: \(O(\Delta t^4)\)
3. Adaptive Methods (e.g., Dormand-Prince “dopri5”)¶
Key idea: Adjust step size based on local error estimate
Algorithm:
Compute two estimates: \(\hat{h}_{n+1}\) (5th order) and \(h_{n+1}\) (4th order)
Error estimate: \(\text{err} = ||\hat{h}_{n+1} - h_{n+1}||\)
If \(\text{err} < \text{tol}\): accept step, maybe increase \(\Delta t\)
If \(\text{err} > \text{tol}\): reject step, decrease \(\Delta t\)
Properties:
✅ Automatically adjusts precision
✅ Efficient for smooth ODEs
✅ Can handle stiff problems
⚠️ Variable computational cost
Solver Comparison¶
Solver |
Order |
Adaptive |
Speed |
Accuracy |
Use Case |
|---|---|---|---|---|---|
Euler |
1 |
❌ |
⚡⚡⚡ |
⭐ |
Quick prototyping |
RK4 |
4 |
❌ |
⚡⚡ |
⭐⭐⭐ |
Fixed timestep, moderate accuracy |
dopri5 |
5(4) |
✅ |
⚡ |
⭐⭐⭐⭐⭐ |
Production, high accuracy |
adams |
Variable |
✅ |
⚡⚡ |
⭐⭐⭐⭐ |
Smooth ODEs, long integration |
Choosing a Solver for Neural ODEs¶
For training:
Use dopri5 (default in torchdiffeq)
Adaptive stepping handles varying dynamics
Set reasonable tolerances:
rtol=1e-3,atol=1e-4
For inference:
Can use fixed-step rk4 for speed
Pre-determine good step size from training
For memory-constrained:
Use Euler with small steps
Trade accuracy for speed
Tolerance Settings¶
Relative tolerance (rtol): Error relative to state magnitude Absolute tolerance (atol): Minimum absolute error threshold
Recommendations:
Training: rtol=1e-3, atol=1e-4 (faster, slightly less accurate)
Evaluation: rtol=1e-5, atol=1e-6 (slower, more accurate)
Production: rtol=1e-4, atol=1e-5 (balanced)
# Compare ODE solvers empirically
def compare_ode_solvers():
"""
Compare different ODE solvers on a simple problem
Demonstrates accuracy vs speed trade-offs
"""
# Define a simple ODE: dh/dt = -h (exponential decay)
# True solution: h(t) = h(0) * exp(-t)
def f(t, h):
return -h
# Initial condition and time span
h0 = 1.0
t_span = np.linspace(0, 5, 100)
# True solution
h_true = h0 * np.exp(-t_span)
# Euler method
def euler_solve(f, h0, t_span):
h = np.zeros_like(t_span)
h[0] = h0
for i in range(len(t_span) - 1):
dt = t_span[i+1] - t_span[i]
h[i+1] = h[i] + dt * f(t_span[i], h[i])
return h
# RK4 method
def rk4_solve(f, h0, t_span):
h = np.zeros_like(t_span)
h[0] = h0
for i in range(len(t_span) - 1):
dt = t_span[i+1] - t_span[i]
k1 = f(t_span[i], h[i])
k2 = f(t_span[i] + dt/2, h[i] + dt*k1/2)
k3 = f(t_span[i] + dt/2, h[i] + dt*k2/2)
k4 = f(t_span[i] + dt, h[i] + dt*k3)
h[i+1] = h[i] + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
return h
# Solve with different methods
import time
start = time.time()
h_euler = euler_solve(f, h0, t_span)
euler_time = time.time() - start
start = time.time()
h_rk4 = rk4_solve(f, h0, t_span)
rk4_time = time.time() - start
# Compute errors
euler_error = np.abs(h_euler - h_true)
rk4_error = np.abs(h_rk4 - h_true)
# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Solutions
ax = axes[0, 0]
ax.plot(t_span, h_true, 'k-', linewidth=3, label='True solution', alpha=0.7)
ax.plot(t_span, h_euler, 'r--', linewidth=2, label='Euler method')
ax.plot(t_span, h_rk4, 'b--', linewidth=2, label='RK4 method')
ax.set_xlabel('Time $t$', fontsize=12)
ax.set_ylabel('$h(t)$', fontsize=12)
ax.set_title('ODE Solutions: $\\frac{dh}{dt} = -h$', fontweight='bold', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
# Plot 2: Errors (log scale)
ax = axes[0, 1]
ax.semilogy(t_span, euler_error + 1e-10, 'r-', linewidth=2, label='Euler error')
ax.semilogy(t_span, rk4_error + 1e-10, 'b-', linewidth=2, label='RK4 error')
ax.set_xlabel('Time $t$', fontsize=12)
ax.set_ylabel('Absolute Error (log scale)', fontsize=12)
ax.set_title('Solver Accuracy Comparison', fontweight='bold', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, which='both')
# Plot 3: Step visualization for Euler
ax = axes[1, 0]
t_demo = t_span[:10]
h_demo = h_euler[:10]
h_true_demo = h_true[:10]
ax.plot(t_demo, h_true_demo, 'k-', linewidth=3, label='True', alpha=0.7)
ax.plot(t_demo, h_demo, 'ro-', linewidth=2, markersize=8, label='Euler steps')
# Show Euler steps
for i in range(len(t_demo)-1):
# Tangent line
dt = t_demo[i+1] - t_demo[i]
slope = f(t_demo[i], h_demo[i])
t_tangent = np.array([t_demo[i], t_demo[i+1]])
h_tangent = h_demo[i] + slope * (t_tangent - t_demo[i])
ax.plot(t_tangent, h_tangent, 'r--', alpha=0.5, linewidth=1)
ax.set_xlabel('Time $t$', fontsize=12)
ax.set_ylabel('$h(t)$', fontsize=12)
ax.set_title('Euler Method: Following Tangent Lines', fontweight='bold', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
# Plot 4: Error comparison bar chart
ax = axes[1, 1]
methods = ['Euler', 'RK4']
max_errors = [euler_error.max(), rk4_error.max()]
mean_errors = [euler_error.mean(), rk4_error.mean()]
times = [euler_time * 1000, rk4_time * 1000] # Convert to ms
x = np.arange(len(methods))
width = 0.25
ax.bar(x - width, max_errors, width, label='Max error', alpha=0.8, color='red')
ax.bar(x, mean_errors, width, label='Mean error', alpha=0.8, color='blue')
ax.bar(x + width, [t/100 for t in times], width, label='Time (ms/100)', alpha=0.8, color='green')
ax.set_ylabel('Value', fontsize=12)
ax.set_title('Method Comparison', fontweight='bold', fontsize=13)
ax.set_xticks(x)
ax.set_xticklabels(methods, fontsize=12)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
print("\n" + "="*70)
print("ODE SOLVER COMPARISON")
print("="*70)
print(f"Problem: dh/dt = -h, h(0) = {h0}, t ∈ [0, 5]")
print(f"Steps: {len(t_span)}")
print("\nEuler Method (1st order):")
print(f" Max error: {euler_error.max():.6f}")
print(f" Mean error: {euler_error.mean():.6f}")
print(f" Time: {euler_time*1000:.2f} ms")
print("\nRK4 Method (4th order):")
print(f" Max error: {rk4_error.max():.6e}")
print(f" Mean error: {rk4_error.mean():.6e}")
print(f" Time: {rk4_time*1000:.2f} ms")
print("\nAccuracy Improvement: {:.1f}x better".format(euler_error.max() / rk4_error.max()))
print("Time Cost: {:.1f}x slower".format(rk4_time / euler_time))
print("="*70)
compare_ode_solvers()
# Neural ODE implementation
class ODEFunc(nn.Module):
"""ODE function: dh/dt = f(h, t)"""
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, 64),
nn.Tanh(),
nn.Linear(64, dim)
)
# Explicitly depend on time
self.time_net = nn.Linear(1, dim)
def forward(self, t, h):
"""
Args:
t: current time (scalar)
h: current state (batch_size, dim)
Returns:
dh/dt
"""
# Time-dependent dynamics
t_vec = torch.ones(h.shape[0], 1).to(h.device) * t
time_effect = self.time_net(t_vec)
return self.net(h) + time_effect
class ODEBlock(nn.Module):
"""Neural ODE block"""
def __init__(self, odefunc, integration_time=1.0):
super().__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, integration_time])
def forward(self, x):
"""
Args:
x: input (batch_size, dim)
Returns:
output after ODE integration
"""
self.integration_time = self.integration_time.to(x.device)
# Solve ODE from t=0 to t=integration_time
out = odeint(self.odefunc, x, self.integration_time,
method='dopri5', rtol=1e-3, atol=1e-4)
return out[1] # Return state at final time
# Test ODE block
dim = 64
odefunc = ODEFunc(dim).to(device)
ode_block = ODEBlock(odefunc, integration_time=1.0).to(device)
# Forward pass
x_test = torch.randn(16, dim).to(device)
y_test = ode_block(x_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {y_test.shape}")
print(f"✅ Neural ODE block working correctly!")
4. Neural ODE Classifier¶
Architecture¶
Downsample: Reduce spatial dimensions (for images)
ODE Block: Continuous transformation
Classifier: Linear layer for prediction
MNIST example:
Input: 28×28 images
Flatten → 784 dims
Reduce to 64 dims
ODE block
Classify to 10 classes
# Neural ODE Classifier for MNIST
class NeuralODEClassifier(nn.Module):
"""Neural ODE for MNIST classification"""
def __init__(self, input_dim=784, hidden_dim=64, num_classes=10):
super().__init__()
# Downsampling
self.downsample = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, hidden_dim)
)
# Neural ODE
odefunc = ODEFunc(hidden_dim)
self.ode_block = ODEBlock(odefunc, integration_time=1.0)
# Classifier
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
"""Forward pass"""
# Flatten
x = x.view(x.size(0), -1)
# Downsample
h = self.downsample(x)
# ODE integration
h = self.ode_block(h)
# Classify
out = self.classifier(h)
return out
# Initialize model
model = NeuralODEClassifier().to(device)
print("Neural ODE Classifier:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load MNIST data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
# Visualize samples
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for i, ax in enumerate(axes.flat):
img, label = train_dataset[i]
ax.imshow(img.squeeze(), cmap='gray')
ax.set_title(f'{label}', fontsize=12)
ax.axis('off')
plt.suptitle('MNIST Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# Train Neural ODE
def train_model(model, train_loader, test_loader, n_epochs=5, lr=1e-3):
"""Train Neural ODE classifier"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
for epoch in range(n_epochs):
# Training
model.train()
train_loss, correct, total = 0, 0, 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if (batch_idx + 1) % 100 == 0:
print(f' Batch {batch_idx+1}/{len(train_loader)} | '
f'Loss: {loss.item():.4f}')
train_acc = 100. * correct / total
# Testing
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
test_correct += pred.eq(target).sum().item()
test_total += target.size(0)
test_acc = 100. * test_correct / test_total
history['train_loss'].append(train_loss / len(train_loader))
history['train_acc'].append(train_acc)
history['test_acc'].append(test_acc)
print(f'Epoch {epoch+1}/{n_epochs} | '
f'Train Loss: {history["train_loss"][-1]:.4f} | '
f'Train Acc: {train_acc:.2f}% | '
f'Test Acc: {test_acc:.2f}%')
print('-' * 70)
return history
print("Training Neural ODE on MNIST...")
print("="*70)
history = train_model(model, train_loader, test_loader, n_epochs=5, lr=1e-3)
# Visualize training
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Loss
axes[0].plot(history['train_loss'], 'o-', linewidth=2, markersize=8, label='Train Loss')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
# Accuracy
axes[1].plot(history['train_acc'], 'o-', linewidth=2, markersize=8, label='Train Accuracy')
axes[1].plot(history['test_acc'], 's-', linewidth=2, markersize=8, label='Test Accuracy')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Classification Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\n✅ Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
5. Summary¶
Neural ODEs¶
✅ Continuous-depth models: Replace discrete layers with ODE ✅ Memory efficient: Constant memory via adjoint method ✅ Adaptive computation: ODE solver chooses steps automatically
Key Ideas¶
Forward pass: Solve ODE \(\frac{dh}{dt} = f_\theta(h, t)\) from \(t=0\) to \(t=T\)
Backward pass: Adjoint method computes gradients without storing states
Implementation: Use
torchdiffeqwithodeint_adjoint
Comparison¶
Property |
ResNet |
Neural ODE |
|---|---|---|
Depth |
Discrete layers |
Continuous |
Memory |
\(O(L)\) |
\(O(1)\) |
Forward cost |
\(O(L)\) |
Adaptive |
Evaluation cost |
Fixed |
Variable |
Implementation¶
✅ Implemented ODE function and ODE block ✅ Trained Neural ODE classifier on MNIST ✅ Achieved competitive accuracy with constant memory
Advantages¶
Memory efficient for very deep models
Adaptive computation (more steps for complex inputs)
Continuous representations (smooth trajectories)
Limitations¶
⚠️ Slower training (backward ODE solve) ⚠️ Numerical errors from ODE solver ⚠️ Harder to debug than discrete models
Next Notebook: 05_3d_vision_introduction.ipynb
References¶
Chen et al. “Neural Ordinary Differential Equations” (NeurIPS 2018)
Grathwohl et al. “FFJORD: Free-form Continuous Dynamics” (ICLR 2019)
5.5. Continuous Normalizing Flows (CNF)¶
Motivation: Change of Variables Formula¶
For invertible transformation \(z = f(x)\):
Problem: Computing Jacobian determinant is \(O(d^3)\) for \(d\)-dimensional data!
Neural ODE Solution¶
Key idea: Use continuous transformation via ODE:
Instantaneous change of variables:
Final density:
Advantages of CNF¶
1. Arbitrary Architecture
No invertibility constraint on \(f_\theta\)
Can use any neural network architecture
More expressive than explicit flows (RealNVP, Glow)
2. Efficient Trace Computation
Use Hutchinson’s trace estimator:
For Jacobian: $\(\text{tr}\left(\frac{\partial f}{\partial z}\right) = \mathbb{E}_\epsilon\left[\epsilon^T \frac{\partial f}{\partial z} \epsilon\right]\)$
Compute via: $\(\epsilon^T \frac{\partial f}{\partial z} \epsilon = \epsilon^T \frac{\partial}{\partial z}[f^T \epsilon]\)$
using vector-Jacobian product (efficient in automatic differentiation!)
Cost: \(O(d)\) instead of \(O(d^3)\)! ✅
3. Continuous Dynamics
Smooth transformations
Can model complex distributions
Training stability
CNF Training Objective¶
Maximum likelihood:
where \(z(1) = f_\theta(x)\) is obtained by solving ODE forward
Prior: Typically \(p_1(z) = \mathcal{N}(0, I)\)
Comparison with Other Generative Models¶
Model |
Pros |
Cons |
|---|---|---|
VAE |
Fast sampling |
Approximate inference, blurry samples |
GAN |
Sharp samples |
Mode collapse, training instability |
Explicit Flow |
Exact likelihood |
Architectural constraints (invertible) |
CNF |
Exact likelihood, flexible architecture |
Slow sampling (ODE solve) |
Practical Considerations¶
Training:
Use adaptive ODE solver (dopri5)
Tolerance:
rtol=1e-5, atol=1e-7Hutchinson estimator with 1-2 samples
Sampling:
Can use faster fixed-step solvers (RK4)
Or reduce tolerance for speed
Trade-off: quality vs speed
Regularization:
Add regularization to encourage simple dynamics
Kinetic energy: \(\int_0^1 ||f_\theta(z(t), t)||^2 dt\)
Jacobian Frobenius norm: \(\int_0^1 ||\frac{\partial f}{\partial z}||_F^2 dt\)
# Continuous Normalizing Flow Implementation
class CNF(nn.Module):
"""
Continuous Normalizing Flow using Neural ODE
Implements density estimation via change of variables:
log p(x) = log p(z) - ∫ tr(∂f/∂z) dt
"""
def __init__(self, dim, hidden_dim=64):
super().__init__()
self.dim = dim
# Dynamics network
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, dim)
)
def forward(self, t, states):
"""
ODE dynamics with trace computation
states = [z, log_p]
Returns d[z, log_p]/dt
"""
z = states[0]
batch_size = z.shape[0]
# Compute f(z, t)
with torch.enable_grad():
z.requires_grad_(True)
dz_dt = self.net(z)
# Hutchinson trace estimator
# tr(∂f/∂z) ≈ ε^T (∂f/∂z) ε where ε ~ N(0,I)
epsilon = torch.randn_like(z)
# Compute ε^T (∂f/∂z) ε via vector-Jacobian product
# This is equivalent to ε^T ∂(f^T ε)/∂z
vjp = torch.autograd.grad(
(dz_dt * epsilon).sum(), z,
create_graph=True, retain_graph=True
)[0]
trace_estimate = (vjp * epsilon).sum(dim=1, keepdim=True)
# Change in log probability
dlog_p_dt = -trace_estimate
return (dz_dt, dlog_p_dt)
def log_prob(self, x, integration_time=1.0):
"""
Compute log probability of data x
Returns:
log_prob: log p(x)
z: latent encoding
"""
batch_size = x.shape[0]
# Initial log probability (starts at 0)
log_p0 = torch.zeros(batch_size, 1).to(x.device)
# Solve ODE forward
states = (x, log_p0)
# Note: In practice, use odeint from torchdiffeq
# For demonstration, this shows the concept
# z_T, log_pT = odeint(self, states, torch.tensor([0., integration_time]))
# Here we just return conceptual output
# In real implementation, integrate the ODE
z = x # Would be transformed via ODE
log_p = log_p0 # Would be updated via integral
# Prior: standard Gaussian
log_prior = -0.5 * (z**2).sum(dim=1, keepdim=True) - \
0.5 * self.dim * np.log(2 * np.pi)
# Final log probability
log_prob = log_prior + log_p
return log_prob, z
def sample(self, num_samples, integration_time=1.0):
"""
Generate samples by solving ODE backward from prior
1. Sample z ~ N(0, I)
2. Solve ODE backward: z(0) ← z(T)
3. Return x = z(0)
"""
device = next(self.parameters()).device
# Sample from prior
z = torch.randn(num_samples, self.dim).to(device)
# Note: Solve ODE backward
# In practice: x = odeint(self, z, torch.tensor([integration_time, 0.]))
return z # Would be transformed via backward ODE
# Demonstrate CNF concept
print("="*70)
print("CONTINUOUS NORMALIZING FLOW (CNF)")
print("="*70)
print("\nKey Components:")
print(" 1. Forward ODE: x → z (data to latent)")
print(" • Transform data to simple prior distribution")
print(" • Track log probability change via trace computation")
print("\n 2. Backward ODE: z → x (latent to data)")
print(" • Sample from prior: z ~ N(0, I)")
print(" • Solve ODE backward to generate data")
print("\n 3. Trace Estimation:")
print(" • Use Hutchinson estimator: tr(J) ≈ ε^T J ε")
print(" • Computed efficiently via vector-Jacobian product")
print(" • Cost: O(d) instead of O(d³)")
print("\nAdvantages:")
print(" ✅ Exact likelihood (unlike VAE)")
print(" ✅ Flexible architecture (unlike explicit flows)")
print(" ✅ Stable training (unlike GAN)")
print(" ✅ Continuous transformations")
print("\nDisadvantages:")
print(" ⚠️ Slow sampling (requires ODE solve)")
print(" ⚠️ Memory for adjoint method")
print("="*70)
# Create example CNF
cnf = CNF(dim=2, hidden_dim=32)
print(f"\nCNF Model:")
print(f" Input dimension: 2")
print(f" Hidden dimension: 32")
print(f" Parameters: {sum(p.numel() for p in cnf.parameters())}")
5.6. Advanced Topics and Practical Considerations¶
Stiffness in Neural ODEs¶
Stiff ODEs have dynamics at multiple time scales:
Fast dynamics (small \(\Delta t\) needed)
Slow dynamics (large \(\Delta t\) possible)
Problem: Explicit solvers (Euler, RK4) require tiny steps for stability Solution: Use implicit solvers or regularization
Regularization Techniques:
1. Spectral Normalization Constrain Lipschitz constant of \(f_\theta\): $\(||f_\theta(z_1, t) - f_\theta(z_2, t)|| \leq L ||z_1 - z_2||\)$
Prevents explosive gradients and improves ODE solver stability
2. Jacobian Regularization Add penalty to encourage smooth dynamics: $\(\mathcal{L}_{reg} = \lambda \mathbb{E}\left[||\frac{\partial f_\theta}{\partial z}||_F^2\right]\)$
3. Kinetic Energy Regularization Penalize large velocities: $\(\mathcal{L}_{kinetic} = \lambda \int_0^T ||f_\theta(z(t), t)||^2 dt\)$
Augmented Neural ODEs¶
Motivation: Standard Neural ODEs may not be expressive enough
Solution: Augment state with extra dimensions:
where:
\(z(t) \in \mathbb{R}^d\): original state
\(a(t) \in \mathbb{R}^p\): augmented dimensions (typically \(p \ll d\))
Benefits:
Increased expressiveness
Can learn more complex trajectories
Better approximation of diffeomorphisms
Initialization: \(a(0) = \mathbf{0}\) (zeros)
Second-Order Neural ODEs¶
Standard (first-order): $\(\frac{dz}{dt} = f_\theta(z, t)\)$
Second-order (Hamiltonian-inspired): $\(\frac{dz}{dt} = v, \quad \frac{dv}{dt} = f_\theta(z, v, t)\)$
Advantages:
More stable dynamics
Better long-term behavior
Physical interpretability (position + velocity)
Latent ODEs for Irregular Time Series¶
Problem: Time series with irregular sampling, missing data
Solution: Encode to latent ODE, then decode
Architecture:
Recognition network: \(h_0 = \text{RNN}(\{(t_i, x_i)\})\)
Latent ODE: \(h(t) = \text{ODESolve}(f_\theta, h_0, [0, T])\)
Decoder: \(\hat{x}_i = \text{Decoder}(h(t_i))\)
Key insight: ODE is continuous, can predict at any \(t\)!
Applications:
Medical time series (irregular patient visits)
Climate data (varying measurement intervals)
Financial data (non-uniform trading times)
Neural CDE (Controlled Differential Equations)¶
Extension: Let input data control the ODE
where \(X(t)\) is a path (interpolated from data)
Properties:
Naturally handles irregular sampling
Path-dependent dynamics
State-of-the-art for time series
Practical Training Tips¶
1. Gradient Clipping
Clip gradients to prevent explosion
Typical:
max_norm=1.0
2. Learning Rate Schedule
Start with higher LR:
1e-3Decay to
1e-5or1e-6Cosine annealing works well
3. Monitoring NFE
Track number of function evaluations
Should be stable during training
Sudden increases indicate instability
4. Tolerance Scheduling
Start with relaxed tolerance:
rtol=1e-3Gradually tighten: final
rtol=1e-5Balances speed early, accuracy late
5. Batch Size
Smaller batches often better
ODE solver adaptive to each sample
Larger batches → more NFE variation
When to Use Neural ODEs¶
Good for: ✅ Time series with irregular sampling ✅ Continuous normalizing flows ✅ Memory-constrained scenarios ✅ Physically-inspired models ✅ Generative modeling with exact likelihood
Not ideal for: ❌ Need for very fast inference ❌ Extremely high-dimensional data (images) ❌ When discrete depth is sufficient ❌ Limited computational budget for training
4.5. Advanced Application: Continuous Normalizing Flows (CNF)¶
Motivation: Generative Modeling¶
Goal: Learn complex probability distributions \(p(x)\)
Approach: Transform simple base distribution \(p(z)\) to data distribution \(p(x)\)
Normalizing Flows Recap¶
Discrete flow: $\(z_0 \xrightarrow{f_1} z_1 \xrightarrow{f_2} \cdots \xrightarrow{f_T} z_T = x\)$
Change of variables: $\(\log p(x) = \log p(z_0) - \sum_{t=1}^T \log \left|\det \frac{\partial f_t}{\partial z_{t-1}}\right|\)$
Problem: Computing determinant is \(O(d^3)\) for \(d\)-dimensional data!
Continuous Normalizing Flows (FFJORD)¶
Idea: Use Neural ODE for continuous transformation
Instantaneous change of variables formula:
This is the trace of the Jacobian, much cheaper than determinant!
Complete Density Formula¶
Integrating from \(t=0\) to \(t=1\):
Computing the Trace Efficiently¶
Problem: Computing trace still requires computing all diagonal elements of Jacobian
Solution (Hutchinson’s Trace Estimator):
For random vector \(\epsilon \sim \mathcal{N}(0, I)\): $\(\text{tr}(A) = \mathbb{E}_\epsilon[\epsilon^T A \epsilon]\)$
Implementation: $\(\text{tr}\left(\frac{\partial f}{\partial z}\right) \approx \epsilon^T \frac{\partial f}{\partial z} \epsilon = \epsilon^T \frac{\partial (f^T \epsilon)}{\partial z}\)$
Only requires one vector-Jacobian product (via automatic differentiation)!
Cost: \(O(d)\) instead of \(O(d^3)\)!
Advantages of CNF¶
✅ Exact likelihood: Can compute \(p(x)\) exactly (unlike VAE)
✅ Free-form Jacobians: No architectural constraints (unlike RealNVP, Glow)
✅ Scalable: \(O(d)\) cost via trace estimator
✅ Invertible: Can sample and compute likelihood
CNF Training¶
Objective: Maximize log-likelihood
Algorithm:
Sample data \(x\)
Solve ODE backward: \(x = z(1) \to z(0)\)
Accumulate trace integral: \(\int_1^0 \text{tr}(\partial f / \partial z) dt\)
Compute loss: \(-\log p(z(0)) + \text{trace integral}\)
Backprop through adjoint method
Applications¶
Density estimation: Model complex distributions
Generative modeling: Sample realistic data
Variational inference: Flexible posteriors
Anomaly detection: Detect out-of-distribution samples