Neural Ordinary Differential Equations (Neural ODEs)

Learning Objectives:

  • Understand continuous-depth networks

  • Implement adjoint method for backpropagation

  • Train Neural ODE models

  • Apply to time series and generative modeling

Prerequisites: Deep learning, differential equations, numerical methods

Time: 90 minutes

📚 Reference Materials:

1. From ResNets to ODEs

ResNet Blocks

Standard ResNet block: $\(h_{t+1} = h_t + f_\theta(h_t)\)$

where \(f_\theta\) is a neural network (typically conv layers + activation).

Iterating \(T\) blocks: $\(h_T = h_0 + \sum_{t=0}^{T-1} f_\theta(h_t, t)\)$

Continuous Limit

As depth \(T \to \infty\) with step size \(\Delta t \to 0\):

\[h_{t+1} = h_t + f_\theta(h_t, t) \cdot \Delta t\]

becomes an ODE:

\[\boxed{\frac{dh(t)}{dt} = f_\theta(h(t), t)}\]

with initial condition \(h(0) = h_0\) and solution \(h(T)\) at time \(T\).

Key insight: Neural network depth becomes continuous!

Neural ODE Layer

Forward pass: Solve ODE from \(t=0\) to \(t=T\) $\(h(T) = h(0) + \int_0^T f_\theta(h(t), t) dt\)$

Use ODE solver (e.g., Runge-Kutta, Adams-Bashforth)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check if torchdiffeq is available
try:
    from torchdiffeq import odeint_adjoint as odeint
    print("✅ torchdiffeq is available")
except ImportError:
    print("⚠️  Installing torchdiffeq...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'torchdiffeq'])
    from torchdiffeq import odeint_adjoint as odeint
    print("✅ torchdiffeq installed successfully")

2. The Adjoint Method

Problem: Memory Cost

Standard backpropagation:

  • Forward pass: Compute and store all intermediate states

  • Backward pass: Use stored states to compute gradients

  • Memory: \(O(T \cdot d)\) where \(T\) is number of steps, \(d\) is state dimension

For very deep networks: Memory becomes prohibitive!

Solution: Adjoint Sensitivity Method

Idea: Compute gradients by solving another ODE backward in time

Adjoint state: $\(a(t) = \frac{\partial L}{\partial h(t)}\)$

where \(L\) is the loss.

Adjoint ODE: $\(\frac{da(t)}{dt} = -a(t)^T \frac{\partial f_\theta}{\partial h}\)$

Gradient: $\(\frac{\partial L}{\partial \theta} = -\int_T^0 a(t)^T \frac{\partial f_\theta}{\partial \theta} dt\)$

Key advantage: Only need to store \(a(T)\) and integrate backward!

Memory: \(O(1)\) instead of \(O(T)\)! ✅

2.5. Adjoint Method: Complete Derivation

Forward Problem

Given ODE initial value problem: $\(\frac{dh(t)}{dt} = f_\theta(h(t), t), \quad h(t_0) = h_0\)$

Solution at time \(t_1\): $\(h(t_1) = h_0 + \int_{t_0}^{t_1} f_\theta(h(t), t) dt\)$

Loss Function

Define loss \(L\) that depends on final state: $\(L = L(h(t_1))\)$

Goal: Compute \(\frac{dL}{d\theta}\) and \(\frac{dL}{dh_0}\) efficiently

Adjoint State Definition

Define adjoint state: $\(a(t) \triangleq \frac{\partial L}{\partial h(t)}\)$

Physical interpretation: Sensitivity of loss to hidden state at time \(t\)

Deriving the Adjoint ODE

Starting from the definition: $\(a(t) = \frac{\partial L}{\partial h(t)}\)$

Take derivative w.r.t. time: $\(\frac{da(t)}{dt} = \frac{d}{dt}\left(\frac{\partial L}{\partial h(t)}\right)\)$

By chain rule and using \(\frac{dh}{dt} = f_\theta(h, t)\): $\(\frac{da(t)}{dt} = -a(t)^T \frac{\partial f_\theta(h(t), t)}{\partial h}\)$

This is the adjoint ODE! It evolves backward in time from \(t_1\) to \(t_0\).

Boundary condition: \(a(t_1) = \frac{\partial L}{\partial h(t_1)}\)

Gradient Computation

Gradient w.r.t. parameters:

\[\frac{dL}{d\theta} = -\int_{t_1}^{t_0} a(t)^T \frac{\partial f_\theta(h(t), t)}{\partial \theta} dt\]

Gradient w.r.t. initial state:

\[\frac{dL}{dh_0} = a(t_0)\]

The Augmented System

In practice, solve augmented ODE backward:

\[\begin{split}\frac{d}{dt}\begin{bmatrix} h(t) \\ a(t) \\ \frac{\partial L}{\partial \theta}(t) \end{bmatrix} = \begin{bmatrix} -f_\theta(h, t) \\ -a^T \frac{\partial f_\theta}{\partial h} \\ -a^T \frac{\partial f_\theta}{\partial \theta} \end{bmatrix}\end{split}\]

from \(t=t_1\) back to \(t=t_0\)

Key advantages:

  1. Memory: \(O(1)\) instead of \(O(T)\) where \(T\) is number of forward steps

  2. Flexibility: Can use adaptive ODE solvers without worrying about gradient computation

  3. Accuracy: Gradient accuracy controlled by ODE solver tolerance

Comparison: Standard Backprop vs Adjoint

Method

Memory

Accuracy

Flexibility

Standard Backprop

\(O(T \cdot d)\)

Exact

Limited to fixed discretization

Adjoint Method

\(O(d)\)

ODE solver tolerance

Adaptive solvers supported

where \(T\) = number of steps, \(d\) = state dimension

# Visualize adjoint computation

def visualize_adjoint_method():
    """
    Demonstrate forward and backward passes in Neural ODE
    Shows memory efficiency of adjoint method
    """
    
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))
    
    # Time points
    t = np.linspace(0, 1, 100)
    
    # Forward pass: hidden state trajectory
    h_trajectory = np.sin(2 * np.pi * t) + 0.5 * np.cos(4 * np.pi * t)
    
    ax = axes[0]
    ax.plot(t, h_trajectory, 'b-', linewidth=3, label='$h(t)$ trajectory')
    ax.scatter([0, 1], [h_trajectory[0], h_trajectory[-1]], 
               s=200, c=['green', 'red'], zorder=5, edgecolors='black', linewidths=2)
    ax.annotate('$h(0)$ (initial)', xy=(0, h_trajectory[0]), 
                xytext=(0.15, h_trajectory[0] + 0.3),
                fontsize=12, fontweight='bold',
                arrowprops=dict(arrowstyle='->', lw=2, color='green'))
    ax.annotate('$h(1)$ (final)', xy=(1, h_trajectory[-1]), 
                xytext=(0.85, h_trajectory[-1] - 0.4),
                fontsize=12, fontweight='bold',
                arrowprops=dict(arrowstyle='->', lw=2, color='red'))
    
    # Add arrows showing forward direction
    for i in range(0, 90, 20):
        ax.annotate('', xy=(t[i+10], h_trajectory[i+10]), 
                   xytext=(t[i], h_trajectory[i]),
                   arrowprops=dict(arrowstyle='->', lw=2, color='blue', alpha=0.6))
    
    ax.set_xlabel('Time $t$', fontsize=13)
    ax.set_ylabel('Hidden State $h(t)$', fontsize=13)
    ax.set_title('Forward Pass: Solve ODE $\\frac{dh}{dt} = f_\\theta(h, t)$', 
                 fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=12)
    
    # Backward pass: adjoint trajectory
    a_trajectory = np.exp(-2 * (t - 1)**2)  # Adjoint decays from t=1 to t=0
    grad_accumulation = np.cumsum(a_trajectory[::-1])[::-1] / 100
    
    ax = axes[1]
    
    # Plot adjoint trajectory
    ax.plot(t, a_trajectory, 'r-', linewidth=3, label='$a(t)$ (adjoint state)')
    ax.fill_between(t, 0, a_trajectory, alpha=0.2, color='red')
    
    # Plot gradient accumulation
    ax.plot(t, grad_accumulation / grad_accumulation.max() * a_trajectory.max() * 0.8, 
            'g--', linewidth=2, label='$\\frac{\\partial L}{\\partial \\theta}$ accumulation')
    
    ax.scatter([0, 1], [a_trajectory[-1], a_trajectory[0]], 
               s=200, c=['green', 'red'], zorder=5, edgecolors='black', linewidths=2)
    ax.annotate('$a(1) = \\frac{\\partial L}{\\partial h(1)}$', xy=(1, a_trajectory[0]), 
                xytext=(0.75, a_trajectory[0] + 0.15),
                fontsize=12, fontweight='bold',
                arrowprops=dict(arrowstyle='->', lw=2, color='red'))
    ax.annotate('$a(0) = \\frac{\\partial L}{\\partial h(0)}$', xy=(0, a_trajectory[-1]), 
                xytext=(0.15, a_trajectory[-1] + 0.15),
                fontsize=12, fontweight='bold',
                arrowprops=dict(arrowstyle='->', lw=2, color='green'))
    
    # Add arrows showing backward direction
    for i in range(90, 0, -20):
        ax.annotate('', xy=(t[i-10], a_trajectory[i-10]), 
                   xytext=(t[i], a_trajectory[i]),
                   arrowprops=dict(arrowstyle='->', lw=2, color='red', alpha=0.6))
    
    ax.set_xlabel('Time $t$', fontsize=13)
    ax.set_ylabel('Adjoint State $a(t)$', fontsize=13)
    ax.set_title('Backward Pass: Solve Adjoint ODE $\\frac{da}{dt} = -a^T \\frac{\\partial f}{\\partial h}$ (backward in time)', 
                 fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*70)
    print("ADJOINT METHOD EXPLANATION")
    print("="*70)
    print("Forward Pass (t: 0 → 1):")
    print("  • Solve ODE to get h(t) trajectory")
    print("  • Only store final state h(1)")
    print("  • Memory: O(1) - constant!")
    print("\nBackward Pass (t: 1 → 0):")
    print("  • Compute a(1) = ∂L/∂h(1) from loss")
    print("  • Solve adjoint ODE backward in time")
    print("  • Accumulate ∂L/∂θ along the way")
    print("  • Result: a(0) = ∂L/∂h(0) and ∂L/∂θ")
    print("\nKey Insight:")
    print("  • No need to store intermediate states!")
    print("  • Recompute h(t) during backward pass if needed")
    print("  • Memory efficient for very deep networks")
    print("="*70)

