import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import seaborn as sns
from scipy.spatial.distance import pdist, squareform

sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)

1. Motivation: Neural Networks as Kernel Methods

Classical Kernel Methods

Linear model in feature space: $\(f(x) = \langle w, \phi(x) \rangle\)$

Kernel: \(k(x, x') = \langle \phi(x), \phi(x') \rangle\)

Neural Network at Initialization

Randomly initialized network \(f(x; \theta_0)\) defines implicit feature map.

Key Question: What kernel does it correspond to?

NTK Insight (Jacot et al., 2018)

In the infinite width limit, neural networks:

  1. Define a fixed kernel (NTK)

  2. Train via kernel gradient descent

  3. Converge to global minimum

Revolutionary: Deep learning = kernel method in disguise!

2. Linearized Training Dynamics

Gradient Descent

Update rule: $\(\theta_t = \theta_{t-1} - \eta \nabla_{\theta} L(\theta_{t-1})\)$

First-Order Taylor Expansion

Network output at time \(t\): $\(f(x; \theta_t) \approx f(x; \theta_0) + \langle \nabla_{\theta} f(x; \theta_0), \theta_t - \theta_0 \rangle\)$

Neural Tangent Kernel Definition

\[\Theta(x, x') = \langle \nabla_{\theta} f(x; \theta_0), \nabla_{\theta} f(x'; \theta_0) \rangle\]

Key Property: In infinite width, \(\Theta\) stays constant during training!

Training Dynamics (Continuous Time)

\[\frac{d}{dt} f(x; \theta_t) = -\int \Theta(x, x') \frac{\partial L}{\partial f(x')} dx'\]

This is kernel gradient descent with kernel \(\Theta\).

3. Computing NTK for Simple Networks

Two-Layer Network

\[f(x; W, a) = \frac{1}{\sqrt{m}} \sum_{j=1}^m a_j \sigma(w_j^T x)\]

where \(m\) = width, \(\sigma\) = activation.

NTK (Infinite Width)

