Chapter 3: Solving the Heat EquationΒΆ
Separation of Variables: Turning One Hard Problem into Two Easy OnesΒΆ
Partial differential equations are notoriously difficult to solve because the unknown function depends on multiple variables (space and time in the heat equation). Separation of variables is a classical technique that assumes the solution factors into a product of single-variable functions: \(T(x,t) = X(x) \cdot \Theta(t)\). Substituting this into the PDE and dividing yields two separate ordinary differential equations β one purely in \(x\) and one purely in \(t\) β which are far easier to solve individually.
The magic is that both sides of the separated equation must equal the same constant \(-\lambda\) (called the separation constant), since one side depends only on \(x\) and the other only on \(t\). The spatial equation produces sinusoidal modes \(\sin(n\pi x / L)\), and the temporal equation produces exponential decays \(e^{-\alpha \lambda_n t}\). In machine learning, this factorization idea resonates with matrix factorization (decomposing data into spatial and temporal components) and with how convolutional layers separate spatial feature extraction from temporal processing in sequence models.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)
np.set_printoptions(precision=3, suppress=True)
The MethodΒΆ
Substitute T(x,t) = X(x)Ξ(t) into heat equation:
Divide both sides by XΞ:
Both sides must equal the same constant!
This gives us two ODEs:
Ξβ(t) = -Ξ±Ξ»Ξ(t)
Xββ(x) = -Ξ»X(x)
def show_separation():
"""Demonstrate separation of variables."""
x = np.linspace(0, np.pi, 100)
t_vals = [0, 0.5, 1.0, 2.0]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for idx, t in enumerate(t_vals):
ax = axes[idx//2, idx%2]
# Solution: sum of modes
T = np.zeros_like(x)
for n in range(1, 6):
lam = n**2
X = np.sin(n * x)
Theta = np.exp(-lam * t)
T += X * Theta / n**2
ax.plot(x, T, 'b-', linewidth=2)
ax.set_title(f't = {t:.1f}', fontweight='bold')
ax.set_xlabel('x')
ax.set_ylabel('T(x,t)')
ax.grid(True, alpha=0.3)
ax.set_ylim(-0.5, 3)
plt.tight_layout()
plt.show()
print("Each mode decays exponentially in time!")
print("Higher frequencies decay faster.")
show_separation()
Boundary ConditionsΒΆ
For a rod of length L with fixed temperatures at ends:
T(0, t) = 0
T(L, t) = 0
This gives us: X(x) = sin(nΟx/L)
Eigenvalues: Ξ»β = (nΟ/L)Β²
General SolutionΒΆ
Each term is a mode that decays exponentially!
SummaryΒΆ
Separation of variables converts PDE β ODEs
Key steps:
Assume T(x,t) = X(x)Ξ(t)
Separate into two ODEs
Apply boundary conditions
Sum up all modes
Next: Fourier series and complex exponentials!