visualize_adjoint_method()
# Visualize ResNet vs Neural ODE

def plot_resnet_vs_ode():
    """Compare discrete ResNet with continuous Neural ODE"""
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # ResNet: Discrete layers
    t_discrete = np.arange(0, 11)
    h_discrete = np.cumsum(np.random.randn(11) * 0.3)  # Random walk
    
    axes[0].plot(t_discrete, h_discrete, 'o-', linewidth=2, markersize=10, color='blue')
    for i, (x, y) in enumerate(zip(t_discrete, h_discrete)):
        axes[0].annotate(f'Layer {i}', (x, y), textcoords="offset points", 
                        xytext=(0,10), ha='center', fontsize=9)
    axes[0].set_xlabel('Layer Index', fontsize=12)
    axes[0].set_ylabel('Hidden State', fontsize=12)
    axes[0].set_title('ResNet: Discrete Layers', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    
    # Neural ODE: Continuous
    t_continuous = np.linspace(0, 10, 100)
    h_continuous = np.cumsum(np.random.randn(100) * 0.1)  # Smooth trajectory
    
    axes[1].plot(t_continuous, h_continuous, '-', linewidth=2, color='green')
    axes[1].fill_between(t_continuous, h_continuous-0.2, h_continuous+0.2, 
                        alpha=0.2, color='green', label='Continuous flow')
    axes[1].set_xlabel('Time $t$', fontsize=12)
    axes[1].set_ylabel('Hidden State $h(t)$', fontsize=12)
    axes[1].set_title('Neural ODE: Continuous Depth', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("ResNet: Discrete transformations at each layer")
    print("Neural ODE: Smooth continuous evolution governed by ODE")

plot_resnet_vs_ode()

3. Neural ODE Implementation

ODE Function

Define \(f_\theta(h, t)\): the dynamics function

Requirements:

  • Input: state \(h\) and time \(t\)

  • Output: derivative \(\frac{dh}{dt}\)

ODE Block

Components:

  1. ODE function \(f_\theta\)

  2. Integration time \([0, T]\)

  3. ODE solver (e.g., dopri5, rk4)

Usage:

from torchdiffeq import odeint_adjoint

# Forward pass
h_T = odeint_adjoint(odefunc, h_0, t=[0, T])

The _adjoint version uses the adjoint method for backpropagation!

3.5. ODE Solvers: Theory and Practice

Common ODE Solver Methods

1. Euler Method (1st order) $\(h_{n+1} = h_n + \Delta t \cdot f(h_n, t_n)\)$

  • Pros: Simple, fast

  • Cons: Low accuracy, requires small step size

  • Error: \(O(\Delta t)\) per step, \(O(\Delta t)\) globally

2. Runge-Kutta 4 (RK4, 4th order) $\(k_1 = f(h_n, t_n)\)\( \)\(k_2 = f(h_n + \frac{\Delta t}{2}k_1, t_n + \frac{\Delta t}{2})\)\( \)\(k_3 = f(h_n + \frac{\Delta t}{2}k_2, t_n + \frac{\Delta t}{2})\)\( \)\(k_4 = f(h_n + \Delta t \cdot k_3, t_n + \Delta t)\)\( \)\(h_{n+1} = h_n + \frac{\Delta t}{6}(k_1 + 2k_2 + 2k_3 + k_4)\)$

  • Pros: More accurate, stable

  • Cons: 4× function evaluations per step

  • Error: \(O(\Delta t^4)\) per step, \(O(\Delta t^4)\) globally

3. Adaptive Methods (Dormand-Prince, dopri5)

Automatically adjust step size based on error estimate:

  • Use two different order methods

  • Compare results to estimate error

  • Increase/decrease \(\Delta t\) accordingly

Advantages:

  • Accuracy: Maintains user-specified tolerance

  • Efficiency: Large steps where possible, small where needed

  • Reliability: Detects stiff regions automatically

Solver Selection Guidelines

Scenario

Recommended Solver

Reason

Training

dopri5 (adaptive)

Balance speed/accuracy, handles varying dynamics

Inference (speed)

Euler or RK4 (fixed)

Faster, predictable cost

Inference (quality)

dopri5

Best quality generations

Stiff problems

Implicit solvers

Stability for fast/slow dynamics

Trade-offs in Neural ODEs

Number of Function Evaluations (NFE):

  • Measures computational cost

  • Varies with solver and tolerance

  • Typical: 20-100 NFE for training, 5-20 for inference

Tolerance:

  • rtol (relative tolerance): \(10^{-3}\) to \(10^{-7}\)

  • atol (absolute tolerance): \(10^{-4}\) to \(10^{-9}\)

  • Lower tolerance → More accurate, more NFE

  • Higher tolerance → Faster, less accurate

Practical Tips:

  1. Start with rtol=1e-3, atol=1e-4 for training

  2. Use rtol=1e-5, atol=1e-6 for high-quality generation

  3. Monitor NFE during training (should be stable)

  4. If NFE explodes, network might be poorly conditioned

# Demonstrate different ODE solvers

def compare_ode_solvers():
    """
    Compare Euler, RK4, and adaptive (dopri5) methods
    Demonstrates accuracy vs computational cost trade-off
    """
    
    # Define a simple ODE: dh/dt = -2h + sin(t)
    def f(t, h):
        return -2 * h + np.sin(t)
    
    # Analytical solution for comparison
    def h_exact(t):
        return (1/5) * (np.sin(t) - 2*np.cos(t) + 2*np.exp(-2*t))
    
    # Time span
    t_span = np.linspace(0, 5, 200)
    h0 = 0.0
    
    # Euler method implementation
    def euler_solve(f, h0, t_span):
        h = [h0]
        nfe = 0
        for i in range(len(t_span) - 1):
            dt = t_span[i+1] - t_span[i]
            h_new = h[-1] + dt * f(t_span[i], h[-1])
            h.append(h_new)
            nfe += 1
        return np.array(h), nfe
    
    # RK4 method implementation
    def rk4_solve(f, h0, t_span):
        h = [h0]
        nfe = 0
        for i in range(len(t_span) - 1):
            dt = t_span[i+1] - t_span[i]
            t = t_span[i]
            h_curr = h[-1]
            
            k1 = f(t, h_curr)
            k2 = f(t + dt/2, h_curr + dt/2 * k1)
            k3 = f(t + dt/2, h_curr + dt/2 * k2)
            k4 = f(t + dt, h_curr + dt * k3)
            
            h_new = h_curr + (dt/6) * (k1 + 2*k2 + 2*k3 + k4)
            h.append(h_new)
            nfe += 4  # RK4 uses 4 function evaluations per step
        return np.array(h), nfe
    
    # Get exact solution
    h_true = h_exact(t_span)
    
    # Solve with different methods
    h_euler, nfe_euler = euler_solve(f, h0, t_span)
    h_rk4, nfe_rk4 = rk4_solve(f, h0, t_span)
    
    # Compute errors
    error_euler = np.abs(h_euler - h_true)
    error_rk4 = np.abs(h_rk4 - h_true)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    
    # Plot 1: Solutions
    ax = axes[0]
    ax.plot(t_span, h_true, 'k-', linewidth=3, label='Exact solution', alpha=0.7)
    ax.plot(t_span, h_euler, 'b--', linewidth=2, label=f'Euler (NFE={nfe_euler})')
    ax.plot(t_span, h_rk4, 'r:', linewidth=2, label=f'RK4 (NFE={nfe_rk4})')
    ax.set_xlabel('Time $t$', fontsize=12)
    ax.set_ylabel('$h(t)$', fontsize=12)
    ax.set_title('ODE Solutions Comparison', fontweight='bold', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Errors
    ax = axes[1]
    ax.semilogy(t_span, error_euler, 'b-', linewidth=2, label='Euler error')
    ax.semilogy(t_span, error_rk4, 'r-', linewidth=2, label='RK4 error')
    ax.set_xlabel('Time $t$', fontsize=12)
    ax.set_ylabel('Absolute Error (log scale)', fontsize=12)
    ax.set_title('Approximation Error', fontweight='bold', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, which='both')
    
    # Plot 3: Accuracy vs Cost
    ax = axes[2]
    methods = ['Euler', 'RK4']
    nfes = [nfe_euler, nfe_rk4]
    max_errors = [error_euler.max(), error_rk4.max()]
    
    colors = ['blue', 'red']
    for i, (method, nfe, err) in enumerate(zip(methods, nfes, max_errors)):
        ax.scatter(nfe, err, s=300, c=colors[i], alpha=0.7, 
                  edgecolors='black', linewidth=2, label=method, zorder=5)
        ax.annotate(method, xy=(nfe, err), xytext=(10, 10),
                   textcoords='offset points', fontsize=11, fontweight='bold')
    
    ax.set_xlabel('Number of Function Evaluations (NFE)', fontsize=12)
    ax.set_ylabel('Maximum Error', fontsize=12)
    ax.set_title('Accuracy vs Computational Cost', fontweight='bold', fontsize=13)
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3, which='both')
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*70)
    print("ODE SOLVER COMPARISON")
    print("="*70)
    print(f"Method    | NFE   | Max Error | Mean Error")
    print("-" * 70)
    print(f"Euler     | {nfe_euler:5d} | {error_euler.max():.2e} | {error_euler.mean():.2e}")
    print(f"RK4       | {nfe_rk4:5d} | {error_rk4.max():.2e} | {error_rk4.mean():.2e}")
    print("="*70)
    print("\nKey Takeaways:")
    print("  • RK4 is ~100× more accurate than Euler")
    print("  • RK4 uses 4× more function evaluations per step")
    print("  • For Neural ODEs: Adaptive solvers (dopri5) are best")
    print("  • They automatically balance accuracy vs cost")
    print("="*70)

compare_ode_solvers()

3.5. ODE Solvers: Theory and Comparison

Numerical Integration Methods

To solve \(\frac{dh}{dt} = f(h, t)\), we need numerical methods.

1. Euler Method (1st Order)

Formula: $\(h_{n+1} = h_n + \Delta t \cdot f(h_n, t_n)\)$

Properties:

  • ✅ Simple, fast

  • ❌ Low accuracy: \(O(\Delta t)\) local error

  • ❌ Requires small step size

Global error: \(O(\Delta t)\)

2. Runge-Kutta 4 (RK4) - 4th Order

Formula: $\(k_1 = f(h_n, t_n)\)\( \)\(k_2 = f(h_n + \frac{\Delta t}{2}k_1, t_n + \frac{\Delta t}{2})\)\( \)\(k_3 = f(h_n + \frac{\Delta t}{2}k_2, t_n + \frac{\Delta t}{2})\)\( \)\(k_4 = f(h_n + \Delta t \cdot k_3, t_n + \Delta t)\)\( \)\(h_{n+1} = h_n + \frac{\Delta t}{6}(k_1 + 2k_2 + 2k_3 + k_4)\)$

Properties:

  • ✅ High accuracy: \(O(\Delta t^4)\) local error

  • ✅ Good balance of speed/accuracy

  • ❌ Fixed step size

Global error: \(O(\Delta t^4)\)

3. Adaptive Methods (e.g., Dormand-Prince “dopri5”)

Key idea: Adjust step size based on local error estimate

Algorithm:

  1. Compute two estimates: \(\hat{h}_{n+1}\) (5th order) and \(h_{n+1}\) (4th order)

  2. Error estimate: \(\text{err} = ||\hat{h}_{n+1} - h_{n+1}||\)

  3. If \(\text{err} < \text{tol}\): accept step, maybe increase \(\Delta t\)

  4. If \(\text{err} > \text{tol}\): reject step, decrease \(\Delta t\)

Properties:

  • ✅ Automatically adjusts precision

  • ✅ Efficient for smooth ODEs

  • ✅ Can handle stiff problems

  • ⚠️ Variable computational cost

Solver Comparison

Solver

Order

Adaptive

Speed

Accuracy

Use Case

Euler

1

⚡⚡⚡

Quick prototyping

RK4

4

⚡⚡

⭐⭐⭐

Fixed timestep, moderate accuracy

dopri5

5(4)

⭐⭐⭐⭐⭐

Production, high accuracy

adams

Variable

⚡⚡

⭐⭐⭐⭐

Smooth ODEs, long integration

Choosing a Solver for Neural ODEs

For training:

  • Use dopri5 (default in torchdiffeq)

  • Adaptive stepping handles varying dynamics

  • Set reasonable tolerances: rtol=1e-3, atol=1e-4

For inference:

  • Can use fixed-step rk4 for speed

  • Pre-determine good step size from training

For memory-constrained:

  • Use Euler with small steps

  • Trade accuracy for speed

Tolerance Settings

Relative tolerance (rtol): Error relative to state magnitude Absolute tolerance (atol): Minimum absolute error threshold

\[\text{error}_i \leq \text{atol} + \text{rtol} \times |h_i|\]

Recommendations:

  • Training: rtol=1e-3, atol=1e-4 (faster, slightly less accurate)

  • Evaluation: rtol=1e-5, atol=1e-6 (slower, more accurate)

  • Production: rtol=1e-4, atol=1e-5 (balanced)

# Compare ODE solvers empirically

def compare_ode_solvers():
    """
    Compare different ODE solvers on a simple problem
    Demonstrates accuracy vs speed trade-offs
    """
    
    # Define a simple ODE: dh/dt = -h (exponential decay)
    # True solution: h(t) = h(0) * exp(-t)
    
    def f(t, h):
        return -h
    
    # Initial condition and time span
    h0 = 1.0
    t_span = np.linspace(0, 5, 100)
    
    # True solution
    h_true = h0 * np.exp(-t_span)
    
    # Euler method
    def euler_solve(f, h0, t_span):
        h = np.zeros_like(t_span)
        h[0] = h0
        for i in range(len(t_span) - 1):
            dt = t_span[i+1] - t_span[i]
            h[i+1] = h[i] + dt * f(t_span[i], h[i])
        return h
    
    # RK4 method
    def rk4_solve(f, h0, t_span):
        h = np.zeros_like(t_span)
        h[0] = h0
        for i in range(len(t_span) - 1):
            dt = t_span[i+1] - t_span[i]
            k1 = f(t_span[i], h[i])
            k2 = f(t_span[i] + dt/2, h[i] + dt*k1/2)
            k3 = f(t_span[i] + dt/2, h[i] + dt*k2/2)
            k4 = f(t_span[i] + dt, h[i] + dt*k3)
            h[i+1] = h[i] + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
        return h
    
    # Solve with different methods
    import time
    
    start = time.time()
    h_euler = euler_solve(f, h0, t_span)
    euler_time = time.time() - start
    
    start = time.time()
    h_rk4 = rk4_solve(f, h0, t_span)
    rk4_time = time.time() - start
    
    # Compute errors
    euler_error = np.abs(h_euler - h_true)
    rk4_error = np.abs(h_rk4 - h_true)
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Solutions
    ax = axes[0, 0]
    ax.plot(t_span, h_true, 'k-', linewidth=3, label='True solution', alpha=0.7)
    ax.plot(t_span, h_euler, 'r--', linewidth=2, label='Euler method')
    ax.plot(t_span, h_rk4, 'b--', linewidth=2, label='RK4 method')
    ax.set_xlabel('Time $t$', fontsize=12)
    ax.set_ylabel('$h(t)$', fontsize=12)
    ax.set_title('ODE Solutions: $\\frac{dh}{dt} = -h$', fontweight='bold', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Errors (log scale)
    ax = axes[0, 1]
    ax.semilogy(t_span, euler_error + 1e-10, 'r-', linewidth=2, label='Euler error')
    ax.semilogy(t_span, rk4_error + 1e-10, 'b-', linewidth=2, label='RK4 error')
    ax.set_xlabel('Time $t$', fontsize=12)
    ax.set_ylabel('Absolute Error (log scale)', fontsize=12)
    ax.set_title('Solver Accuracy Comparison', fontweight='bold', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, which='both')
    
    # Plot 3: Step visualization for Euler
    ax = axes[1, 0]
    t_demo = t_span[:10]
    h_demo = h_euler[:10]
    h_true_demo = h_true[:10]
    
    ax.plot(t_demo, h_true_demo, 'k-', linewidth=3, label='True', alpha=0.7)
    ax.plot(t_demo, h_demo, 'ro-', linewidth=2, markersize=8, label='Euler steps')
    
    # Show Euler steps
    for i in range(len(t_demo)-1):
        # Tangent line
        dt = t_demo[i+1] - t_demo[i]
        slope = f(t_demo[i], h_demo[i])
        t_tangent = np.array([t_demo[i], t_demo[i+1]])
        h_tangent = h_demo[i] + slope * (t_tangent - t_demo[i])
        ax.plot(t_tangent, h_tangent, 'r--', alpha=0.5, linewidth=1)
    
    ax.set_xlabel('Time $t$', fontsize=12)
    ax.set_ylabel('$h(t)$', fontsize=12)
    ax.set_title('Euler Method: Following Tangent Lines', fontweight='bold', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Error comparison bar chart
    ax = axes[1, 1]
    methods = ['Euler', 'RK4']
    max_errors = [euler_error.max(), rk4_error.max()]
    mean_errors = [euler_error.mean(), rk4_error.mean()]
    times = [euler_time * 1000, rk4_time * 1000]  # Convert to ms
    
    x = np.arange(len(methods))
    width = 0.25
    
    ax.bar(x - width, max_errors, width, label='Max error', alpha=0.8, color='red')
    ax.bar(x, mean_errors, width, label='Mean error', alpha=0.8, color='blue')
    ax.bar(x + width, [t/100 for t in times], width, label='Time (ms/100)', alpha=0.8, color='green')
    
    ax.set_ylabel('Value', fontsize=12)
    ax.set_title('Method Comparison', fontweight='bold', fontsize=13)
    ax.set_xticks(x)
    ax.set_xticklabels(methods, fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*70)
    print("ODE SOLVER COMPARISON")
    print("="*70)
    print(f"Problem: dh/dt = -h, h(0) = {h0}, t ∈ [0, 5]")
    print(f"Steps: {len(t_span)}")
    print("\nEuler Method (1st order):")
    print(f"  Max error:  {euler_error.max():.6f}")
    print(f"  Mean error: {euler_error.mean():.6f}")
    print(f"  Time:       {euler_time*1000:.2f} ms")
    print("\nRK4 Method (4th order):")
    print(f"  Max error:  {rk4_error.max():.6e}")
    print(f"  Mean error: {rk4_error.mean():.6e}")
    print(f"  Time:       {rk4_time*1000:.2f} ms")
    print("\nAccuracy Improvement: {:.1f}x better".format(euler_error.max() / rk4_error.max()))
    print("Time Cost: {:.1f}x slower".format(rk4_time / euler_time))
    print("="*70)

compare_ode_solvers()
# Neural ODE implementation

class ODEFunc(nn.Module):
    """ODE function: dh/dt = f(h, t)"""
    
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 64),
            nn.Tanh(),
            nn.Linear(64, dim)
        )
        
        # Explicitly depend on time
        self.time_net = nn.Linear(1, dim)
    
    def forward(self, t, h):
        """
        Args:
            t: current time (scalar)
            h: current state (batch_size, dim)
        Returns:
            dh/dt
        """
        # Time-dependent dynamics
        t_vec = torch.ones(h.shape[0], 1).to(h.device) * t
        time_effect = self.time_net(t_vec)
        
        return self.net(h) + time_effect

class ODEBlock(nn.Module):
    """Neural ODE block"""
    
    def __init__(self, odefunc, integration_time=1.0):
        super().__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, integration_time])
    
    def forward(self, x):
        """
        Args:
            x: input (batch_size, dim)
        Returns:
            output after ODE integration
        """
        self.integration_time = self.integration_time.to(x.device)
        
        # Solve ODE from t=0 to t=integration_time
        out = odeint(self.odefunc, x, self.integration_time, 
                    method='dopri5', rtol=1e-3, atol=1e-4)
        
        return out[1]  # Return state at final time

