import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.special import digamma, gammaln
import torch
import torch.nn as nn
import torch.optim as optim

sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)

1. Motivation: The Intractability Problemยถ

Bayesian Inference Goalยถ

Given data \(x\) and latent variables \(z\), we want the posterior: $\(p(z | x) = \frac{p(x, z)}{p(x)} = \frac{p(x | z) p(z)}{\int p(x, z) dz}\)$

The Problemยถ

The marginal likelihood (evidence) is intractable: $\(p(x) = \int p(x, z) dz\)$

This integral is exponentially hard in high dimensions!

Classical Solutionsยถ

  1. MCMC (Markov Chain Monte Carlo)

    • Sampling-based

    • Slow, hard to diagnose convergence

  2. Variational Inference (VI)

    • Optimization-based

    • Fast, scalable

    • Approximate but deterministic

Variational Inference Ideaยถ

  1. Choose a family of distributions \(\mathcal{Q}\)

  2. Find \(q^*(z) \in \mathcal{Q}\) closest to \(p(z | x)\)

  3. Use \(q^*(z)\) as a proxy for the true posterior

Key insight: Turn inference into optimization!

2. Evidence Lower Bound (ELBO)ยถ

KL Divergence Minimizationยถ

Objective: $\(q^*(z) = \arg\min_{q \in \mathcal{Q}} KL(q(z) \| p(z | x))\)$

Derivationยถ

\[\begin{split}\begin{align} KL(q(z) \| p(z | x)) &= \mathbb{E}_q[\log q(z)] - \mathbb{E}_q[\log p(z | x)] \\ &= \mathbb{E}_q[\log q(z)] - \mathbb{E}_q[\log p(x, z)] + \log p(x) \end{align}\end{split}\]

Since \(\log p(x)\) is constant w.r.t. \(q\): $\(\min_q KL(q \| p) \iff \max_q \underbrace{\mathbb{E}_q[\log p(x, z)] - \mathbb{E}_q[\log q(z)]}_{\text{ELBO}(q)}\)$

Evidence Lower Boundยถ

\[\mathcal{L}(q) = \mathbb{E}_q[\log p(x, z)] - \mathbb{E}_q[\log q(z)]\]

Key property: $\(\log p(x) = \mathcal{L}(q) + KL(q \| p)\)$

Since \(KL \geq 0\): $\(\mathcal{L}(q) \leq \log p(x)\)$

Alternative Formsยถ

\[\mathcal{L}(q) = \mathbb{E}_q[\log p(x | z)] - KL(q(z) \| p(z))\]

Interpretation:

  • Reconstruction: How well does \(z\) explain \(x\)?

  • Regularization: Stay close to prior \(p(z)\)

3. Mean-Field Approximationยถ

Variational Familyยถ

Assume latent variables factorize: $\(q(z) = \prod_{i=1}^m q_i(z_i)\)$

Each \(q_i\) is called a variational factor.

Coordinate Ascent Variational Inference (CAVI)ยถ

Optimize each factor holding others fixed:

\[\log q_j^*(z_j) = \mathbb{E}_{-j}[\log p(x, z)] + \text{const}\]

where \(\mathbb{E}_{-j}\) means expectation over all \(z_i\) except \(z_j\).

Key Resultยถ

\[q_j^*(z_j) \propto \exp\left(\mathbb{E}_{-j}[\log p(x, z)]\right)\]

This is the exponential family form!

4. Example: Bayesian Gaussian Mixture (Simple Case)ยถ

Modelยถ

Data: \(x_1, \ldots, x_n \in \mathbb{R}\)

Generative process:

  1. Choose cluster \(z_i \sim \text{Categorical}(\pi)\)

  2. Generate \(x_i \sim \mathcal{N}(\mu_{z_i}, \sigma^2)\)

Priors:

  • \(\pi \sim \text{Dirichlet}(\alpha)\)

  • \(\mu_k \sim \mathcal{N}(0, \tau^2)\) for \(k = 1, 2\)

Variational Approximationยถ

