import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)
1. Motivation: Neural Networks as Kernel Methods¶
Classical Kernel Methods¶
Linear model in feature space: $\(f(x) = \langle w, \phi(x) \rangle\)$
Kernel: \(k(x, x') = \langle \phi(x), \phi(x') \rangle\)
Neural Network at Initialization¶
Randomly initialized network \(f(x; \theta_0)\) defines implicit feature map.
Key Question: What kernel does it correspond to?
NTK Insight (Jacot et al., 2018)¶
In the infinite width limit, neural networks:
Define a fixed kernel (NTK)
Train via kernel gradient descent
Converge to global minimum
Revolutionary: Deep learning = kernel method in disguise!
2. Linearized Training Dynamics¶
Gradient Descent¶
Update rule: $\(\theta_t = \theta_{t-1} - \eta \nabla_{\theta} L(\theta_{t-1})\)$
First-Order Taylor Expansion¶
Network output at time \(t\): $\(f(x; \theta_t) \approx f(x; \theta_0) + \langle \nabla_{\theta} f(x; \theta_0), \theta_t - \theta_0 \rangle\)$
Neural Tangent Kernel Definition¶
Key Property: In infinite width, \(\Theta\) stays constant during training!
Training Dynamics (Continuous Time)¶
This is kernel gradient descent with kernel \(\Theta\).
3. Computing NTK for Simple Networks¶
Two-Layer Network¶
where \(m\) = width, \(\sigma\) = activation.
NTK (Infinite Width)¶
For ReLU: $\(\Theta_{ReLU}(x, x') = \frac{\|x\| \|x'\|}{2\pi} \left( \cos(\theta)(\pi - \theta) + \sin(\theta) \right)\)$
where \(\theta = \arccos\left(\frac{x^T x'}{\|x\| \|x'\|}\right)\)
def ntk_relu_2layer(X1, X2=None):
"""Compute NTK for 2-layer ReLU network."""
if X2 is None:
X2 = X1
# Compute dot products and norms
dot = X1 @ X2.T
norm1 = np.linalg.norm(X1, axis=1, keepdims=True)
norm2 = np.linalg.norm(X2, axis=1, keepdims=True)
norms = norm1 @ norm2.T
# Compute angles
cos_theta = np.clip(dot / (norms + 1e-8), -1, 1)
theta = np.arccos(cos_theta)
# NTK formula
ntk = norms / (2 * np.pi) * (np.sin(theta) + (np.pi - theta) * cos_theta)
return ntk
# Test on simple data
X = np.random.randn(5, 3)
K = ntk_relu_2layer(X)
print("NTK matrix:")
print(K)
print(f"\nShape: {K.shape}")
print(f"Symmetric: {np.allclose(K, K.T)}")
print(f"Positive definite: {np.all(np.linalg.eigvals(K) > -1e-10)}")
4. Empirical NTK via Jacobian¶
def compute_ntk_empirical(model, X1, X2=None):
"""Compute empirical NTK via Jacobian."""
if X2 is None:
X2 = X1
X1 = torch.FloatTensor(X1)
X2 = torch.FloatTensor(X2)
# Compute Jacobians
def get_jacobian(x):
jac = []
for xi in x:
xi = xi.unsqueeze(0).requires_grad_()
y = model(xi).squeeze()
grads = torch.autograd.grad(y, model.parameters(), create_graph=False)
jac.append(torch.cat([g.flatten() for g in grads]))
return torch.stack(jac)
J1 = get_jacobian(X1)
J2 = get_jacobian(X2)
# NTK = J1 @ J2^T
ntk = (J1 @ J2.T).detach().numpy()
return ntk
# Create simple network
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# He initialization
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.fc(x)
# Compare theoretical vs empirical NTK
X_test = np.random.randn(10, 3)
# Theoretical
K_theory = ntk_relu_2layer(X_test)
# Empirical (average over multiple initializations)
K_emp_list = []
for _ in range(10):
net = SimpleNet(3, 1000) # Large width
K_emp_list.append(compute_ntk_empirical(net, X_test))
K_emp = np.mean(K_emp_list, axis=0)
# Compare
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
im1 = axes[0].imshow(K_theory, cmap='viridis')
axes[0].set_title('Theoretical NTK', fontsize=13)
plt.colorbar(im1, ax=axes[0])
im2 = axes[1].imshow(K_emp, cmap='viridis')
axes[1].set_title('Empirical NTK (width=1000)', fontsize=13)
plt.colorbar(im2, ax=axes[1])
diff = np.abs(K_theory - K_emp)
im3 = axes[2].imshow(diff, cmap='Reds')
axes[2].set_title('Absolute Difference', fontsize=13)
plt.colorbar(im3, ax=axes[2])
plt.tight_layout()
plt.show()
print(f"Mean absolute error: {diff.mean():.6f}")
print("As width → ∞, empirical → theoretical")
5. Training Dynamics: Kernel Gradient Descent¶
# Generate regression data
n_train = 50
X_train = np.random.uniform(-3, 3, (n_train, 1))
y_train = np.sin(X_train) + 0.1 * np.random.randn(n_train, 1)
X_test = np.linspace(-4, 4, 100).reshape(-1, 1)
y_test = np.sin(X_test)
# Compute NTK
K_train = ntk_relu_2layer(X_train)
K_test_train = ntk_relu_2layer(X_test, X_train)
# Kernel gradient descent prediction (closed form)
def kernel_prediction(K_train, K_test_train, y_train, reg=1e-6):
"""Prediction via kernel ridge regression."""
alpha = np.linalg.solve(K_train + reg * np.eye(len(K_train)), y_train)
y_pred = K_test_train @ alpha
return y_pred
y_pred_kernel = kernel_prediction(K_train, K_test_train, y_train)
# Train actual neural network
net = SimpleNet(1, 5000).train()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.FloatTensor(y_train)
X_test_t = torch.FloatTensor(X_test)
predictions = []
for epoch in range(500):
optimizer.zero_grad()
loss = ((net(X_train_t) - y_train_t) ** 2).mean()
loss.backward()
optimizer.step()
if epoch % 50 == 0:
with torch.no_grad():
pred = net(X_test_t).numpy()
predictions.append(pred)
# Plot comparison
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(X_train, y_train, s=50, alpha=0.6, label='Training data')
plt.plot(X_test, y_test, 'g--', linewidth=2, label='True function')
plt.plot(X_test, y_pred_kernel, 'r-', linewidth=2, label='NTK prediction')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Kernel Prediction', fontsize=13)
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.scatter(X_train, y_train, s=50, alpha=0.6, label='Training data')
plt.plot(X_test, y_test, 'g--', linewidth=2, label='True function')
for i, pred in enumerate(predictions):
alpha = 0.3 + 0.7 * (i / len(predictions))
plt.plot(X_test, pred, 'b-', alpha=alpha, linewidth=1.5)
plt.plot(X_test, predictions[-1], 'b-', linewidth=2, label='NN (final)')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Neural Network Training', fontsize=13)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Wide networks behave like kernel methods!")
6. NTK Evolution During Training¶
# Track NTK throughout training
def track_ntk_evolution(width, n_epochs=1000):
net = SimpleNet(1, width)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
X_sample = np.random.randn(20, 1)
X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.FloatTensor(y_train)
# Initial NTK
K_init = compute_ntk_empirical(net, X_sample)
ntk_changes = []
for epoch in range(n_epochs):
optimizer.zero_grad()
loss = ((net(X_train_t) - y_train_t) ** 2).mean()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
K_current = compute_ntk_empirical(net, X_sample)
change = np.linalg.norm(K_current - K_init, 'fro') / np.linalg.norm(K_init, 'fro')
ntk_changes.append(change)
return ntk_changes
# Compare different widths
widths = [50, 200, 1000, 5000]
results = {}
for width in widths:
print(f"Training width={width}...")
results[width] = track_ntk_evolution(width, n_epochs=500)
# Plot
plt.figure(figsize=(10, 6))
epochs = np.arange(0, 500, 100)
for width in widths:
plt.plot(epochs, results[width], 'o-', linewidth=2, markersize=6, label=f'width={width}')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Relative NTK Change', fontsize=12)
plt.title('NTK Stability During Training', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nKey observation: As width increases, NTK becomes more stable (lazy training regime)")
7. Theoretical Implications¶
Convergence Guarantee¶
For MSE loss, gradient descent converges to global minimum at rate: $\(\|f_t - f^*\| \leq e^{-\eta \lambda_{min}(\Theta) t} \|f_0 - f^*\|\)$
where \(\lambda_{min}(\Theta)\) is smallest eigenvalue of NTK.
Generalization¶
NTK connects to kernel ridge regression, so standard kernel theory applies:
RKHS norm bounds generalization
Implicit regularization from kernel structure
Lazy Training¶
In infinite width:
Parameters barely move from initialization
Network linearizes around initialization
All power comes from kernel, not feature learning
Feature Learning (Finite Width)¶
Finite width networks:
NTK changes during training
Learn features (non-linear regime)
Can outperform fixed kernel methods
Summary¶
Key Takeaways:¶
NTK = kernel induced by neural network at initialization
Infinite width limit: NTK constant, training = kernel gradient descent
Convergence: Global minimum guaranteed (MSE loss)
Lazy training: Wide networks don’t learn features, rely on initialization
Feature learning: Finite width networks escape kernel regime
When NTK Applies:¶
Very wide networks
Small learning rates
Early training stages
When NTK Breaks Down:¶
Moderate/small width
Large learning rates
Deep networks (NTK changes across layers)
Practical Insights:¶
Width-depth tradeoff: Wide shallow vs narrow deep
Initialization matters: Sets the kernel
Feature learning vs kernel: Finite width crucial for performance
Further Reading:¶
Jacot et al. (2018) - “Neural Tangent Kernel”
Lee et al. (2019) - “Wide Neural Networks of Any Depth”
Chizat & Bach (2018) - “On Lazy Training”
Next Steps:¶
03_rademacher_complexity.ipynb - Generalization bounds
04_pac_bayes_theory.ipynb - Bayesian perspective
Phase 6 neural networks for practical deep learning