# Test ODE block
dim = 64
odefunc = ODEFunc(dim).to(device)
ode_block = ODEBlock(odefunc, integration_time=1.0).to(device)

# Forward pass
x_test = torch.randn(16, dim).to(device)
y_test = ode_block(x_test)

print(f"Input shape:  {x_test.shape}")
print(f"Output shape: {y_test.shape}")
print(f"✅ Neural ODE block working correctly!")

4. Neural ODE Classifier

Architecture

  1. Downsample: Reduce spatial dimensions (for images)

  2. ODE Block: Continuous transformation

  3. Classifier: Linear layer for prediction

MNIST example:

  • Input: 28×28 images

  • Flatten → 784 dims

  • Reduce to 64 dims

  • ODE block

  • Classify to 10 classes

# Neural ODE Classifier for MNIST

class NeuralODEClassifier(nn.Module):
    """Neural ODE for MNIST classification"""
    
    def __init__(self, input_dim=784, hidden_dim=64, num_classes=10):
        super().__init__()
        
        # Downsampling
        self.downsample = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, hidden_dim)
        )
        
        # Neural ODE
        odefunc = ODEFunc(hidden_dim)
        self.ode_block = ODEBlock(odefunc, integration_time=1.0)
        
        # Classifier
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        """Forward pass"""
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Downsample
        h = self.downsample(x)
        
        # ODE integration
        h = self.ode_block(h)
        
        # Classify
        out = self.classifier(h)
        
        return out

