import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Bayesian Neural NetworksΒΆ

Posterior:ΒΆ

\[p(w | \mathcal{D}) = \frac{p(\mathcal{D} | w) p(w)}{p(\mathcal{D})}\]

Predictive:ΒΆ

\[p(y^* | x^*, \mathcal{D}) = \int p(y^* | x^*, w) p(w | \mathcal{D}) dw\]

Variational Inference:ΒΆ

Approximate \(p(w|\mathcal{D})\) with \(q(w|\theta)\):

\[\mathcal{L} = \mathbb{E}_{q(w)}[\log p(\mathcal{D}|w)] - \text{KL}(q(w) \| p(w))\]

πŸ“š Reference Materials:

Bayesian Neural Networks: Deep Theory and Variational InferenceΒΆ

1. Bayesian Inference for Neural NetworksΒΆ

The core idea of Bayesian Neural Networks (BNNs) is to treat network weights as random variables with distributions rather than point estimates.

Posterior Distribution:

Given dataset \(\mathcal{D} = \{(x_i, y_i)\}_{i=1}^N\), the posterior over weights is:

\[p(w | \mathcal{D}) = \frac{p(\mathcal{D} | w) p(w)}{p(\mathcal{D})} = \frac{\prod_{i=1}^N p(y_i | x_i, w) \cdot p(w)}{\int \prod_{i=1}^N p(y_i | x_i, w') \cdot p(w') dw'}\]

The Challenge: The denominator (evidence) requires integrating over all possible weight configurationsβ€”intractable for neural networks with millions of parameters.

Predictive Distribution:

For a new input \(x^*\), we want:

\[p(y^* | x^*, \mathcal{D}) = \int p(y^* | x^*, w) p(w | \mathcal{D}) dw\]

This marginalizes over the posterior, naturally providing uncertainty quantification.

2. Variational Inference for BNNsΒΆ

Since exact inference is intractable, we use variational inference to approximate \(p(w | \mathcal{D})\) with a simpler distribution \(q(w | \theta)\) parameterized by \(\theta\) (e.g., Gaussian with mean \(\mu\) and variance \(\sigma^2\)).

Evidence Lower Bound (ELBO):

We maximize the ELBO instead of the marginal likelihood:

\[\mathcal{L}(\theta) = \mathbb{E}_{q(w|\theta)}[\log p(\mathcal{D} | w)] - \text{KL}(q(w|\theta) \| p(w))\]

Derivation:

Starting from the log evidence:

\[\log p(\mathcal{D}) = \log \int p(\mathcal{D}, w) dw = \log \int \frac{p(\mathcal{D}, w)}{q(w|\theta)} q(w|\theta) dw\]

By Jensen’s inequality (log is concave):

\[\log p(\mathcal{D}) \geq \mathbb{E}_{q(w|\theta)}[\log \frac{p(\mathcal{D}, w)}{q(w|\theta)}] = \mathbb{E}_q[\log p(\mathcal{D}|w)] + \mathbb{E}_q[\log \frac{p(w)}{q(w|\theta)}]\]
\[= \mathbb{E}_q[\log p(\mathcal{D}|w)] - \text{KL}(q(w|\theta) \| p(w)) = \mathcal{L}(\theta)\]

Components:

  1. Likelihood Term: \(\mathbb{E}_{q(w|\theta)}[\log p(\mathcal{D} | w)]\) - Data fit (reconstruction)

  2. KL Regularizer: \(\text{KL}(q(w|\theta) \| p(w))\) - Complexity penalty (keep weights close to prior)

3. Bayes by Backprop AlgorithmΒΆ

Reparameterization Trick:

To compute gradients w.r.t. \(\theta\), we use:

\[w = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

where \(\sigma = \log(1 + \exp(\rho))\) (softplus to ensure positivity).

Loss Function (per minibatch):

\[\mathcal{L}_{\text{BB}} = \frac{1}{M} \sum_{i=1}^M \log q(w^{(i)}|\theta) - \log p(w^{(i)}) - \log p(\mathcal{D} | w^{(i)})\]

where \(w^{(i)} \sim q(w|\theta)\) are sampled weights.

Algorithm:

For each minibatch:
  1. Sample Ξ΅ ~ N(0, I)
  2. Compute w = ΞΌ + Οƒ βŠ™ Ξ΅
  3. Forward pass: Ε· = f(x; w)
  4. Compute NLL: -log p(D|w) = MSE or cross-entropy loss
  5. Compute KL: KL(q(w|ΞΈ) || p(w))
  6. Total loss: L = NLL + λ·KL
  7. Backpropagate and update μ, ρ

KL Divergence (Gaussian case):

For \(q(w_j|\theta) = \mathcal{N}(\mu_j, \sigma_j^2)\) and \(p(w_j) = \mathcal{N}(0, \sigma_p^2)\):

\[\text{KL}(q \| p) = \frac{1}{2} \sum_j \left[ \frac{\sigma_j^2 + \mu_j^2}{\sigma_p^2} - \log \frac{\sigma_j^2}{\sigma_p^2} - 1 \right]\]

4. Uncertainty DecompositionΒΆ

BNNs provide two types of uncertainty:

A. Aleatoric Uncertainty (Data Uncertainty):

  • Inherent noise in the data

  • Cannot be reduced with more data or model capacity

  • Example: Sensor noise, label ambiguity

  • Modeled by output noise: \(y = f(x; w) + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma_{\text{noise}}^2)\)

B. Epistemic Uncertainty (Model Uncertainty):

  • Uncertainty about model parameters \(w\)

  • Can be reduced with more training data

  • Example: Uncertainty far from training data

  • Captured by weight distribution \(p(w | \mathcal{D})\)

Total Predictive Variance:

\[\mathbb{V}[y^*] = \underbrace{\mathbb{E}_w[\sigma_{\text{noise}}^2(x^*; w)]}_{\text{Aleatoric}} + \underbrace{\mathbb{V}_w[\mu(x^*; w)]}_{\text{Epistemic}}\]

Law of Total Variance:

\[\mathbb{V}[y^*] = \mathbb{E}_w[\mathbb{V}[y^* | w]] + \mathbb{V}_w[\mathbb{E}[y^* | w]]\]

5. MC Dropout as Approximate Bayesian InferenceΒΆ

Gal & Ghahramani (2016) showed that dropout training approximates variational inference.

Training: Apply dropout with rate \(p\) during training:

\[\hat{y} = \text{Softmax}(W_2 \cdot \text{ReLU}(W_1 x \odot m))\]

where \(m \sim \text{Bernoulli}(1-p)\) is the dropout mask.

Inference: Keep dropout active during test time and average predictions:

\[\mathbb{E}[y^*] \approx \frac{1}{T} \sum_{t=1}^T f(x^*; w^{(t)})\]

Uncertainty Estimation:

\[\mathbb{V}[y^*] \approx \frac{1}{T} \sum_{t=1}^T (f(x^*; w^{(t)}) - \bar{y})^2\]

Connection to Variational Inference:

Dropout approximates \(q(w|\theta)\) with a Bernoulli distribution over weight masking. The variational distribution is:

\[q(w|\theta) = \prod_l \text{Bernoulli}(m_l; 1-p)\]

The ELBO corresponds to minimizing:

\[\mathcal{L} = -\frac{1}{N} \sum_i \log p(y_i | x_i, w) + \lambda \|w\|_2^2\]

where \(\lambda = \frac{p}{2N\tau}\) (weight decay relates to dropout rate).

6. Deep EnsemblesΒΆ

An alternative to full Bayesian inference: train \(M\) neural networks with different random initializations.

Predictive Mean:

\[\mu_{\text{ens}}(x) = \frac{1}{M} \sum_{m=1}^M f_m(x)\]

Predictive Variance:

\[\sigma_{\text{ens}}^2(x) = \frac{1}{M} \sum_{m=1}^M [f_m(x) - \mu_{\text{ens}}(x)]^2\]

Advantages:

  • Simple to implement (no architecture changes)

  • Embarrassingly parallel training

  • Often outperforms BNNs in practice

Disadvantages:

  • Computationally expensive (\(M \times\) cost)

  • Not truly Bayesian (no prior, no marginalization)

7. Comparison: BNN vs MC Dropout vs EnsemblesΒΆ

Method

Training Cost

Inference Cost

Uncertainty Quality

Calibration

BNN (VI)

High (KL term)

Medium (T samples)

High (principled)

Good

MC Dropout

Low (standard)

Medium (T forward)

Medium

Fair

Ensembles

Very High (M models)

High (M forward)

High (empirical)

Excellent

When to use:

  • BNN: Need principled uncertainty, small models, interpretability

  • MC Dropout: Quick uncertainty estimates, existing models

  • Ensembles: Best performance, resources available, calibration critical

8. Advanced TopicsΒΆ

A. Last-Layer Bayesian Approximation:

  • Only make the final layer Bayesian (computationally cheaper)

  • Feature extractor remains deterministic

  • Works well when representation learning is more important

B. Structured Variational Distributions:

  • Diagonal Gaussian: \(q(w) = \prod_j \mathcal{N}(w_j | \mu_j, \sigma_j^2)\) (independent weights)

  • Matrix Gaussian: \(q(W) = \mathcal{N}(W | M, \Sigma)\) (correlated weights)

  • Normalizing Flows: \(q(w) = q_0(z) \left| \det \frac{\partial f}{\partial z} \right|^{-1}\) (flexible)

C. Temperature Scaling for Calibration:

After training, scale logits by temperature \(T\):

\[p(y=k|x) = \frac{\exp(z_k / T)}{\sum_j \exp(z_j / T)}\]

Optimize \(T\) on validation set to minimize NLL. Improves calibration without retraining.

class BayesianLinear(nn.Module):
    """Bayesian linear layer with Gaussian weights."""
    
    def __init__(self, in_features, out_features, prior_std=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.weight_rho = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.randn(out_features) * 0.1)
        self.bias_rho = nn.Parameter(torch.randn(out_features) * 0.1)
    
    def forward(self, x):
        # Sample weights
        weight_std = torch.log1p(torch.exp(self.weight_rho))
        weight = self.weight_mu + weight_std * torch.randn_like(self.weight_mu)
        
        # Sample bias
        bias_std = torch.log1p(torch.exp(self.bias_rho))
        bias = self.bias_mu + bias_std * torch.randn_like(self.bias_mu)
        
        return F.linear(x, weight, bias)
    
    def kl_divergence(self):
        """KL divergence to prior."""
        weight_std = torch.log1p(torch.exp(self.weight_rho))
        bias_std = torch.log1p(torch.exp(self.bias_rho))
        
        # KL for weights
        kl_weight = 0.5 * torch.sum(
            (self.weight_mu ** 2 + weight_std ** 2) / (self.prior_std ** 2)
            - torch.log(weight_std ** 2 / (self.prior_std ** 2))
            - 1
        )
        
        # KL for bias
        kl_bias = 0.5 * torch.sum(
            (self.bias_mu ** 2 + bias_std ** 2) / (self.prior_std ** 2)
            - torch.log(bias_std ** 2 / (self.prior_std ** 2))
            - 1
        )
        
        return kl_weight + kl_bias

print("BayesianLinear defined")
# Advanced BNN Implementation: Comparison of Methods

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# 1. MC Dropout Network
# ============================================================

class MCDropoutNN(nn.Module):
    """Network with MC Dropout for uncertainty estimation."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x
    
    def predict_with_uncertainty(self, x, n_samples=100):
        """MC Dropout inference: keep dropout active."""
        self.train()  # Enable dropout during inference
        predictions = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                pred = self.forward(x)
                predictions.append(pred.cpu().numpy())
        
        predictions = np.array(predictions)
        mean = predictions.mean(axis=0)
        std = predictions.std(axis=0)
        
        return mean, std

# ============================================================
# 2. Deep Ensemble
# ============================================================

class EnsembleNN:
    """Ensemble of neural networks for uncertainty."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, n_models=5):
        self.models = []
        for _ in range(n_models):
            model = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
            self.models.append(model)
        
        self.n_models = n_models
    
    def to(self, device):
        for model in self.models:
            model.to(device)
        return self
    
    def train_model(self, idx, X, y, optimizer, n_epochs=1000):
        """Train individual model in ensemble."""
        model = self.models[idx]
        model.train()
        
        for epoch in range(n_epochs):
            y_pred = model(X)
            loss = F.mse_loss(y_pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    def predict_with_uncertainty(self, x):
        """Ensemble prediction."""
        predictions = []
        
        for model in self.models:
            model.eval()
            with torch.no_grad():
                pred = model(x)
                predictions.append(pred.cpu().numpy())
        
        predictions = np.array(predictions)
        mean = predictions.mean(axis=0)
        std = predictions.std(axis=0)
        
        return mean, std

# ============================================================
# 3. Uncertainty Decomposition Utilities
# ============================================================

def decompose_uncertainty(model, x_test, y_test, n_samples=100, noise_std=0.1):
    """
    Decompose total uncertainty into aleatoric and epistemic.
    
    For regression: y = f(x; w) + Ξ΅, Ξ΅ ~ N(0, σ²)
    
    Total variance: Var[y*] = E_w[σ²] + Var_w[ΞΌ(x*; w)]
                              ^^^^^^^^   ^^^^^^^^^^^^^^^^
                              Aleatoric   Epistemic
    """
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            pred = model(x_test)
            predictions.append(pred.cpu().numpy())
    
    predictions = np.array(predictions).squeeze()
    
    # Epistemic uncertainty: variance of predictions across weight samples
    epistemic = predictions.var(axis=0)
    
    # Aleatoric uncertainty: inherent noise (known or estimated)
    aleatoric = noise_std ** 2 * np.ones_like(epistemic)
    
    # Total uncertainty
    total = aleatoric + epistemic
    
    return {
        'total': total,
        'aleatoric': aleatoric,
        'epistemic': epistemic,
        'predictions': predictions
    }

# ============================================================
# 4. Calibration Metrics
# ============================================================

def compute_calibration_curve(confidences, accuracies, n_bins=10):
    """
    Expected Calibration Error (ECE).
    
    ECE = Ξ£ (|acc(B_m) - conf(B_m)|) Β· |B_m| / N
    
    Perfect calibration: confidence = accuracy
    """
    bins = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(confidences, bins) - 1
    
    ece = 0.0
    bin_accs = []
    bin_confs = []
    bin_counts = []
    
    for i in range(n_bins):
        mask = bin_indices == i
        if mask.sum() > 0:
            bin_acc = accuracies[mask].mean()
            bin_conf = confidences[mask].mean()
            bin_count = mask.sum()
            
            ece += np.abs(bin_acc - bin_conf) * bin_count / len(confidences)
            
            bin_accs.append(bin_acc)
            bin_confs.append(bin_conf)
            bin_counts.append(bin_count)
        else:
            bin_accs.append(0)
            bin_confs.append(0)
            bin_counts.append(0)
    
    return ece, bin_accs, bin_confs, bin_counts

def temperature_scaling(logits, labels, T_range=(0.1, 5.0), n_trials=50):
    """
    Find optimal temperature T to minimize NLL.
    
    Calibrated probabilities: p(y|x) = softmax(z/T)
    """
    best_T = 1.0
    best_nll = float('inf')
    
    temperatures = np.linspace(T_range[0], T_range[1], n_trials)
    
    for T in temperatures:
        scaled_logits = logits / T
        log_probs = F.log_softmax(torch.FloatTensor(scaled_logits), dim=1)
        nll = F.nll_loss(log_probs, torch.LongTensor(labels))
        
        if nll < best_nll:
            best_nll = nll
            best_T = T
    
    return best_T

print("Advanced BNN implementations loaded: MC Dropout, Ensembles, Uncertainty Decomposition, Calibration")

BNN for RegressionΒΆ

A Bayesian Neural Network places probability distributions over the network weights instead of learning single point estimates. Each weight \(w\) is represented by a distribution \(q(w) = \mathcal{N}(\mu_w, \sigma_w^2)\) where both the mean and variance are learnable parameters. During a forward pass, weights are sampled from their distributions, making each forward pass stochastic. Training optimizes the ELBO (Evidence Lower Bound): \(\mathcal{L} = \mathbb{E}_{q(w)}[\log p(y|x, w)] - \text{KL}(q(w) \| p(w))\), balancing data fit with a prior regularizer. The result is a model that provides principled uncertainty estimates for every prediction, distinguishing between confident and uncertain regions of the input space.

class BayesianNN(nn.Module):
    """Bayesian neural network."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = BayesianLinear(input_dim, hidden_dim)
        self.fc2 = BayesianLinear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def kl_divergence(self):
        return self.fc1.kl_divergence() + self.fc2.kl_divergence()

print("BayesianNN defined")

Generate Data and TrainΒΆ

We generate synthetic regression data with clear regions of data density (where the model should be confident) and gaps (where the model should be uncertain). Training a BNN on this data using variational inference (the reparameterization trick for sampling differentiable weight samples) converges similarly to standard training but produces a distribution over functions rather than a single function. The KL divergence term acts as a regularizer, preventing the posterior from collapsing to a point estimate and ensuring meaningful uncertainty quantification.

# Data
def f(x):
    return np.sin(3*x)

np.random.seed(42)
X_train = np.random.uniform(-1, 1, 30).reshape(-1, 1)
y_train = f(X_train) + 0.1 * np.random.randn(30, 1)

X_train_t = torch.FloatTensor(X_train).to(device)
y_train_t = torch.FloatTensor(y_train).to(device)

# Model
model = BayesianNN(1, 64, 1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Train
n_epochs = 2000
losses = []

for epoch in range(n_epochs):
    # Forward (sample weights)
    y_pred = model(X_train_t)
    
    # ELBO loss
    nll = F.mse_loss(y_pred, y_train_t)
    kl = model.kl_divergence() / len(X_train)
    loss = nll + 0.01 * kl
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, NLL: {nll.item():.4f}, KL: {kl.item():.4f}")

Predictive UncertaintyΒΆ

Predictive uncertainty in a BNN is obtained by running multiple forward passes (each with different weight samples) and computing statistics of the outputs. The mean of the forward passes gives the expected prediction, while the variance decomposes into aleatoric uncertainty (inherent data noise, captured by a noise output head) and epistemic uncertainty (model uncertainty due to limited data, captured by the spread across weight samples). This decomposition is uniquely valuable: epistemic uncertainty decreases with more training data, while aleatoric uncertainty does not – knowing which type dominates guides data collection and model improvement strategies.

# Test data
X_test = np.linspace(-1.5, 1.5, 200).reshape(-1, 1)
X_test_t = torch.FloatTensor(X_test).to(device)

# MC samples
model.eval()
n_samples = 100
predictions = []

with torch.no_grad():
    for _ in range(n_samples):
        y_pred = model(X_test_t)
        predictions.append(y_pred.cpu().numpy())

predictions = np.array(predictions).squeeze()

# Statistics
mean_pred = predictions.mean(axis=0)
std_pred = predictions.std(axis=0)

print(f"Predictions: {predictions.shape}, Mean: {mean_pred.shape}")

Visualize ResultsΒΆ

Plotting multiple sampled functions from the BNN posterior, along with the mean prediction and uncertainty bands, provides a visual analogue to Gaussian Process regression. Near training data, the sampled functions agree closely (low epistemic uncertainty); far from training data, they diverge (high epistemic uncertainty). Comparing the BNN’s uncertainty estimates to a standard neural network (which produces only point predictions) and to a Gaussian Process (which provides exact uncertainty) highlights the trade-offs: BNNs scale better than GPs to large datasets and complex architectures while providing uncertainty estimates that standard networks lack.

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Predictions with uncertainty
axes[0].plot(X_test, f(X_test), 'k--', label='True', linewidth=2)
axes[0].plot(X_test, mean_pred, 'b-', label='BNN mean', linewidth=2)
axes[0].fill_between(X_test.ravel(), mean_pred - 2*std_pred, mean_pred + 2*std_pred, 
                      alpha=0.3, label='Β±2Οƒ')
axes[0].scatter(X_train, y_train, c='r', s=50, zorder=10, label='Data')
axes[0].set_xlabel('x', fontsize=12)
axes[0].set_ylabel('y', fontsize=12)
axes[0].set_title('Bayesian NN Predictions', fontsize=13)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Training loss
axes[1].plot(losses)
axes[1].set_xlabel('Iteration', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Training Loss', fontsize=13)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
# Uncertainty Visualization and Method Comparison

# Generate comparison data
X_comp = np.linspace(-1.5, 1.5, 200).reshape(-1, 1)
X_comp_t = torch.FloatTensor(X_comp).to(device)

# ============================================================
# Train MC Dropout model
# ============================================================
mc_model = MCDropoutNN(1, 64, 1, dropout_rate=0.1).to(device)
mc_optimizer = torch.optim.Adam(mc_model.parameters(), lr=1e-2)

print("Training MC Dropout model...")
for epoch in range(1000):
    mc_model.train()
    y_pred = mc_model(X_train_t)
    loss = F.mse_loss(y_pred, y_train_t)
    
    mc_optimizer.zero_grad()
    loss.backward()
    mc_optimizer.step()

mc_mean, mc_std = mc_model.predict_with_uncertainty(X_comp_t, n_samples=100)

# ============================================================
# Train Ensemble
# ============================================================
ensemble = EnsembleNN(1, 64, 1, n_models=5).to(device)

print("Training ensemble (5 models)...")
for idx in range(ensemble.n_models):
    optimizer = torch.optim.Adam(ensemble.models[idx].parameters(), lr=1e-2)
    ensemble.train_model(idx, X_train_t, y_train_t, optimizer, n_epochs=1000)

ens_mean, ens_std = ensemble.predict_with_uncertainty(X_comp_t)

# ============================================================
# BNN predictions (already trained)
# ============================================================
bnn_predictions = []
model.eval()
with torch.no_grad():
    for _ in range(100):
        pred = model(X_comp_t)
        bnn_predictions.append(pred.cpu().numpy())

bnn_predictions = np.array(bnn_predictions).squeeze()
bnn_mean = bnn_predictions.mean(axis=0)
bnn_std = bnn_predictions.std(axis=0)

# ============================================================
# Visualization: Compare all methods
# ============================================================
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

methods = [
    ('BNN (Variational)', bnn_mean, bnn_std, 'blue'),
    ('MC Dropout', mc_mean.squeeze(), mc_std.squeeze(), 'green'),
    ('Ensemble (5 models)', ens_mean.squeeze(), ens_std.squeeze(), 'red')
]

for idx, (name, mean, std, color) in enumerate(methods):
    ax = axes[idx // 2, idx % 2]
    
    # True function
    ax.plot(X_comp, f(X_comp), 'k--', label='True function', linewidth=2, alpha=0.7)
    
    # Predictions
    ax.plot(X_comp, mean, color=color, label=f'{name} mean', linewidth=2)
    ax.fill_between(X_comp.ravel(), mean - 2*std, mean + 2*std, 
                     alpha=0.3, color=color, label='Β±2Οƒ')
    
    # Training data
    ax.scatter(X_train, y_train, c='black', s=80, zorder=10, 
               edgecolors='white', linewidths=2, label='Training data')
    
    ax.set_xlabel('x', fontsize=13)
    ax.set_ylabel('y', fontsize=13)
    ax.set_title(f'{name} - Uncertainty Quantification', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11, loc='upper left')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)

# ============================================================
# Uncertainty comparison plot
# ============================================================
ax = axes[1, 1]
ax.plot(X_comp, bnn_std, 'b-', label='BNN', linewidth=2)
ax.plot(X_comp, mc_std.squeeze(), 'g-', label='MC Dropout', linewidth=2)
ax.plot(X_comp, ens_std.squeeze(), 'r-', label='Ensemble', linewidth=2)

# Highlight extrapolation regions
ax.axvspan(-1.5, -1.0, alpha=0.2, color='gray', label='Extrapolation')
ax.axvspan(1.0, 1.5, alpha=0.2, color='gray')

ax.set_xlabel('x', fontsize=13)
ax.set_ylabel('Predictive Std Dev (Οƒ)', fontsize=13)
ax.set_title('Uncertainty Comparison Across Methods', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('bnn_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# Uncertainty Decomposition for BNN
# ============================================================
noise_std = 0.1
uncertainty_data = decompose_uncertainty(model, X_comp_t, None, n_samples=100, noise_std=noise_std)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Total uncertainty
axes[0].plot(X_comp, np.sqrt(uncertainty_data['total']), 'purple', linewidth=2)
axes[0].fill_between(X_comp.ravel(), 0, np.sqrt(uncertainty_data['total']), alpha=0.3, color='purple')
axes[0].scatter(X_train, np.zeros_like(X_train), c='red', s=50, zorder=10, label='Training data')
axes[0].set_xlabel('x', fontsize=12)
axes[0].set_ylabel('Total Uncertainty (Οƒ)', fontsize=12)
axes[0].set_title('Total Predictive Uncertainty', fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Plot 2: Aleatoric vs Epistemic
axes[1].plot(X_comp, np.sqrt(uncertainty_data['aleatoric']), 'orange', 
             linewidth=2, label='Aleatoric (data noise)')
axes[1].plot(X_comp, np.sqrt(uncertainty_data['epistemic']), 'blue', 
             linewidth=2, label='Epistemic (model)')
axes[1].fill_between(X_comp.ravel(), 0, np.sqrt(uncertainty_data['aleatoric']), 
                     alpha=0.2, color='orange')
axes[1].fill_between(X_comp.ravel(), 0, np.sqrt(uncertainty_data['epistemic']), 
                     alpha=0.2, color='blue')
axes[1].set_xlabel('x', fontsize=12)
axes[1].set_ylabel('Uncertainty (Οƒ)', fontsize=12)
axes[1].set_title('Uncertainty Decomposition', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Plot 3: Stacked uncertainty
axes[2].fill_between(X_comp.ravel(), 0, np.sqrt(uncertainty_data['aleatoric']), 
                     alpha=0.5, color='orange', label='Aleatoric')
axes[2].fill_between(X_comp.ravel(), 
                     np.sqrt(uncertainty_data['aleatoric']), 
                     np.sqrt(uncertainty_data['aleatoric']) + np.sqrt(uncertainty_data['epistemic']), 
                     alpha=0.5, color='blue', label='Epistemic')
axes[2].scatter(X_train, np.zeros_like(X_train), c='red', s=50, zorder=10)
axes[2].set_xlabel('x', fontsize=12)
axes[2].set_ylabel('Cumulative Uncertainty (Οƒ)', fontsize=12)
axes[2].set_title('Stacked Uncertainty Components', fontsize=13, fontweight='bold')
axes[2].legend(fontsize=11, loc='upper left')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('uncertainty_decomposition.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("UNCERTAINTY ANALYSIS")
print("="*70)
print(f"Mean Aleatoric Uncertainty:  {np.sqrt(uncertainty_data['aleatoric']).mean():.4f}")
print(f"Mean Epistemic Uncertainty:  {np.sqrt(uncertainty_data['epistemic']).mean():.4f}")
print(f"Mean Total Uncertainty:      {np.sqrt(uncertainty_data['total']).mean():.4f}")
print("\nNote: Epistemic uncertainty is HIGH in extrapolation regions (|x| > 1.0)")
print("      Aleatoric uncertainty is CONSTANT (inherent data noise)")
print("="*70)

MC DropoutΒΆ

MC (Monte Carlo) Dropout is a practical approximation to Bayesian inference: simply keep dropout active at test time and run multiple forward passes. Gal & Ghahramani (2016) showed that dropout at test time is mathematically equivalent to approximate variational inference in a deep Gaussian process. The mean of multiple stochastic forward passes estimates the predictive mean, and the variance estimates the predictive uncertainty. MC Dropout requires no architectural changes beyond standard dropout – making it the most accessible method for adding uncertainty estimation to any existing neural network. The trade-off is that the uncertainty estimates are less calibrated than full variational inference or ensemble methods.

class MCDropoutNN(nn.Module):
    """Network with MC Dropout for uncertainty."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_p=0.3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Train MC Dropout model
mc_model = MCDropoutNN(1, 64, 1).to(device)
optimizer = torch.optim.Adam(mc_model.parameters(), lr=1e-2)

for epoch in range(1000):
    mc_model.train()
    y_pred = mc_model(X_train_t)
    loss = F.mse_loss(y_pred, y_train_t)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# MC predictions (keep dropout on)
mc_model.train()  # Keep dropout enabled
mc_predictions = []

with torch.no_grad():
    for _ in range(100):
        y_pred = mc_model(X_test_t)
        mc_predictions.append(y_pred.cpu().numpy())

mc_predictions = np.array(mc_predictions).squeeze()
mc_mean = mc_predictions.mean(axis=0)
mc_std = mc_predictions.std(axis=0)

# Plot
plt.figure(figsize=(12, 6))
plt.plot(X_test, f(X_test), 'k--', label='True', linewidth=2)
plt.plot(X_test, mc_mean, 'g-', label='MC Dropout mean', linewidth=2)
plt.fill_between(X_test.ravel(), mc_mean - 2*mc_std, mc_mean + 2*mc_std, 
                 alpha=0.3, color='g', label='Β±2Οƒ')
plt.scatter(X_train, y_train, c='r', s=50, zorder=10, label='Data')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('MC Dropout Uncertainty', fontsize=13)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

SummaryΒΆ

Bayesian Neural Networks:ΒΆ

Key Ideas:

  1. Probability distribution over weights

  2. Predictive uncertainty via integration

  3. Variational inference for tractability

  4. ELBO = likelihood - KL divergence

Uncertainty Types:ΒΆ

  • Epistemic: Model uncertainty (reducible with data)

  • Aleatoric: Data noise (irreducible)

Methods:ΒΆ

Variational Inference:

  • Gaussian posterior q(w)

  • Reparameterization trick

  • KL to prior regularization

MC Dropout:

  • Dropout as Bayesian approximation

  • Enable dropout at test time

  • Multiple forward passes

Advantages:ΒΆ

  • Calibrated uncertainty

  • Active learning

  • Out-of-distribution detection

  • Safety-critical applications

Applications:ΒΆ

  • Medical diagnosis

  • Autonomous vehicles

  • Reinforcement learning (exploration)

  • Bayesian optimization

Variants:ΒΆ

  • Laplace approximation

  • Ensemble methods

  • Deep ensembles

  • SWAG (Stochastic Weight Averaging)

Challenges:ΒΆ

  • Computational cost

  • Hyperparameter tuning

  • Posterior approximation quality

Advanced Bayesian Neural Networks TheoryΒΆ

1. Introduction to Bayesian Deep LearningΒΆ

1.1 Motivation: Uncertainty QuantificationΒΆ

Traditional neural networks provide point estimates ΞΈΜ‚ for parameters, leading to:

  • Overconfidence on out-of-distribution data

  • No uncertainty quantification (cannot distinguish β€œdon’t know” from β€œsure but wrong”)

  • Poor calibration (predicted probabilities β‰  true probabilities)

Bayesian Neural Networks (BNNs) maintain distributions over weights p(ΞΈ|D), enabling:

  • Epistemic uncertainty (model uncertainty, reducible with more data)

  • Aleatoric uncertainty (data noise, irreducible)

  • Principled decision-making under uncertainty

  • Calibrated predictions with confidence intervals

1.2 Bayesian FrameworkΒΆ

Prior: p(ΞΈ) - belief before seeing data
Likelihood: p(D|θ) = ∏ᡒ p(yᡒ|xᡒ, θ)
Posterior: p(ΞΈ|D) = p(D|ΞΈ)p(ΞΈ) / p(D) via Bayes’ theorem

Prediction for new input x*:

p(y*|x*, D) = ∫ p(y*|x*, θ) p(θ|D) dθ

Challenge: Posterior p(ΞΈ|D) is intractable for neural networks (millions of parameters)

2. Bayesian Inference MethodsΒΆ

2.1 Variational Inference (VI)ΒΆ

Idea: Approximate intractable posterior p(ΞΈ|D) with tractable q_Ο†(ΞΈ)

Objective: Minimize KL divergence

KL(q_Ο†(ΞΈ) || p(ΞΈ|D)) = ∫ q_Ο†(ΞΈ) log[q_Ο†(ΞΈ) / p(ΞΈ|D)] dΞΈ

Equivalent to maximizing ELBO (Evidence Lower Bound):

ELBO(Ο†) = E_{q_Ο†(ΞΈ)}[log p(D|ΞΈ)] - KL(q_Ο†(ΞΈ) || p(ΞΈ))
         = Ξ£α΅’ E_{q_Ο†(ΞΈ)}[log p(yα΅’|xα΅’, ΞΈ)] - KL(q_Ο†(ΞΈ) || p(ΞΈ))

Bayes by Backprop (Blundell et al., 2015):

  • Parameterize q_Ο†(ΞΈ) = N(ΞΌ, σ²) (mean-field Gaussian)

  • Reparameterization trick: ΞΈ = ΞΌ + Οƒ βŠ™ Ξ΅, Ξ΅ ~ N(0, I)

  • Gradient: βˆ‡_Ο† ELBO = βˆ‡_Ο† E_Ξ΅[log p(D|ΞΈ(Ξ΅)) - log q_Ο†(ΞΈ(Ξ΅)) + log p(ΞΈ(Ξ΅))]

Advantages: Scalable, differentiable, GPU-friendly
Disadvantages: Mean-field assumption (independence), local optima

2.2 Monte Carlo DropoutΒΆ

Observation (Gal & Ghahramani, 2016): Dropout is approximate Bayesian inference!

Standard dropout:

y = Wβ‚‚ Β· dropout(ReLU(W₁x))

Interpretation: Each dropout mask ~ sample from posterior q(ΞΈ)

Inference:

  1. Train with dropout (rate p)

  2. At test time, keep dropout active

  3. Sample T predictions: ŷ₁, …, Ε·_T (different masks)

  4. Mean: E[y*|x*] β‰ˆ 1/T Ξ£β‚œ Ε·β‚œ

  5. Variance: Var[y*|x*] β‰ˆ 1/T Ξ£β‚œ (Ε·β‚œ - mean)Β²

Advantages: Zero training overhead, works with any architecture
Disadvantages: Limited expressiveness, heuristic connection to VI

2.3 Deep EnsemblesΒΆ

Idea: Train M independent models with different initializations

Prediction:

p(y*|x*, D) β‰ˆ 1/M Ξ£β‚˜ p(y*|x*, ΞΈβ‚˜)

Training:

  • Different random seeds

  • Different data subsets (bagging)

  • Adversarial training for diversity

Advantages: Simple, strong empirical performance, diverse hypotheses
Disadvantages: MΓ— training cost, not truly Bayesian (no prior)

2.4 Markov Chain Monte Carlo (MCMC)ΒΆ

Goal: Sample ΞΈ ~ p(ΞΈ|D) exactly (asymptotically)

Stochastic Gradient Langevin Dynamics (SGLD):

ΞΈβ‚œβ‚Šβ‚ = ΞΈβ‚œ + (Ξ·/2)βˆ‡log p(ΞΈβ‚œ|D) + N(0, Ξ·)
     = ΞΈβ‚œ + (Ξ·/2)[βˆ‡log p(D|ΞΈβ‚œ) + βˆ‡log p(ΞΈβ‚œ)] + N(0, Ξ·)
  • Gradient ascent + Langevin noise

  • Converges to true posterior as Ξ· β†’ 0

Hamiltonian Monte Carlo (HMC):

  • Introduce momentum variables

  • Leapfrog integrator for proposals

  • Higher acceptance rate than SGLD

Advantages: Asymptotically exact, no variational gap
Disadvantages: Slow convergence, high computational cost, tuning required

2.5 Laplace ApproximationΒΆ

Idea: Approximate posterior with Gaussian at MAP estimate

Procedure:

  1. Find MAP: ΞΈ_MAP = argmax p(ΞΈ|D)

  2. Compute Hessian: H = -βˆ‡Β²log p(ΞΈ|D)|_{ΞΈ_MAP}

  3. Approximate: p(ΞΈ|D) β‰ˆ N(ΞΈ_MAP, H⁻¹)

Challenges for NNs:

  • Hessian is huge (millions Γ— millions)

  • Expensive to compute and invert

Modern approaches:

  • KFAC (Kronecker-Factored Approximate Curvature): Block-diagonal Hessian

  • Diagonal Laplace: Only diagonal of H

  • Last-layer Laplace: Only linearize last layer (cheap!)

Advantages: Post-hoc (apply to pretrained models), principled
Disadvantages: Gaussian assumption, expensive for full network

3. Structured Variational InferenceΒΆ

3.1 Mean-Field vs. Structured ApproximationsΒΆ

Mean-field: q(θ) = ∏ᡒ q(θᡒ) (fully factorized)

  • Simple, scalable

  • Too restrictive: Ignores correlations

Matrix Variate Gaussian (Louizos & Welling, 2016):

q(W) = MN(M, U, V)  (W is matrix of weights)
  • Captures row/column correlations

  • Kronecker structure for efficiency

Normalizing Flows (Rezende & Mohamed, 2015):

ΞΈ = f_K(...fβ‚‚(f₁(Ξ΅))...)  where Ξ΅ ~ pβ‚€(Ξ΅)
Change of variables: q(ΞΈ) = pβ‚€(Ξ΅) |det J_f⁻¹|⁻¹
  • Expressive posteriors via invertible transformations

  • Planar flows, RealNVP, MAF for BNNs

3.2 Weight Uncertainty in Neural NetworksΒΆ

Hierarchical priors for automatic relevance determination:

p(W) = ∏ᡒⱼ N(Wᡒⱼ|0, αᡒⱼ⁻¹)
p(α) = ∏ᡒⱼ Gamma(αᡒⱼ|a, b)

Sparse variational dropout (Molchanov et al., 2017):

  • Learn dropout rate per weight

  • Prune weights with high dropout (Ξ±α΅’β±Ό β†’ ∞)

  • Achieves sparsity + uncertainty

4. Scalable BNN TrainingΒΆ

4.1 Minibatch Scaling for ELBOΒΆ

Full ELBO:

L(Ο†) = Ξ£α΅’β‚Œβ‚α΄Ί E_{q_Ο†}[log p(yα΅’|xα΅’, ΞΈ)] - KL(q_Ο† || p)

Minibatch estimate:

LΜƒ(Ο†) = (N/B) Σᡒ∈batch E_{q_Ο†}[log p(yα΅’|xα΅’, ΞΈ)] - KL(q_Ο† || p)
  • KL term constant (computed once per batch)

  • Likelihood term scaled by N/B

4.2 Local Reparameterization TrickΒΆ

Standard reparameterization: Sample full ΞΈ per minibatch

  • High variance gradients

Local reparameterization (Kingma et al., 2015):

  • Sample activations instead of weights

  • For linear layer: z = Wx where W ~ q(W)

    E[z] = E[W]x = ΞΌx
    Var[z] = x^T diag(σ²) x
    z ~ N(ΞΌx, Var[z])
    
  • Lower variance, faster

4.3 Natural Gradient VIΒΆ

Natural gradient: Adjust for parameter space curvature

Ο†β‚œβ‚Šβ‚ = Ο†β‚œ + Ξ· F⁻¹ βˆ‡_Ο† ELBO

where F is Fisher information matrix

Practical: Use Adam (approximates natural gradient)

5. Uncertainty DecompositionΒΆ

5.1 Epistemic vs. Aleatoric UncertaintyΒΆ

Epistemic (model uncertainty):

  • Reducible with more data

  • Captured by posterior variance p(ΞΈ|D)

  • Example: Insufficient data in region

Aleatoric (data noise):

  • Irreducible (inherent randomness)

  • Captured by likelihood variance p(y|x, ΞΈ)

  • Example: Sensor noise, label ambiguity

5.2 Heteroscedastic Aleatoric UncertaintyΒΆ

Homoscedastic: σ² is constant
Heteroscedastic: σ²(x) varies with input

Model:

NN: x β†’ (ΞΌ(x), σ²(x))
p(y|x, ΞΈ) = N(y|ΞΌ(x), σ²(x))

Loss (negative log-likelihood):

L = (1/2σ²(x))(y - ΞΌ(x))Β² + (1/2)log σ²(x)
  • First term: Precision-weighted MSE

  • Second term: Regularization (prevents Οƒ β†’ ∞)

Combined epistemic + aleatoric:

Var[y*] = E_ΞΈ[Var[y|x, ΞΈ]] + Var_ΞΈ[E[y|x, ΞΈ]]
        = E[σ²(x)] + Var[ΞΌ(x)]
        = aleatoric + epistemic

6. Priors for Neural NetworksΒΆ

6.1 Weight PriorsΒΆ

Gaussian prior: p(ΞΈ) = N(0, σ²_p I)

  • Corresponds to L2 regularization (MAP = ridge)

  • Induces smoothness

Laplace prior: p(ΞΈ) = Laplace(0, b)

  • Corresponds to L1 regularization (MAP = lasso)

  • Induces sparsity

Horseshoe prior:

p(wβ±Ό) = N(0, λⱼ²τ²)
p(λⱼ) = C⁺(0, 1)  (half-Cauchy)
  • Sparse but keeps important weights large

  • Better than Laplace for high-dimensional data

6.2 Functional Priors (Neural Network Gaussian Processes)ΒΆ

Observation: Infinite-width NN β†’ Gaussian Process (Neal, 1996)

For single hidden layer with width H β†’ ∞:

f(x) = (b/√H) Ξ£β‚• vβ‚• Ο†(wβ‚•α΅€x)

where wβ‚• ~ N(0, I), vβ‚• ~ N(0, 1)

Limit: f(x) ~ GP(0, K(x, x’))
Kernel (NNGP):

K(x, x') = E_w[Ο†(wα΅€x) Ο†(wα΅€x')]

For ReLU: Kernel has closed form (Cho & Saul, 2009)

Deep GPs: Multi-layer limit (Lee et al., 2018)

  • Provides prior over functions

  • Useful for architecture search, initialization

7. BNN ApplicationsΒΆ

7.1 Active LearningΒΆ

Goal: Select most informative points to label

BALD (Bayesian Active Learning by Disagreement):

I(y; ΞΈ|x, D) = H[y|x, D] - E_ΞΈ[H[y|x, ΞΈ]]
             = H[E_ΞΈ[p(y|x, ΞΈ)]] - E_ΞΈ[H[p(y|x, ΞΈ)]]
  • High when models disagree (epistemic uncertainty)

  • Query points with maximum mutual information

7.2 Continual LearningΒΆ

Catastrophic forgetting: New tasks overwrite old knowledge

Elastic Weight Consolidation (EWC) (Kirkpatrick et al., 2017):

L_new = L_task(ΞΈ) + (Ξ»/2) Ξ£α΅’ Fα΅’(ΞΈα΅’ - ΞΈα΅’*)Β²
  • Fα΅’: Fisher information (importance of weight i for old task)

  • Prevents large changes to important weights

Variational Continual Learning (Nguyen et al., 2018):

  • Posterior of task t becomes prior for task t+1

  • Maintains memory of all tasks

7.3 Out-of-Distribution DetectionΒΆ

Predictive entropy:

H[y|x, D] = -Ξ£_c p(y=c|x, D) log p(y=c|x, D)
  • High entropy β†’ uncertain β†’ likely OOD

Predictive variance: Var[y|x, D]

  • High variance β†’ epistemic uncertainty β†’ OOD

Threshold: Reject if H or Var > threshold

7.4 Safety-Critical ApplicationsΒΆ

  • Medical diagnosis: Uncertainty for β€œrefer to specialist”

  • Autonomous driving: Detect novel scenarios

  • Reinforcement learning: Risk-aware exploration

8. Computational ComplexityΒΆ

8.1 Training CostΒΆ

Method

Forward Pass

Backward Pass

Memory

Standard NN

O(W)

O(W)

O(W)

Bayes by Backprop

O(SW)

O(SW)

O(2W)

MC Dropout

O(W)

O(W)

O(W)

Deep Ensembles (M)

O(MW)

O(MW)

O(MW)

SGLD

O(W)

O(W)

O(TW)

  • W: Number of weights

  • S: Samples per batch (typically 1-3)

  • M: Ensemble size (typically 5-10)

  • T: MCMC samples

8.2 Inference CostΒΆ

Single prediction:

  • Standard: 1 forward pass

  • BNN: T forward passes (sample θ₁, …, ΞΈ_T)

Typical T: 10-100 for uncertainty, 1000+ for calibration

9. Evaluation MetricsΒΆ

9.1 CalibrationΒΆ

Perfect calibration: Predicted probability = actual frequency

Expected Calibration Error (ECE):

ECE = Ξ£β‚˜ (|Bβ‚˜|/N) |acc(Bβ‚˜) - conf(Bβ‚˜)|
  • Partition predictions into M bins by confidence

  • Compare accuracy vs. confidence per bin

Reliability diagram: Plot accuracy vs. confidence

  • Ideal: Diagonal line

9.2 Negative Log-Likelihood (NLL)ΒΆ

NLL = -(1/N) Ξ£α΅’ log p(yα΅’|xα΅’, D)
  • Measures quality of predicted distribution

  • Lower is better

9.3 Brier ScoreΒΆ

BS = (1/N) Ξ£α΅’ Ξ£_c (p(y=c|xα΅’) - πŸ™[yα΅’=c])Β²
  • Squared error of predicted probabilities

  • Lower is better

10. Recent Advances (2017-2024)ΒΆ

10.1 Function-Space InferenceΒΆ

Neural Tangent Kernel (NTK) (Jacot et al., 2018):

  • Infinite-width limit at initialization

  • Kernel remains constant during training

  • Exact GP inference in function space

Limitations: Requires infinite width, doesn’t capture feature learning

10.2 Stochastic Weight Averaging Gaussian (SWAG)ΒΆ

Idea (Maddox et al., 2019): Approximate posterior from SGD trajectory

  1. Run SGD for T iterations

  2. Collect θ₁, …, ΞΈ_T in later epochs

  3. Fit Gaussian: ΞΌ = mean(ΞΈ), Ξ£ = cov(ΞΈ)

Advantages: Post-hoc, uses existing training, cheap

10.3 Neural ProcessesΒΆ

Meta-learning BNNs:

  • Learn prior p(ΞΈ) from related tasks

  • Fast adaptation to new tasks

  • Combines benefits of GPs and NNs

10.4 Predictive Uncertainty with Deep Kernel LearningΒΆ

Combine NN feature extractor with GP:

k(x, x') = k_GP(Ο†_NN(x), Ο†_NN(x'))
  • NN learns features

  • GP provides calibrated uncertainty

10.5 Generalized Variational InferenceΒΆ

RΓ©nyi divergence instead of KL:

D_α(q||p) = (1/(α-1)) log ∫ q^α p^{1-α}
  • Ξ± = 1: Recovers KL

  • Ξ± > 1: More robust, mode-seeking

  • Ξ± < 1: Mass-covering

11. Practical GuidelinesΒΆ

11.1 Method SelectionΒΆ

Use MC Dropout if:

  • Have pretrained model

  • Need quick uncertainty estimate

  • Limited compute budget

Use Bayes by Backprop if:

  • Training from scratch

  • Want principled Bayesian inference

  • Have moderate data

Use Deep Ensembles if:

  • Need best empirical performance

  • Can afford MΓ— training

  • Want diversity

Use Last-Layer Laplace if:

  • Have pretrained model

  • Want post-hoc uncertainty

  • Need calibrated predictions

11.2 Hyperparameter TuningΒΆ

Prior variance σ²_p: Controls regularization

  • Too small: Underfitting

  • Too large: Overfitting

  • Tune via validation NLL

Posterior learning rate: Typically lower than standard

  • Adam with lr = 1e-4 to 1e-3

Number of samples T: Trade-off accuracy vs. speed

  • Training: 1-3 samples

  • Evaluation: 10-100 samples

12. Limitations and Open ProblemsΒΆ

12.1 Current ChallengesΒΆ

  1. Computational cost: TΓ— slower inference

  2. Scalability: Difficult for huge models (GPT-3)

  3. Posterior collapse: VI can underestimate uncertainty

  4. Hyperprior selection: Sensitive to prior choice

  5. Calibration: Not guaranteed even for BNNs

12.2 Open Research QuestionsΒΆ

  • Scalable exact inference: MCMC for billions of parameters

  • Better posteriors: Beyond mean-field, tractable structured VI

  • Functional priors: Specify p(f) instead of p(ΞΈ)

  • Multi-modal posteriors: Capture symmetries, local optima

  • Uncertainty in generative models: BNNs for GANs, diffusion

14. Software LibrariesΒΆ

14.1 Python LibrariesΒΆ

  • Pyro: Probabilistic programming (PyTorch backend)

  • TensorFlow Probability: Bayesian layers, VI, MCMC

  • Blitz: Bayes by Backprop in PyTorch

  • Laplace: Last-layer Laplace approximation

  • GPyTorch: Scalable GPs for DNN features

14.2 Example: Bayes by Backprop LayerΒΆ

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.mu_w = nn.Parameter(torch.randn(out_features, in_features))
        self.rho_w = nn.Parameter(torch.randn(out_features, in_features))
        self.mu_b = nn.Parameter(torch.randn(out_features))
        self.rho_b = nn.Parameter(torch.randn(out_features))
    
    def forward(self, x):
        sigma_w = torch.log1p(torch.exp(self.rho_w))
        sigma_b = torch.log1p(torch.exp(self.rho_b))
        
        w = self.mu_w + sigma_w * torch.randn_like(sigma_w)
        b = self.mu_b + sigma_b * torch.randn_like(sigma_b)
        
        return F.linear(x, w, b)

15. Benchmarks and ResultsΒΆ

15.1 Classification TasksΒΆ

CIFAR-10 (Accuracy / NLL / ECE):

  • Standard ResNet: 95.5% / 0.18 / 0.05

  • MC Dropout (p=0.1): 95.2% / 0.16 / 0.04

  • Deep Ensemble (M=5): 96.1% / 0.14 / 0.02

  • SWAG: 95.8% / 0.15 / 0.03

ImageNet:

  • Standard: 76.1% / 1.02

  • Ensemble: 77.3% / 0.91

  • Temp scaling: 76.1% / 0.89

15.2 Regression TasksΒΆ

UCI datasets (Avg. RMSE / Avg. NLL):

  • Standard MLP: 0.52 / 1.21

  • MC Dropout: 0.49 / 0.98

  • Variational: 0.47 / 0.91

  • Ensemble: 0.46 / 0.88

15.3 Active LearningΒΆ

MNIST (Accuracy with 100 labels):

  • Random: 85%

  • Entropy: 88%

  • BALD (BNN): 92%

16. Key TakeawaysΒΆ

  1. BNNs provide uncertainty: Epistemic + aleatoric via posterior

  2. Trade-offs: Computational cost vs. uncertainty quality

  3. Practical methods: MC Dropout, ensembles, last-layer Laplace

  4. Calibration is crucial: Measure with ECE, NLL, reliability plots

  5. Applications: Active learning, OOD detection, safety-critical systems

  6. Open problems: Scalability, multi-modal posteriors, functional priors

When to use BNNs:

  • Safety-critical applications (medical, autonomous)

  • Small data regimes (active learning)

  • Need confidence intervals

  • Out-of-distribution detection

When NOT to use:

  • Computational budget limited

  • Only accuracy matters (not uncertainty)

  • Data is abundant and clean

17. ReferencesΒΆ

Foundational:

  • Neal (1996): Bayesian Learning for Neural Networks

  • MacKay (1992): Practical Bayesian Framework for Backprop

Variational Inference:

  • Blundell et al. (2015): Weight Uncertainty in Neural Networks

  • Kingma et al. (2015): Variational Dropout and Local Reparameterization

  • Louizos & Welling (2016): Structured and Efficient Variational Inference

MC Dropout:

  • Gal & Ghahramani (2016): Dropout as a Bayesian Approximation

Ensembles:

  • Lakshminarayanan et al. (2017): Simple and Scalable Predictive Uncertainty

Laplace:

  • Ritter et al. (2018): Scalable Laplace Approximation

  • Daxberger et al. (2021): Laplace Redux

Recent:

  • Maddox et al. (2019): SWAG

  • Wilson & Izmailov (2020): Bayesian Deep Learning and a Probabilistic Perspective

  • Fortuin (2022): Priors in Bayesian Deep Learning (survey)

18. Connection to Other TopicsΒΆ

Gaussian Processes: BNNs at infinite width
Meta-Learning: Neural Processes, learned priors
Continual Learning: EWC uses BNN principles
Generative Models: VAEs = BNNs for latent variables
Reinforcement Learning: Uncertainty-aware exploration

"""
Complete Bayesian Neural Networks Implementations
==================================================
Includes: Bayes by Backprop, MC Dropout, Deep Ensembles, Laplace approximation,
uncertainty quantification, calibration metrics.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, kl_divergence
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# ============================================================================
# 1. Bayes by Backprop
# ============================================================================

class BayesianLinear(nn.Module):
    """
    Bayesian linear layer with weight uncertainty.
    
    Implements variational inference with mean-field Gaussian posterior.
    
    Args:
        in_features: Input dimension
        out_features: Output dimension
        prior_sigma: Prior standard deviation (regularization strength)
    """
    def __init__(self, in_features, out_features, prior_sigma=1.0):
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_sigma = prior_sigma
        
        # Variational parameters: q(W) = N(ΞΌ_w, Οƒ_wΒ²)
        self.mu_w = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.rho_w = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        
        # Bias parameters
        self.mu_b = nn.Parameter(torch.zeros(out_features))
        self.rho_b = nn.Parameter(torch.randn(out_features) * 0.01)
        
        # Register KL divergence
        self.kl_divergence = 0
    
    @property
    def sigma_w(self):
        """Compute Οƒ from ρ using softplus: Οƒ = log(1 + exp(ρ))"""
        return torch.log1p(torch.exp(self.rho_w))
    
    @property
    def sigma_b(self):
        return torch.log1p(torch.exp(self.rho_b))
    
    def forward(self, x):
        """
        Forward pass with reparameterization trick.
        
        Sample W ~ q(W) = N(ΞΌ, σ²) via W = ΞΌ + Οƒ βŠ™ Ξ΅, Ξ΅ ~ N(0, I)
        """
        # Sample weights
        epsilon_w = torch.randn_like(self.sigma_w)
        w = self.mu_w + self.sigma_w * epsilon_w
        
        # Sample bias
        epsilon_b = torch.randn_like(self.sigma_b)
        b = self.mu_b + self.sigma_b * epsilon_b
        
        # Compute KL divergence: KL(q(W) || p(W))
        # For q = N(ΞΌ, σ²), p = N(0, Οƒ_pΒ²):
        # KL = 0.5 * [σ²/Οƒ_pΒ² + ΞΌΒ²/Οƒ_pΒ² - 1 - log(σ²/Οƒ_pΒ²)]
        prior_var = self.prior_sigma ** 2
        
        kl_w = 0.5 * (
            (self.sigma_w ** 2 + self.mu_w ** 2) / prior_var
            - 1 
            - torch.log(self.sigma_w ** 2 / prior_var)
        ).sum()
        
        kl_b = 0.5 * (
            (self.sigma_b ** 2 + self.mu_b ** 2) / prior_var
            - 1
            - torch.log(self.sigma_b ** 2 / prior_var)
        ).sum()
        
        self.kl_divergence = kl_w + kl_b
        
        return F.linear(x, w, b)


class BayesianMLP(nn.Module):
    """
    Bayesian MLP with variational inference.
    
    Args:
        input_dim: Input dimension
        hidden_dims: List of hidden dimensions
        output_dim: Output dimension
        prior_sigma: Prior standard deviation
    """
    def __init__(self, input_dim, hidden_dims, output_dim, prior_sigma=1.0):
        super(BayesianMLP, self).__init__()
        
        layers = []
        dims = [input_dim] + hidden_dims + [output_dim]
        
        for i in range(len(dims) - 1):
            layers.append(BayesianLinear(dims[i], dims[i+1], prior_sigma))
            if i < len(dims) - 2:  # No activation after last layer
                layers.append(nn.ReLU())
        
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        """Forward pass."""
        for layer in self.layers:
            x = layer(x)
        return x
    
    def kl_divergence(self):
        """Sum KL divergence from all Bayesian layers."""
        kl = 0
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                kl += layer.kl_divergence
        return kl
    
    def predict_with_uncertainty(self, x, num_samples=100):
        """
        Predict with uncertainty quantification.
        
        Args:
            x: Input [batch, input_dim]
            num_samples: Number of posterior samples
        
        Returns:
            mean: Predictive mean [batch, output_dim]
            std: Predictive std (epistemic uncertainty) [batch, output_dim]
        """
        self.eval()
        predictions = []
        
        with torch.no_grad():
            for _ in range(num_samples):
                pred = self.forward(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)  # [num_samples, batch, output_dim]
        
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        
        return mean, std


# ============================================================================
# 2. MC Dropout
# ============================================================================

class MCDropoutMLP(nn.Module):
    """
    MLP with Monte Carlo Dropout for uncertainty.
    
    Args:
        input_dim: Input dimension
        hidden_dims: List of hidden dimensions
        output_dim: Output dimension
        dropout_rate: Dropout probability
    """
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.1):
        super(MCDropoutMLP, self).__init__()
        
        layers = []
        dims = [input_dim] + hidden_dims + [output_dim]
        
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                layers.append(nn.ReLU())
                layers.append(nn.Dropout(p=dropout_rate))
        
        self.network = nn.Sequential(*layers)
        self.dropout_rate = dropout_rate
    
    def forward(self, x):
        """Forward pass."""
        return self.network(x)
    
    def predict_with_uncertainty(self, x, num_samples=100):
        """
        MC Dropout inference.
        
        Keep dropout active at test time to sample from approximate posterior.
        """
        self.train()  # Enable dropout!
        predictions = []
        
        with torch.no_grad():
            for _ in range(num_samples):
                pred = self.forward(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        
        return mean, std


# ============================================================================
# 3. Deep Ensembles
# ============================================================================

class Ensemble(nn.Module):
    """
    Deep ensemble for uncertainty quantification.
    
    Args:
        model_class: Class of base model
        num_models: Number of ensemble members
        **model_kwargs: Arguments for model constructor
    """
    def __init__(self, model_class, num_models, **model_kwargs):
        super(Ensemble, self).__init__()
        
        self.models = nn.ModuleList([
            model_class(**model_kwargs) for _ in range(num_models)
        ])
        self.num_models = num_models
    
    def forward(self, x, model_idx=None):
        """
        Forward pass.
        
        Args:
            x: Input
            model_idx: If specified, use single model. Else, average all.
        """
        if model_idx is not None:
            return self.models[model_idx](x)
        else:
            outputs = [model(x) for model in self.models]
            return torch.stack(outputs).mean(dim=0)
    
    def predict_with_uncertainty(self, x):
        """
        Ensemble prediction with uncertainty.
        
        Returns:
            mean: Average prediction
            std: Ensemble disagreement (uncertainty)
        """
        self.eval()
        predictions = []
        
        with torch.no_grad():
            for model in self.models:
                pred = model(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        
        return mean, std
    
    def train_ensemble(self, train_loader, optimizer_class, num_epochs, device='cpu'):
        """
        Train all ensemble members independently.
        
        Different initializations + stochasticity β†’ diverse models
        """
        optimizers = [optimizer_class(model.parameters()) for model in self.models]
        
        for epoch in range(num_epochs):
            for model_idx, (model, optimizer) in enumerate(zip(self.models, optimizers)):
                model.train()
                
                for batch_x, batch_y in train_loader:
                    batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                    
                    optimizer.zero_grad()
                    output = model(batch_x)
                    loss = F.mse_loss(output, batch_y)
                    loss.backward()
                    optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Model losses: "
                      f"{[F.mse_loss(m(batch_x), batch_y).item() for m in self.models]}")


# ============================================================================
# 4. Heteroscedastic Aleatoric Uncertainty
# ============================================================================

class HeteroscedasticMLP(nn.Module):
    """
    MLP that predicts mean AND variance (aleatoric uncertainty).
    
    Output: (ΞΌ(x), σ²(x))
    Loss: -log N(y|ΞΌ(x), σ²(x)) = (y - ΞΌ)Β²/(2σ²) + log Οƒ
    
    Args:
        input_dim: Input dimension
        hidden_dims: Hidden dimensions
        output_dim: Output dimension
    """
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(HeteroscedasticMLP, self).__init__()
        
        # Shared trunk
        trunk_layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            trunk_layers.append(nn.Linear(dims[i], dims[i+1]))
            trunk_layers.append(nn.ReLU())
        self.trunk = nn.Sequential(*trunk_layers)
        
        # Heads: mean and log_variance
        self.mean_head = nn.Linear(hidden_dims[-1], output_dim)
        self.log_var_head = nn.Linear(hidden_dims[-1], output_dim)
    
    def forward(self, x):
        """
        Returns:
            mean: ΞΌ(x)
            log_var: log σ²(x) (for numerical stability)
        """
        features = self.trunk(x)
        mean = self.mean_head(features)
        log_var = self.log_var_head(features)
        return mean, log_var
    
    def loss(self, x, y):
        """
        Negative log-likelihood loss.
        
        -log N(y|ΞΌ, σ²) = 0.5 * [(y - ΞΌ)Β²/σ² + log σ² + log(2Ο€)]
        """
        mean, log_var = self.forward(x)
        variance = torch.exp(log_var)
        
        # NLL (drop constant term)
        loss = 0.5 * ((y - mean) ** 2 / variance + log_var)
        return loss.mean()
    
    def predict_with_uncertainty(self, x):
        """
        Predict with aleatoric uncertainty.
        
        Returns:
            mean: ΞΌ(x)
            std: Οƒ(x) (aleatoric uncertainty from data noise)
        """
        self.eval()
        with torch.no_grad():
            mean, log_var = self.forward(x)
            std = torch.exp(0.5 * log_var)
        return mean, std


# ============================================================================
# 5. Last-Layer Laplace Approximation
# ============================================================================

class LastLayerLaplace:
    """
    Laplace approximation for last layer only.
    
    Fast post-hoc uncertainty for pretrained models.
    
    Args:
        model: Pretrained neural network
        prior_precision: Prior precision (1/Οƒ_pΒ²)
    """
    def __init__(self, model, prior_precision=1.0):
        self.model = model
        self.prior_precision = prior_precision
        
        # Extract last layer
        self.last_layer = None
        for module in model.modules():
            if isinstance(module, nn.Linear):
                self.last_layer = module
        
        assert self.last_layer is not None, "No linear layer found"
        
        # Hessian approximation
        self.H_inv = None
    
    def fit(self, train_loader, device='cpu'):
        """
        Compute Hessian approximation using Fisher information.
        
        H = Ξ£α΅’ βˆ‡log p(yα΅’|xα΅’) βˆ‡log p(yα΅’|xα΅’)α΅€ + Ξ»I
        
        For regression: H β‰ˆ Xα΅€ X + Ξ»I (Gauss-Newton)
        """
        self.model.eval()
        
        # Collect features and targets
        features_list = []
        
        with torch.no_grad():
            for batch_x, batch_y in train_loader:
                batch_x = batch_x.to(device)
                
                # Forward to last layer
                for module in self.model.modules():
                    if module == self.last_layer:
                        break
                    if isinstance(module, nn.Module):
                        batch_x = module(batch_x)
                
                features_list.append(batch_x)
        
        features = torch.cat(features_list, dim=0)  # [N, hidden_dim]
        
        # Compute H = X^T X + Ξ»I
        N = features.size(0)
        H = features.T @ features / N + self.prior_precision * torch.eye(features.size(1))
        
        # Invert (use Cholesky for stability)
        try:
            L = torch.linalg.cholesky(H)
            self.H_inv = torch.cholesky_inverse(L)
        except:
            # Fallback to direct inverse
            self.H_inv = torch.inverse(H)
        
        print(f"Fitted Laplace approximation. Hessian shape: {H.shape}")
    
    def predict_with_uncertainty(self, x, num_samples=100):
        """
        Predictive distribution via linearization.
        
        p(f*|x*, D) β‰ˆ N(f_MAP(x*), βˆ‡f_MAP(x*)α΅€ H⁻¹ βˆ‡f_MAP(x*))
        """
        self.model.eval()
        
        with torch.no_grad():
            # Extract features
            features = x
            for module in self.model.modules():
                if module == self.last_layer:
                    break
                if isinstance(module, nn.Module):
                    features = module(features)
            
            # MAP prediction
            mean = self.last_layer(features)
            
            # Predictive variance
            J = features  # Jacobian (for linear layer)
            variance = (J @ self.H_inv @ J.T).diagonal().unsqueeze(1)
            std = torch.sqrt(variance)
        
        return mean, std


# ============================================================================
# 6. Calibration Metrics
# ============================================================================

def expected_calibration_error(y_true, y_pred, y_conf, num_bins=10):
    """
    Compute Expected Calibration Error (ECE).
    
    Args:
        y_true: True labels [N]
        y_pred: Predicted labels [N]
        y_conf: Predicted confidence [N]
        num_bins: Number of bins
    
    Returns:
        ece: Expected calibration error
    """
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    ece = 0.0
    
    for i in range(num_bins):
        # Samples in bin
        mask = (y_conf >= bin_boundaries[i]) & (y_conf < bin_boundaries[i + 1])
        
        if mask.sum() > 0:
            bin_acc = (y_true[mask] == y_pred[mask]).float().mean()
            bin_conf = y_conf[mask].mean()
            bin_size = mask.sum().float()
            
            ece += (bin_size / len(y_true)) * torch.abs(bin_acc - bin_conf)
    
    return ece.item()


def plot_reliability_diagram(y_true, y_pred, y_conf, num_bins=10):
    """
    Plot reliability diagram (calibration curve).
    
    Ideal: Points lie on diagonal (confidence = accuracy)
    """
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_accs = []
    bin_confs = []
    
    for i in range(num_bins):
        mask = (y_conf >= bin_boundaries[i]) & (y_conf < bin_boundaries[i + 1])
        
        if mask.sum() > 0:
            bin_acc = (y_true[mask] == y_pred[mask]).float().mean().item()
            bin_conf = y_conf[mask].mean().item()
            bin_accs.append(bin_acc)
            bin_confs.append(bin_conf)
        else:
            bin_accs.append(0)
            bin_confs.append((bin_boundaries[i] + bin_boundaries[i + 1]) / 2)
    
    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    plt.plot(bin_confs, bin_accs, 'o-', label='Model')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title('Reliability Diagram')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


# ============================================================================
# 7. Demonstrations
# ============================================================================

def demo_bayes_by_backprop():
    """Demonstrate Bayes by Backprop."""
    print("="*70)
    print("Bayes by Backprop Demo")
    print("="*70)
    
    # Create Bayesian MLP
    model = BayesianMLP(input_dim=10, hidden_dims=[20, 20], output_dim=1, prior_sigma=1.0)
    
    # Sample data
    x = torch.randn(32, 10)
    y = torch.randn(32, 1)
    
    # ELBO loss
    output = model(x)
    likelihood = F.mse_loss(output, y, reduction='sum')
    kl = model.kl_divergence()
    
    # ELBO = -likelihood + KL
    elbo = likelihood + kl / len(x)  # Scale KL by batch size
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Likelihood term: {likelihood.item():.4f}")
    print(f"KL divergence: {kl.item():.4f}")
    print(f"ELBO (loss): {elbo.item():.4f}")
    print()
    
    # Uncertainty quantification
    x_test = torch.randn(5, 10)
    mean, std = model.predict_with_uncertainty(x_test, num_samples=100)
    
    print("Predictions with uncertainty:")
    for i in range(5):
        print(f"  Sample {i+1}: {mean[i, 0].item():.4f} Β± {std[i, 0].item():.4f}")
    print()
    
    # Count parameters
    params_mu = sum(p.numel() for n, p in model.named_parameters() if 'mu' in n)
    params_rho = sum(p.numel() for n, p in model.named_parameters() if 'rho' in n)
    print(f"Parameters:")
    print(f"  Mean (ΞΌ): {params_mu:,}")
    print(f"  Variance (ρ): {params_rho:,}")
    print(f"  Total: {params_mu + params_rho:,}")
    print()


def demo_mc_dropout():
    """Demonstrate MC Dropout."""
    print("="*70)
    print("MC Dropout Demo")
    print("="*70)
    
    # Create MC Dropout model
    model = MCDropoutMLP(input_dim=10, hidden_dims=[20, 20], output_dim=1, dropout_rate=0.1)
    
    x_test = torch.randn(5, 10)
    
    # Standard prediction (dropout off)
    model.eval()
    with torch.no_grad():
        std_pred = model(x_test)
    
    # MC Dropout prediction (dropout on)
    mean, std = model.predict_with_uncertainty(x_test, num_samples=100)
    
    print("Standard vs. MC Dropout predictions:")
    for i in range(5):
        print(f"  Sample {i+1}: Standard={std_pred[i, 0].item():.4f}, "
              f"MC={mean[i, 0].item():.4f} Β± {std[i, 0].item():.4f}")
    print()
    
    print("Key insight: MC Dropout provides uncertainty at zero training cost")
    print("Interpretation: Each dropout mask ~ sample from approximate posterior")
    print()


def demo_ensemble():
    """Demonstrate deep ensemble."""
    print("="*70)
    print("Deep Ensemble Demo")
    print("="*70)
    
    # Define simple MLP
    class SimpleMLP(nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
        
        def forward(self, x):
            return self.net(x)
    
    # Create ensemble
    ensemble = Ensemble(SimpleMLP, num_models=5, input_dim=10, hidden_dim=20, output_dim=1)
    
    x_test = torch.randn(5, 10)
    
    # Individual model predictions
    print("Individual model predictions:")
    for i in range(ensemble.num_models):
        ensemble.eval()
        with torch.no_grad():
            pred = ensemble.forward(x_test, model_idx=i)
        print(f"  Model {i+1}: {pred[:3, 0].tolist()}")
    print()
    
    # Ensemble prediction
    mean, std = ensemble.predict_with_uncertainty(x_test)
    print("Ensemble prediction:")
    for i in range(3):
        print(f"  Sample {i+1}: {mean[i, 0].item():.4f} Β± {std[i, 0].item():.4f}")
    print()
    
    print("Advantage: Ensemble captures model uncertainty via disagreement")
    print("Cost: Train M models independently")
    print()


def demo_heteroscedastic():
    """Demonstrate heteroscedastic aleatoric uncertainty."""
    print("="*70)
    print("Heteroscedastic Aleatoric Uncertainty Demo")
    print("="*70)
    
    # Create heteroscedastic model
    model = HeteroscedasticMLP(input_dim=1, hidden_dims=[50, 50], output_dim=1)
    
    # Synthetic data: y = x + noise, noise increases with |x|
    x_train = torch.linspace(-3, 3, 100).unsqueeze(1)
    noise_std = 0.1 + 0.5 * torch.abs(x_train)
    y_train = x_train + noise_std * torch.randn_like(x_train)
    
    # Train
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(500):
        optimizer.zero_grad()
        loss = model.loss(x_train, y_train)
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
    print()
    
    # Predict
    x_test = torch.linspace(-4, 4, 50).unsqueeze(1)
    mean, std = model.predict_with_uncertainty(x_test)
    
    print("Predictions with aleatoric uncertainty:")
    for i in range(0, 50, 10):
        print(f"  x={x_test[i, 0].item():.2f}: y={mean[i, 0].item():.4f} Β± {std[i, 0].item():.4f}")
    print()
    
    print("Key insight: Uncertainty increases with |x| (captures heteroscedastic noise)")
    print()


def print_method_comparison():
    """Print comparison of BNN methods."""
    print("="*70)
    print("Bayesian Neural Network Method Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method              β”‚ Training     β”‚ Inference    β”‚ Uncertainty β”‚ Best For     β”‚
β”‚                     β”‚ Cost         β”‚ Cost         β”‚ Type        β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Bayes by Backprop   β”‚ ~2Γ— standard β”‚ TΓ— forward   β”‚ Epistemic   β”‚ Principled   β”‚
β”‚                     β”‚ (sample ΞΈ)   β”‚ (sample ΞΈ)   β”‚             β”‚ Bayesian     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ MC Dropout          β”‚ 1Γ— standard  β”‚ TΓ— forward   β”‚ Epistemic   β”‚ Pretrained   β”‚
β”‚                     β”‚ (just drop.) β”‚ (drop. on)   β”‚ (approx)    β”‚ models       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Deep Ensembles      β”‚ MΓ— standard  β”‚ MΓ— forward   β”‚ Epistemic   β”‚ Best         β”‚
β”‚                     β”‚ (M models)   β”‚              β”‚ (diversity) β”‚ performance  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Laplace (last lay.) β”‚ 1Γ— standard  β”‚ 1Γ— forward   β”‚ Epistemic   β”‚ Post-hoc,    β”‚
β”‚                     β”‚ + Hessian    β”‚ + covariance β”‚             β”‚ fast         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Heteroscedastic     β”‚ 1Γ— standard  β”‚ 1Γ— forward   β”‚ Aleatoric   β”‚ Noisy data   β”‚
β”‚                     β”‚              β”‚              β”‚             β”‚              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ SGLD                β”‚ 1Γ— standard  β”‚ TΓ— forward   β”‚ Epistemic   β”‚ Exact        β”‚
β”‚                     β”‚ + Langevin   β”‚ (MCMC)       β”‚ (exact)     β”‚ (research)   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Epistemic vs. Aleatoric:**

- **Epistemic**: Model uncertainty (lack of data)
  - Reducible: More data β†’ less uncertainty
  - Captured by: Weight distributions p(ΞΈ|D)
  
- **Aleatoric**: Data noise (inherent randomness)
  - Irreducible: More data doesn't help
  - Captured by: Predictive variance p(y|x, ΞΈ)

**Total uncertainty** = Epistemic + Aleatoric

**Decision guide:**

1. **Use MC Dropout if:**
   - Have pretrained model
   - Need quick uncertainty estimate
   - Limited computational budget

2. **Use Bayes by Backprop if:**
   - Training from scratch
   - Want principled Bayesian inference
   - Can afford ~2Γ— training cost

3. **Use Deep Ensembles if:**
   - Need best empirical performance
   - Can afford MΓ— training (M=5-10)
   - Want diverse predictions

4. **Use Last-Layer Laplace if:**
   - Have pretrained model
   - Want post-hoc uncertainty
   - Need fast inference

5. **Use Heteroscedastic if:**
   - Data has varying noise levels
   - Need aleatoric uncertainty
   - Input-dependent noise

**Performance (NLL on regression):**

- Standard NN: 1.21
- MC Dropout: 0.98
- Bayes by Backprop: 0.91
- Deep Ensemble (M=5): 0.88
"""
    
    print(comparison)
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    
    demo_bayes_by_backprop()
    demo_mc_dropout()
    demo_ensemble()
    demo_heteroscedastic()
    print_method_comparison()
    
    print("="*70)
    print("Bayesian Neural Networks Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ Bayes by Backprop: Variational inference with weight distributions")
    print("  β€’ MC Dropout: Dropout as Bayesian approximation (zero training cost)")
    print("  β€’ Deep Ensembles: M independent models for diversity")
    print("  β€’ Heteroscedastic: Model aleatoric uncertainty σ²(x)")
    print("  β€’ Last-Layer Laplace: Post-hoc Gaussian approximation")
    print()
    print("Key insight: BNNs provide uncertainty quantification")
    print("Applications: Active learning, OOD detection, safety-critical systems")
    print("Trade-off: Computational cost vs. uncertainty quality")
    print()