\[q(z, \pi, \mu) = q(\pi) \prod_{k=1}^2 q(\mu_k) \prod_{i=1}^n q(z_i)\]
  • \(q(\pi) = \text{Dirichlet}(\alpha')\)

  • \(q(\mu_k) = \mathcal{N}(m_k, s_k^2)\)

  • \(q(z_i) = \text{Categorical}(\phi_i)\)

class BayesianGMMVariational:
    """Variational inference for 2-component Gaussian mixture."""
    def __init__(self, n_components=2, alpha_prior=1.0, tau_prior=10.0, sigma=1.0):
        self.K = n_components
        self.alpha_prior = alpha_prior
        self.tau_prior = tau_prior
        self.sigma = sigma
        
    def fit(self, X, max_iter=100, tol=1e-4):
        n = len(X)
        
        # Initialize variational parameters
        self.alpha = np.ones(self.K) * self.alpha_prior  # Dirichlet parameter
        self.m = np.random.randn(self.K)  # Mean of q(mu_k)
        self.s2 = np.ones(self.K)  # Variance of q(mu_k)
        self.phi = np.random.dirichlet(np.ones(self.K), n)  # q(z_i)
        
        elbo_history = []
        
        for iteration in range(max_iter):
            # Update q(z_i) - responsibilities
            log_pi = digamma(self.alpha) - digamma(self.alpha.sum())
            
            for i in range(n):
                for k in range(self.K):
                    # Expected log likelihood
                    log_lik = -0.5 * np.log(2 * np.pi * self.sigma**2)
                    log_lik -= 0.5 / self.sigma**2 * (X[i]**2 - 2*X[i]*self.m[k] + self.m[k]**2 + self.s2[k])
                    self.phi[i, k] = log_pi[k] + log_lik
                
                # Normalize
                self.phi[i] -= self.phi[i].max()  # Numerical stability
                self.phi[i] = np.exp(self.phi[i])
                self.phi[i] /= self.phi[i].sum()
            
            # Update q(mu_k)
            N_k = self.phi.sum(axis=0)
            for k in range(self.K):
                self.s2[k] = 1.0 / (1.0/self.tau_prior**2 + N_k[k]/self.sigma**2)
                self.m[k] = self.s2[k] * (self.phi[:, k] @ X) / self.sigma**2
            
            # Update q(pi)
            self.alpha = self.alpha_prior + N_k
            
            # Compute ELBO
            elbo = self._compute_elbo(X)
            elbo_history.append(elbo)
            
            if iteration > 0 and abs(elbo - elbo_history[-2]) < tol:
                print(f"Converged at iteration {iteration}")
                break
        
        self.elbo_history = elbo_history
        return self
    
    def _compute_elbo(self, X):
        """Compute evidence lower bound."""
        n = len(X)
        elbo = 0
        
        # E[log p(X|z, mu)]
        for i in range(n):
            for k in range(self.K):
                log_lik = -0.5 * np.log(2 * np.pi * self.sigma**2)
                log_lik -= 0.5 / self.sigma**2 * (X[i]**2 - 2*X[i]*self.m[k] + self.m[k]**2 + self.s2[k])
                elbo += self.phi[i, k] * log_lik
        
        # E[log p(z|pi)]
        log_pi = digamma(self.alpha) - digamma(self.alpha.sum())
        elbo += (self.phi * log_pi).sum()
        
        # E[log p(pi)] - E[log q(pi)]
        elbo += gammaln(self.K * self.alpha_prior) - self.K * gammaln(self.alpha_prior)
        elbo += ((self.alpha_prior - 1) * log_pi).sum()
        elbo -= gammaln(self.alpha.sum()) - gammaln(self.alpha).sum()
        elbo -= ((self.alpha - 1) * log_pi).sum()
        
        # E[log p(mu)] - E[log q(mu)]
        for k in range(self.K):
            elbo -= 0.5 * (np.log(2*np.pi*self.tau_prior**2) + (self.m[k]**2 + self.s2[k])/self.tau_prior**2)
            elbo += 0.5 * (1 + np.log(2*np.pi*self.s2[k]))
        
        # -E[log q(z)]
        elbo -= (self.phi * np.log(self.phi + 1e-10)).sum()
        
        return elbo

# Generate data from 2-component mixture
np.random.seed(42)
n_samples = 200
true_mu = np.array([-3, 3])
true_pi = np.array([0.3, 0.7])

z_true = np.random.choice(2, n_samples, p=true_pi)
X = true_mu[z_true] + np.random.randn(n_samples)

# Fit variational model
model = BayesianGMMVariational(n_components=2, sigma=1.0)
model.fit(X, max_iter=100)

print(f"\nTrue means: {true_mu}")
print(f"Estimated means: {model.m}")
print(f"True mixing: {true_pi}")
print(f"Estimated mixing: {model.alpha / model.alpha.sum()}")
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# ELBO convergence
axes[0].plot(model.elbo_history, linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('ELBO', fontsize=12)
axes[0].set_title('ELBO Convergence', fontsize=13)
axes[0].grid(True, alpha=0.3)

# Data and fitted components
axes[1].hist(X, bins=30, density=True, alpha=0.5, label='Data')
x_range = np.linspace(X.min(), X.max(), 200)
for k in range(2):
    weight = model.alpha[k] / model.alpha.sum()
    component = weight * stats.norm.pdf(x_range, model.m[k], np.sqrt(model.s2[k] + model.sigma**2))
    axes[1].plot(x_range, component, linewidth=2, label=f'Component {k+1}')
axes[1].set_xlabel('x', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].set_title('Fitted Mixture', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Responsibilities
sort_idx = np.argsort(X)
axes[2].scatter(X[sort_idx], model.phi[sort_idx, 0], alpha=0.6, s=20, label='Resp. Component 1')
axes[2].scatter(X[sort_idx], model.phi[sort_idx, 1], alpha=0.6, s=20, label='Resp. Component 2')
axes[2].set_xlabel('x', fontsize=12)
axes[2].set_ylabel('Responsibility', fontsize=12)
axes[2].set_title('Variational Responsibilities', fontsize=13)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

5. Black Box Variational Inference (BBVI)ยถ

Limitation of CAVIยถ

  • Requires conjugacy

  • Model-specific derivations

  • Not scalable to complex models

Solution: Gradient-Based VIยถ

Parameterize \(q_{\lambda}(z)\) and optimize: $\(\max_{\lambda} \mathcal{L}(\lambda) = \mathbb{E}_{q_{\lambda}}[\log p(x, z) - \log q_{\lambda}(z)]\)$

Gradient Estimationยถ

Score function (REINFORCE): $\(\nabla_{\lambda} \mathcal{L} = \mathbb{E}_{q_{\lambda}}[(\log p(x, z) - \log q_{\lambda}(z)) \nabla_{\lambda} \log q_{\lambda}(z)]\)$

Reparameterization trick (when applicable): $\(z = g(\epsilon, \lambda), \quad \epsilon \sim p(\epsilon)\)\( \)\(\nabla_{\lambda} \mathcal{L} = \mathbb{E}_{p(\epsilon)}[\nabla_{\lambda}(\log p(x, g(\epsilon, \lambda)) - \log q_{\lambda}(g(\epsilon, \lambda)))]\)$

class BBVIGaussian(nn.Module):
    """Black box VI with Gaussian approximation using reparameterization."""
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.mu = nn.Parameter(torch.randn(latent_dim))
        self.log_sigma = nn.Parameter(torch.zeros(latent_dim))
    
    def sample(self, n_samples=1):
        """Reparameterization: z = mu + sigma * epsilon."""
        epsilon = torch.randn(n_samples, self.latent_dim)
        return self.mu + torch.exp(self.log_sigma) * epsilon
    
    def log_prob(self, z):
        """Log probability of q(z)."""
        sigma = torch.exp(self.log_sigma)
        log_prob = -0.5 * torch.sum(
            torch.log(2 * np.pi * sigma**2) + ((z - self.mu) / sigma)**2,
            dim=-1
        )
        return log_prob

def log_joint(z, X):
    """Log p(X, z) for simple Bayesian linear regression.
    
    Model: y = z^T x + noise, z ~ N(0, I)
    """
    # Prior: p(z) = N(0, I)
    log_prior = -0.5 * torch.sum(z**2, dim=-1)
    
    # Likelihood: p(X|z) for linear regression
    # Here X is (features, targets) pair
    features, targets = X
    predictions = features @ z.T
    log_lik = -0.5 * torch.sum((targets.unsqueeze(-1) - predictions)**2, dim=0)
    
    return log_prior + log_lik

# Generate regression data
torch.manual_seed(42)
n_data = 100
d = 5
X_train = torch.randn(n_data, d)
true_w = torch.randn(d)
y_train = X_train @ true_w + 0.5 * torch.randn(n_data)

# BBVI
q = BBVIGaussian(d)
optimizer = optim.Adam(q.parameters(), lr=0.01)

elbo_list = []
for epoch in range(1000):
    optimizer.zero_grad()
    
    # Sample from q
    z_samples = q.sample(n_samples=10)
    
    # ELBO = E[log p(X, z)] - E[log q(z)]
    log_joint_vals = log_joint(z_samples, (X_train, y_train))
    log_q_vals = q.log_prob(z_samples)
    elbo = (log_joint_vals - log_q_vals).mean()
    
    # Maximize ELBO (minimize negative ELBO)
    loss = -elbo
    loss.backward()
    optimizer.step()
    
    elbo_list.append(elbo.item())
    
    if epoch % 200 == 0:
        print(f"Epoch {epoch}: ELBO = {elbo.item():.2f}")

print(f"\nTrue weights: {true_w.numpy()}")
print(f"Estimated weights (mean): {q.mu.detach().numpy()}")
print(f"Estimated uncertainty (std): {torch.exp(q.log_sigma).detach().numpy()}")
# Plot BBVI convergence
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(elbo_list, linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('ELBO', fontsize=12)
axes[0].set_title('Black Box VI Convergence', fontsize=13)
axes[0].grid(True, alpha=0.3)

# Compare true vs estimated weights
x_pos = np.arange(d)
axes[1].bar(x_pos - 0.2, true_w.numpy(), 0.4, label='True', alpha=0.7)
axes[1].bar(x_pos + 0.2, q.mu.detach().numpy(), 0.4, label='Estimated', alpha=0.7)
axes[1].errorbar(x_pos + 0.2, q.mu.detach().numpy(), 
                 yerr=torch.exp(q.log_sigma).detach().numpy(),
                 fmt='none', ecolor='black', capsize=5)
axes[1].set_xlabel('Weight Index', fontsize=12)
axes[1].set_ylabel('Value', fontsize=12)
axes[1].set_title('Weight Comparison', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

Summaryยถ

Key Concepts:ยถ

  1. Variational Inference: Approximate intractable posteriors via optimization

  2. ELBO: Evidence lower bound, objective function for VI

  3. Mean-field: Assume factorized posterior for tractability

  4. CAVI: Coordinate ascent for conjugate models

  5. BBVI: Gradient-based VI for general models

ELBO Forms:ยถ

\[\mathcal{L} = \mathbb{E}_q[\log p(x, z)] - \mathbb{E}_q[\log q(z)]\]
\[= \mathbb{E}_q[\log p(x | z)] - KL(q(z) \| p(z))\]
\[= \log p(x) - KL(q(z) \| p(z | x))\]

Advantages of VI:ยถ

  • Fast: Optimization vs sampling

  • Scalable: Stochastic gradients for big data

  • Deterministic: No MCMC diagnostics

  • Flexible: Any differentiable model

Limitations:ยถ

  • Approximation: Not exact (underestimates uncertainty)

  • Local optima: Optimization challenges

  • Family choice: Restricted to \(\mathcal{Q}\)

Modern Extensions:ยถ

  • Normalizing flows: Expressive variational families

  • Importance weighted ELBO: Tighter bounds

  • Amortized inference: Neural networks for \(q\)

  • Stochastic VI: Mini-batch optimization

Applications:ยถ

  • Variational autoencoders (VAEs)

  • Bayesian neural networks

  • Topic models (LDA)

  • Gaussian processes

Further Reading:ยถ

  • Blei et al. (2017) - โ€œVariational Inference: A Reviewโ€

  • Jordan et al. (1999) - โ€œIntroduction to Variational Methodsโ€

  • Kingma & Welling (2014) - โ€œAuto-Encoding Variational Bayesโ€

Next Steps:ยถ

  • 03_variational_autoencoders_advanced.ipynb - VAE with ELBO

  • 04_pac_bayes_theory.ipynb - PAC-Bayesian framework

  • 09_expectation_maximization.ipynb - EM algorithm