# Initialize model
model = NeuralODEClassifier().to(device)

print("Neural ODE Classifier:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load MNIST data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Visualize samples
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'{label}', fontsize=12)
    ax.axis('off')
plt.suptitle('MNIST Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# Train Neural ODE

def train_model(model, train_loader, test_loader, n_epochs=5, lr=1e-3):
    """Train Neural ODE classifier"""
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    
    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss, correct, total = 0, 0, 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            if (batch_idx + 1) % 100 == 0:
                print(f'  Batch {batch_idx+1}/{len(train_loader)} | '
                      f'Loss: {loss.item():.4f}')
        
        train_acc = 100. * correct / total
        
        # Testing
        model.eval()
        test_correct, test_total = 0, 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                test_correct += pred.eq(target).sum().item()
                test_total += target.size(0)
        
        test_acc = 100. * test_correct / test_total
        
        history['train_loss'].append(train_loss / len(train_loader))
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        
        print(f'Epoch {epoch+1}/{n_epochs} | '
              f'Train Loss: {history["train_loss"][-1]:.4f} | '
              f'Train Acc: {train_acc:.2f}% | '
              f'Test Acc: {test_acc:.2f}%')
        print('-' * 70)
    
    return history

print("Training Neural ODE on MNIST...")
print("="*70)
history = train_model(model, train_loader, test_loader, n_epochs=5, lr=1e-3)
# Visualize training

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], 'o-', linewidth=2, markersize=8, label='Train Loss')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], 'o-', linewidth=2, markersize=8, label='Train Accuracy')
axes[1].plot(history['test_acc'], 's-', linewidth=2, markersize=8, label='Test Accuracy')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Classification Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n✅ Final Test Accuracy: {history['test_acc'][-1]:.2f}%")