\[\Theta(x, x') = \mathbb{E}_{w}[\sigma'(w^T x) \sigma'(w^T x') \cdot x^T x']\]

For ReLU: $\(\Theta_{ReLU}(x, x') = \frac{\|x\| \|x'\|}{2\pi} \left( \cos(\theta)(\pi - \theta) + \sin(\theta) \right)\)$

where \(\theta = \arccos\left(\frac{x^T x'}{\|x\| \|x'\|}\right)\)

def ntk_relu_2layer(X1, X2=None):
    """Compute NTK for 2-layer ReLU network."""
    if X2 is None:
        X2 = X1
    
    # Compute dot products and norms
    dot = X1 @ X2.T
    norm1 = np.linalg.norm(X1, axis=1, keepdims=True)
    norm2 = np.linalg.norm(X2, axis=1, keepdims=True)
    norms = norm1 @ norm2.T
    
    # Compute angles
    cos_theta = np.clip(dot / (norms + 1e-8), -1, 1)
    theta = np.arccos(cos_theta)
    
    # NTK formula
    ntk = norms / (2 * np.pi) * (np.sin(theta) + (np.pi - theta) * cos_theta)
    
    return ntk

# Test on simple data
X = np.random.randn(5, 3)
K = ntk_relu_2layer(X)

print("NTK matrix:")
print(K)
print(f"\nShape: {K.shape}")
print(f"Symmetric: {np.allclose(K, K.T)}")
print(f"Positive definite: {np.all(np.linalg.eigvals(K) > -1e-10)}")

4. Empirical NTK via Jacobian

def compute_ntk_empirical(model, X1, X2=None):
    """Compute empirical NTK via Jacobian."""
    if X2 is None:
        X2 = X1
    
    X1 = torch.FloatTensor(X1)
    X2 = torch.FloatTensor(X2)
    
    # Compute Jacobians
    def get_jacobian(x):
        jac = []
        for xi in x:
            xi = xi.unsqueeze(0).requires_grad_()
            y = model(xi).squeeze()
            grads = torch.autograd.grad(y, model.parameters(), create_graph=False)
            jac.append(torch.cat([g.flatten() for g in grads]))
        return torch.stack(jac)
    
    J1 = get_jacobian(X1)
    J2 = get_jacobian(X2)
    
    # NTK = J1 @ J2^T
    ntk = (J1 @ J2.T).detach().numpy()
    
    return ntk

# Create simple network
class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        # He initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.fc(x)

# Compare theoretical vs empirical NTK
X_test = np.random.randn(10, 3)

# Theoretical
K_theory = ntk_relu_2layer(X_test)

# Empirical (average over multiple initializations)
K_emp_list = []
for _ in range(10):
    net = SimpleNet(3, 1000)  # Large width
    K_emp_list.append(compute_ntk_empirical(net, X_test))
K_emp = np.mean(K_emp_list, axis=0)

# Compare
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

im1 = axes[0].imshow(K_theory, cmap='viridis')
axes[0].set_title('Theoretical NTK', fontsize=13)
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(K_emp, cmap='viridis')
axes[1].set_title('Empirical NTK (width=1000)', fontsize=13)
plt.colorbar(im2, ax=axes[1])

diff = np.abs(K_theory - K_emp)
im3 = axes[2].imshow(diff, cmap='Reds')
axes[2].set_title('Absolute Difference', fontsize=13)
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

print(f"Mean absolute error: {diff.mean():.6f}")
print("As width → ∞, empirical → theoretical")

5. Training Dynamics: Kernel Gradient Descent

# Generate regression data
n_train = 50
X_train = np.random.uniform(-3, 3, (n_train, 1))
y_train = np.sin(X_train) + 0.1 * np.random.randn(n_train, 1)

X_test = np.linspace(-4, 4, 100).reshape(-1, 1)
y_test = np.sin(X_test)

# Compute NTK
K_train = ntk_relu_2layer(X_train)
K_test_train = ntk_relu_2layer(X_test, X_train)

# Kernel gradient descent prediction (closed form)
def kernel_prediction(K_train, K_test_train, y_train, reg=1e-6):
    """Prediction via kernel ridge regression."""
    alpha = np.linalg.solve(K_train + reg * np.eye(len(K_train)), y_train)
    y_pred = K_test_train @ alpha
    return y_pred

y_pred_kernel = kernel_prediction(K_train, K_test_train, y_train)

# Train actual neural network
net = SimpleNet(1, 5000).train()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.FloatTensor(y_train)
X_test_t = torch.FloatTensor(X_test)

predictions = []
for epoch in range(500):
    optimizer.zero_grad()
    loss = ((net(X_train_t) - y_train_t) ** 2).mean()
    loss.backward()
    optimizer.step()
    
    if epoch % 50 == 0:
        with torch.no_grad():
            pred = net(X_test_t).numpy()
        predictions.append(pred)

# Plot comparison
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.scatter(X_train, y_train, s=50, alpha=0.6, label='Training data')
plt.plot(X_test, y_test, 'g--', linewidth=2, label='True function')
plt.plot(X_test, y_pred_kernel, 'r-', linewidth=2, label='NTK prediction')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Kernel Prediction', fontsize=13)
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(X_train, y_train, s=50, alpha=0.6, label='Training data')
plt.plot(X_test, y_test, 'g--', linewidth=2, label='True function')
for i, pred in enumerate(predictions):
    alpha = 0.3 + 0.7 * (i / len(predictions))
    plt.plot(X_test, pred, 'b-', alpha=alpha, linewidth=1.5)
plt.plot(X_test, predictions[-1], 'b-', linewidth=2, label='NN (final)')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Neural Network Training', fontsize=13)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Wide networks behave like kernel methods!")

6. NTK Evolution During Training

# Track NTK throughout training
def track_ntk_evolution(width, n_epochs=1000):
    net = SimpleNet(1, width)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    
    X_sample = np.random.randn(20, 1)
    X_train_t = torch.FloatTensor(X_train)
    y_train_t = torch.FloatTensor(y_train)
    
    # Initial NTK
    K_init = compute_ntk_empirical(net, X_sample)
    ntk_changes = []
    
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        loss = ((net(X_train_t) - y_train_t) ** 2).mean()
        loss.backward()
        optimizer.step()
        
        if epoch % 100 == 0:
            K_current = compute_ntk_empirical(net, X_sample)
            change = np.linalg.norm(K_current - K_init, 'fro') / np.linalg.norm(K_init, 'fro')
            ntk_changes.append(change)
    
    return ntk_changes

# Compare different widths
widths = [50, 200, 1000, 5000]
results = {}

for width in widths:
    print(f"Training width={width}...")
    results[width] = track_ntk_evolution(width, n_epochs=500)

# Plot
plt.figure(figsize=(10, 6))
epochs = np.arange(0, 500, 100)
for width in widths:
    plt.plot(epochs, results[width], 'o-', linewidth=2, markersize=6, label=f'width={width}')

plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Relative NTK Change', fontsize=12)
plt.title('NTK Stability During Training', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nKey observation: As width increases, NTK becomes more stable (lazy training regime)")

7. Theoretical Implications

Convergence Guarantee

For MSE loss, gradient descent converges to global minimum at rate: $\(\|f_t - f^*\| \leq e^{-\eta \lambda_{min}(\Theta) t} \|f_0 - f^*\|\)$

where \(\lambda_{min}(\Theta)\) is smallest eigenvalue of NTK.

Generalization

NTK connects to kernel ridge regression, so standard kernel theory applies:

  • RKHS norm bounds generalization

  • Implicit regularization from kernel structure

Lazy Training

In infinite width:

  • Parameters barely move from initialization

  • Network linearizes around initialization

  • All power comes from kernel, not feature learning

Feature Learning (Finite Width)

Finite width networks:

  • NTK changes during training

  • Learn features (non-linear regime)

  • Can outperform fixed kernel methods

Summary

Key Takeaways:

  1. NTK = kernel induced by neural network at initialization

  2. Infinite width limit: NTK constant, training = kernel gradient descent

  3. Convergence: Global minimum guaranteed (MSE loss)

  4. Lazy training: Wide networks don’t learn features, rely on initialization

  5. Feature learning: Finite width networks escape kernel regime

When NTK Applies:

  • Very wide networks

  • Small learning rates

  • Early training stages

When NTK Breaks Down:

  • Moderate/small width

  • Large learning rates

  • Deep networks (NTK changes across layers)

Practical Insights:

  • Width-depth tradeoff: Wide shallow vs narrow deep

  • Initialization matters: Sets the kernel

  • Feature learning vs kernel: Finite width crucial for performance

Further Reading:

  • Jacot et al. (2018) - “Neural Tangent Kernel”

  • Lee et al. (2019) - “Wide Neural Networks of Any Depth”

  • Chizat & Bach (2018) - “On Lazy Training”

Next Steps:

  • 03_rademacher_complexity.ipynb - Generalization bounds

  • 04_pac_bayes_theory.ipynb - Bayesian perspective

  • Phase 6 neural networks for practical deep learning