5. Summary

Neural ODEs

Continuous-depth models: Replace discrete layers with ODE ✅ Memory efficient: Constant memory via adjoint method ✅ Adaptive computation: ODE solver chooses steps automatically

Key Ideas

  1. Forward pass: Solve ODE \(\frac{dh}{dt} = f_\theta(h, t)\) from \(t=0\) to \(t=T\)

  2. Backward pass: Adjoint method computes gradients without storing states

  3. Implementation: Use torchdiffeq with odeint_adjoint

Comparison

Property

ResNet

Neural ODE

Depth

Discrete layers

Continuous

Memory

\(O(L)\)

\(O(1)\)

Forward cost

\(O(L)\)

Adaptive

Evaluation cost

Fixed

Variable

Implementation

✅ Implemented ODE function and ODE block ✅ Trained Neural ODE classifier on MNIST ✅ Achieved competitive accuracy with constant memory

Advantages

  • Memory efficient for very deep models

  • Adaptive computation (more steps for complex inputs)

  • Continuous representations (smooth trajectories)

Limitations

⚠️ Slower training (backward ODE solve) ⚠️ Numerical errors from ODE solver ⚠️ Harder to debug than discrete models

Next Notebook: 05_3d_vision_introduction.ipynb

References

  • Chen et al. “Neural Ordinary Differential Equations” (NeurIPS 2018)

  • Grathwohl et al. “FFJORD: Free-form Continuous Dynamics” (ICLR 2019)

5.5. Continuous Normalizing Flows (CNF)

Motivation: Change of Variables Formula

For invertible transformation \(z = f(x)\):

\[p_z(z) = p_x(f^{-1}(z)) \left|\det \frac{\partial f^{-1}}{\partial z}\right|\]

Problem: Computing Jacobian determinant is \(O(d^3)\) for \(d\)-dimensional data!

Neural ODE Solution

Key idea: Use continuous transformation via ODE:

\[\frac{dz(t)}{dt} = f_\theta(z(t), t), \quad z(0) = x, \quad z(1) = z\]

Instantaneous change of variables:

\[\log p_t(z(t)) = \log p_0(z(0)) - \int_0^t \text{tr}\left(\frac{\partial f_\theta}{\partial z(s)}\right) ds\]

Final density:

\[\boxed{\log p_1(z(1)) = \log p_0(z(0)) - \int_0^1 \text{tr}\left(\frac{\partial f_\theta}{\partial z(t)}\right) dt}\]

Advantages of CNF

1. Arbitrary Architecture

  • No invertibility constraint on \(f_\theta\)

  • Can use any neural network architecture

  • More expressive than explicit flows (RealNVP, Glow)

2. Efficient Trace Computation

Use Hutchinson’s trace estimator:

\[\text{tr}(A) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0,I)}[\epsilon^T A \epsilon]\]

For Jacobian: $\(\text{tr}\left(\frac{\partial f}{\partial z}\right) = \mathbb{E}_\epsilon\left[\epsilon^T \frac{\partial f}{\partial z} \epsilon\right]\)$

Compute via: $\(\epsilon^T \frac{\partial f}{\partial z} \epsilon = \epsilon^T \frac{\partial}{\partial z}[f^T \epsilon]\)$

using vector-Jacobian product (efficient in automatic differentiation!)

Cost: \(O(d)\) instead of \(O(d^3)\)! ✅

3. Continuous Dynamics

  • Smooth transformations

  • Can model complex distributions

  • Training stability

CNF Training Objective

Maximum likelihood:

\[\max_\theta \mathbb{E}_{x \sim p_{data}}\left[\log p_1(f_\theta(x)) - \int_0^1 \text{tr}\left(\frac{\partial f_\theta}{\partial z(t)}\right) dt\right]\]

where \(z(1) = f_\theta(x)\) is obtained by solving ODE forward

Prior: Typically \(p_1(z) = \mathcal{N}(0, I)\)

Comparison with Other Generative Models

Model

Pros

Cons

VAE

Fast sampling

Approximate inference, blurry samples

GAN

Sharp samples

Mode collapse, training instability

Explicit Flow

Exact likelihood

Architectural constraints (invertible)

CNF

Exact likelihood, flexible architecture

Slow sampling (ODE solve)

Practical Considerations

Training:

  • Use adaptive ODE solver (dopri5)

  • Tolerance: rtol=1e-5, atol=1e-7

  • Hutchinson estimator with 1-2 samples

Sampling:

  • Can use faster fixed-step solvers (RK4)

  • Or reduce tolerance for speed

  • Trade-off: quality vs speed

Regularization:

  • Add regularization to encourage simple dynamics

  • Kinetic energy: \(\int_0^1 ||f_\theta(z(t), t)||^2 dt\)

  • Jacobian Frobenius norm: \(\int_0^1 ||\frac{\partial f}{\partial z}||_F^2 dt\)

# Continuous Normalizing Flow Implementation

class CNF(nn.Module):
    """
    Continuous Normalizing Flow using Neural ODE
    
    Implements density estimation via change of variables:
    log p(x) = log p(z) - ∫ tr(∂f/∂z) dt
    """
    
    def __init__(self, dim, hidden_dim=64):
        super().__init__()
        self.dim = dim
        
        # Dynamics network
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, t, states):
        """
        ODE dynamics with trace computation
        
        states = [z, log_p]
        Returns d[z, log_p]/dt
        """
        z = states[0]
        batch_size = z.shape[0]
        
        # Compute f(z, t)
        with torch.enable_grad():
            z.requires_grad_(True)
            dz_dt = self.net(z)
            
            # Hutchinson trace estimator
            # tr(∂f/∂z) ≈ ε^T (∂f/∂z) ε where ε ~ N(0,I)
            epsilon = torch.randn_like(z)
            
            # Compute ε^T (∂f/∂z) ε via vector-Jacobian product
            # This is equivalent to ε^T ∂(f^T ε)/∂z
            vjp = torch.autograd.grad(
                (dz_dt * epsilon).sum(), z,
                create_graph=True, retain_graph=True
            )[0]
            
            trace_estimate = (vjp * epsilon).sum(dim=1, keepdim=True)
        
        # Change in log probability
        dlog_p_dt = -trace_estimate
        
        return (dz_dt, dlog_p_dt)
    
    def log_prob(self, x, integration_time=1.0):
        """
        Compute log probability of data x
        
        Returns:
            log_prob: log p(x)
            z: latent encoding
        """
        batch_size = x.shape[0]
        
        # Initial log probability (starts at 0)
        log_p0 = torch.zeros(batch_size, 1).to(x.device)
        
        # Solve ODE forward
        states = (x, log_p0)
        
        # Note: In practice, use odeint from torchdiffeq
        # For demonstration, this shows the concept
        # z_T, log_pT = odeint(self, states, torch.tensor([0., integration_time]))
        
        # Here we just return conceptual output
        # In real implementation, integrate the ODE
        z = x  # Would be transformed via ODE
        log_p = log_p0  # Would be updated via integral
        
        # Prior: standard Gaussian
        log_prior = -0.5 * (z**2).sum(dim=1, keepdim=True) - \
                    0.5 * self.dim * np.log(2 * np.pi)
        
        # Final log probability
        log_prob = log_prior + log_p
        
        return log_prob, z
    
    def sample(self, num_samples, integration_time=1.0):
        """
        Generate samples by solving ODE backward from prior
        
        1. Sample z ~ N(0, I)
        2. Solve ODE backward: z(0) ← z(T)
        3. Return x = z(0)
        """
        device = next(self.parameters()).device
        
        # Sample from prior
        z = torch.randn(num_samples, self.dim).to(device)
        
        # Note: Solve ODE backward
        # In practice: x = odeint(self, z, torch.tensor([integration_time, 0.]))
        
        return z  # Would be transformed via backward ODE

# Demonstrate CNF concept
print("="*70)
print("CONTINUOUS NORMALIZING FLOW (CNF)")
print("="*70)
print("\nKey Components:")
print("  1. Forward ODE: x → z (data to latent)")
print("     • Transform data to simple prior distribution")
print("     • Track log probability change via trace computation")
print("\n  2. Backward ODE: z → x (latent to data)")
print("     • Sample from prior: z ~ N(0, I)")
print("     • Solve ODE backward to generate data")
print("\n  3. Trace Estimation:")
print("     • Use Hutchinson estimator: tr(J) ≈ ε^T J ε")
print("     • Computed efficiently via vector-Jacobian product")
print("     • Cost: O(d) instead of O(d³)")
print("\nAdvantages:")
print("  ✅ Exact likelihood (unlike VAE)")
print("  ✅ Flexible architecture (unlike explicit flows)")
print("  ✅ Stable training (unlike GAN)")
print("  ✅ Continuous transformations")
print("\nDisadvantages:")
print("  ⚠️  Slow sampling (requires ODE solve)")
print("  ⚠️  Memory for adjoint method")
print("="*70)

# Create example CNF
cnf = CNF(dim=2, hidden_dim=32)
print(f"\nCNF Model:")
print(f"  Input dimension: 2")
print(f"  Hidden dimension: 32")
print(f"  Parameters: {sum(p.numel() for p in cnf.parameters())}")

5.6. Advanced Topics and Practical Considerations

Stiffness in Neural ODEs

Stiff ODEs have dynamics at multiple time scales:

  • Fast dynamics (small \(\Delta t\) needed)

  • Slow dynamics (large \(\Delta t\) possible)

Problem: Explicit solvers (Euler, RK4) require tiny steps for stability Solution: Use implicit solvers or regularization

Regularization Techniques:

1. Spectral Normalization Constrain Lipschitz constant of \(f_\theta\): $\(||f_\theta(z_1, t) - f_\theta(z_2, t)|| \leq L ||z_1 - z_2||\)$

Prevents explosive gradients and improves ODE solver stability

2. Jacobian Regularization Add penalty to encourage smooth dynamics: $\(\mathcal{L}_{reg} = \lambda \mathbb{E}\left[||\frac{\partial f_\theta}{\partial z}||_F^2\right]\)$

3. Kinetic Energy Regularization Penalize large velocities: $\(\mathcal{L}_{kinetic} = \lambda \int_0^T ||f_\theta(z(t), t)||^2 dt\)$

Augmented Neural ODEs

Motivation: Standard Neural ODEs may not be expressive enough

Solution: Augment state with extra dimensions:

\[\begin{split}\frac{d}{dt}\begin{bmatrix} z(t) \\ a(t) \end{bmatrix} = f_\theta\left(\begin{bmatrix} z(t) \\ a(t) \end{bmatrix}, t\right)\end{split}\]

where:

  • \(z(t) \in \mathbb{R}^d\): original state

  • \(a(t) \in \mathbb{R}^p\): augmented dimensions (typically \(p \ll d\))

Benefits:

  • Increased expressiveness

  • Can learn more complex trajectories

  • Better approximation of diffeomorphisms

Initialization: \(a(0) = \mathbf{0}\) (zeros)

Second-Order Neural ODEs

Standard (first-order): $\(\frac{dz}{dt} = f_\theta(z, t)\)$

Second-order (Hamiltonian-inspired): $\(\frac{dz}{dt} = v, \quad \frac{dv}{dt} = f_\theta(z, v, t)\)$

Advantages:

  • More stable dynamics

  • Better long-term behavior

  • Physical interpretability (position + velocity)

Latent ODEs for Irregular Time Series

Problem: Time series with irregular sampling, missing data

Solution: Encode to latent ODE, then decode

Architecture:

  1. Recognition network: \(h_0 = \text{RNN}(\{(t_i, x_i)\})\)

  2. Latent ODE: \(h(t) = \text{ODESolve}(f_\theta, h_0, [0, T])\)

  3. Decoder: \(\hat{x}_i = \text{Decoder}(h(t_i))\)

Key insight: ODE is continuous, can predict at any \(t\)!

Applications:

  • Medical time series (irregular patient visits)

  • Climate data (varying measurement intervals)

  • Financial data (non-uniform trading times)

Neural CDE (Controlled Differential Equations)

Extension: Let input data control the ODE

\[dh(t) = f_\theta(h(t)) dX(t)\]

where \(X(t)\) is a path (interpolated from data)

Properties:

  • Naturally handles irregular sampling

  • Path-dependent dynamics

  • State-of-the-art for time series

Practical Training Tips

1. Gradient Clipping

  • Clip gradients to prevent explosion

  • Typical: max_norm=1.0

2. Learning Rate Schedule

  • Start with higher LR: 1e-3

  • Decay to 1e-5 or 1e-6

  • Cosine annealing works well

3. Monitoring NFE

  • Track number of function evaluations

  • Should be stable during training

  • Sudden increases indicate instability

4. Tolerance Scheduling

  • Start with relaxed tolerance: rtol=1e-3

  • Gradually tighten: final rtol=1e-5

  • Balances speed early, accuracy late

5. Batch Size

  • Smaller batches often better

  • ODE solver adaptive to each sample

  • Larger batches → more NFE variation

When to Use Neural ODEs

Good for: ✅ Time series with irregular sampling ✅ Continuous normalizing flows ✅ Memory-constrained scenarios ✅ Physically-inspired models ✅ Generative modeling with exact likelihood

Not ideal for: ❌ Need for very fast inference ❌ Extremely high-dimensional data (images) ❌ When discrete depth is sufficient ❌ Limited computational budget for training

4.5. Advanced Application: Continuous Normalizing Flows (CNF)

Motivation: Generative Modeling

Goal: Learn complex probability distributions \(p(x)\)

Approach: Transform simple base distribution \(p(z)\) to data distribution \(p(x)\)

Normalizing Flows Recap

Discrete flow: $\(z_0 \xrightarrow{f_1} z_1 \xrightarrow{f_2} \cdots \xrightarrow{f_T} z_T = x\)$

Change of variables: $\(\log p(x) = \log p(z_0) - \sum_{t=1}^T \log \left|\det \frac{\partial f_t}{\partial z_{t-1}}\right|\)$

Problem: Computing determinant is \(O(d^3)\) for \(d\)-dimensional data!

Continuous Normalizing Flows (FFJORD)

Idea: Use Neural ODE for continuous transformation

\[\frac{dz(t)}{dt} = f_\theta(z(t), t), \quad z(0) \sim p_0, \quad z(1) = x\]

Instantaneous change of variables formula:

\[\frac{d \log p(z(t))}{dt} = -\text{tr}\left(\frac{\partial f_\theta}{\partial z(t)}\right)\]

This is the trace of the Jacobian, much cheaper than determinant!

Complete Density Formula

Integrating from \(t=0\) to \(t=1\):

\[\log p(x) = \log p(z(0)) - \int_0^1 \text{tr}\left(\frac{\partial f_\theta(z(t), t)}{\partial z(t)}\right) dt\]

Computing the Trace Efficiently

Problem: Computing trace still requires computing all diagonal elements of Jacobian

Solution (Hutchinson’s Trace Estimator):

For random vector \(\epsilon \sim \mathcal{N}(0, I)\): $\(\text{tr}(A) = \mathbb{E}_\epsilon[\epsilon^T A \epsilon]\)$

Implementation: $\(\text{tr}\left(\frac{\partial f}{\partial z}\right) \approx \epsilon^T \frac{\partial f}{\partial z} \epsilon = \epsilon^T \frac{\partial (f^T \epsilon)}{\partial z}\)$

Only requires one vector-Jacobian product (via automatic differentiation)!

Cost: \(O(d)\) instead of \(O(d^3)\)!

Advantages of CNF

Exact likelihood: Can compute \(p(x)\) exactly (unlike VAE)

Free-form Jacobians: No architectural constraints (unlike RealNVP, Glow)

Scalable: \(O(d)\) cost via trace estimator

Invertible: Can sample and compute likelihood

CNF Training

Objective: Maximize log-likelihood

\[\max_\theta \mathbb{E}_{x \sim p_{\text{data}}}[\log p_\theta(x)]\]

Algorithm:

  1. Sample data \(x\)

  2. Solve ODE backward: \(x = z(1) \to z(0)\)

  3. Accumulate trace integral: \(\int_1^0 \text{tr}(\partial f / \partial z) dt\)

  4. Compute loss: \(-\log p(z(0)) + \text{trace integral}\)

  5. Backprop through adjoint method

Applications

  • Density estimation: Model complex distributions

  • Generative modeling: Sample realistic data

  • Variational inference: Flexible posteriors

  • Anomaly detection: Detect out-of-distribution samples