import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.special import digamma, gammaln
import torch
import torch.nn as nn
import torch.optim as optim
sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)
1. Motivation: The Intractability Problemยถ
Bayesian Inference Goalยถ
Given data \(x\) and latent variables \(z\), we want the posterior: $\(p(z | x) = \frac{p(x, z)}{p(x)} = \frac{p(x | z) p(z)}{\int p(x, z) dz}\)$
The Problemยถ
The marginal likelihood (evidence) is intractable: $\(p(x) = \int p(x, z) dz\)$
This integral is exponentially hard in high dimensions!
Classical Solutionsยถ
MCMC (Markov Chain Monte Carlo)
Sampling-based
Slow, hard to diagnose convergence
Variational Inference (VI)
Optimization-based
Fast, scalable
Approximate but deterministic
Variational Inference Ideaยถ
Choose a family of distributions \(\mathcal{Q}\)
Find \(q^*(z) \in \mathcal{Q}\) closest to \(p(z | x)\)
Use \(q^*(z)\) as a proxy for the true posterior
Key insight: Turn inference into optimization!
2. Evidence Lower Bound (ELBO)ยถ
KL Divergence Minimizationยถ
Objective: $\(q^*(z) = \arg\min_{q \in \mathcal{Q}} KL(q(z) \| p(z | x))\)$
Derivationยถ
Since \(\log p(x)\) is constant w.r.t. \(q\): $\(\min_q KL(q \| p) \iff \max_q \underbrace{\mathbb{E}_q[\log p(x, z)] - \mathbb{E}_q[\log q(z)]}_{\text{ELBO}(q)}\)$
Evidence Lower Boundยถ
Key property: $\(\log p(x) = \mathcal{L}(q) + KL(q \| p)\)$
Since \(KL \geq 0\): $\(\mathcal{L}(q) \leq \log p(x)\)$
Alternative Formsยถ
Interpretation:
Reconstruction: How well does \(z\) explain \(x\)?
Regularization: Stay close to prior \(p(z)\)
3. Mean-Field Approximationยถ
Variational Familyยถ
Assume latent variables factorize: $\(q(z) = \prod_{i=1}^m q_i(z_i)\)$
Each \(q_i\) is called a variational factor.
Coordinate Ascent Variational Inference (CAVI)ยถ
Optimize each factor holding others fixed:
where \(\mathbb{E}_{-j}\) means expectation over all \(z_i\) except \(z_j\).
Key Resultยถ
This is the exponential family form!
4. Example: Bayesian Gaussian Mixture (Simple Case)ยถ
Modelยถ
Data: \(x_1, \ldots, x_n \in \mathbb{R}\)
Generative process:
Choose cluster \(z_i \sim \text{Categorical}(\pi)\)
Generate \(x_i \sim \mathcal{N}(\mu_{z_i}, \sigma^2)\)
Priors:
\(\pi \sim \text{Dirichlet}(\alpha)\)
\(\mu_k \sim \mathcal{N}(0, \tau^2)\) for \(k = 1, 2\)
Variational Approximationยถ
\(q(\pi) = \text{Dirichlet}(\alpha')\)
\(q(\mu_k) = \mathcal{N}(m_k, s_k^2)\)
\(q(z_i) = \text{Categorical}(\phi_i)\)
class BayesianGMMVariational:
"""Variational inference for 2-component Gaussian mixture."""
def __init__(self, n_components=2, alpha_prior=1.0, tau_prior=10.0, sigma=1.0):
self.K = n_components
self.alpha_prior = alpha_prior
self.tau_prior = tau_prior
self.sigma = sigma
def fit(self, X, max_iter=100, tol=1e-4):
n = len(X)
# Initialize variational parameters
self.alpha = np.ones(self.K) * self.alpha_prior # Dirichlet parameter
self.m = np.random.randn(self.K) # Mean of q(mu_k)
self.s2 = np.ones(self.K) # Variance of q(mu_k)
self.phi = np.random.dirichlet(np.ones(self.K), n) # q(z_i)
elbo_history = []
for iteration in range(max_iter):
# Update q(z_i) - responsibilities
log_pi = digamma(self.alpha) - digamma(self.alpha.sum())
for i in range(n):
for k in range(self.K):
# Expected log likelihood
log_lik = -0.5 * np.log(2 * np.pi * self.sigma**2)
log_lik -= 0.5 / self.sigma**2 * (X[i]**2 - 2*X[i]*self.m[k] + self.m[k]**2 + self.s2[k])
self.phi[i, k] = log_pi[k] + log_lik
# Normalize
self.phi[i] -= self.phi[i].max() # Numerical stability
self.phi[i] = np.exp(self.phi[i])
self.phi[i] /= self.phi[i].sum()
# Update q(mu_k)
N_k = self.phi.sum(axis=0)
for k in range(self.K):
self.s2[k] = 1.0 / (1.0/self.tau_prior**2 + N_k[k]/self.sigma**2)
self.m[k] = self.s2[k] * (self.phi[:, k] @ X) / self.sigma**2
# Update q(pi)
self.alpha = self.alpha_prior + N_k
# Compute ELBO
elbo = self._compute_elbo(X)
elbo_history.append(elbo)
if iteration > 0 and abs(elbo - elbo_history[-2]) < tol:
print(f"Converged at iteration {iteration}")
break
self.elbo_history = elbo_history
return self
def _compute_elbo(self, X):
"""Compute evidence lower bound."""
n = len(X)
elbo = 0
# E[log p(X|z, mu)]
for i in range(n):
for k in range(self.K):
log_lik = -0.5 * np.log(2 * np.pi * self.sigma**2)
log_lik -= 0.5 / self.sigma**2 * (X[i]**2 - 2*X[i]*self.m[k] + self.m[k]**2 + self.s2[k])
elbo += self.phi[i, k] * log_lik
# E[log p(z|pi)]
log_pi = digamma(self.alpha) - digamma(self.alpha.sum())
elbo += (self.phi * log_pi).sum()
# E[log p(pi)] - E[log q(pi)]
elbo += gammaln(self.K * self.alpha_prior) - self.K * gammaln(self.alpha_prior)
elbo += ((self.alpha_prior - 1) * log_pi).sum()
elbo -= gammaln(self.alpha.sum()) - gammaln(self.alpha).sum()
elbo -= ((self.alpha - 1) * log_pi).sum()
# E[log p(mu)] - E[log q(mu)]
for k in range(self.K):
elbo -= 0.5 * (np.log(2*np.pi*self.tau_prior**2) + (self.m[k]**2 + self.s2[k])/self.tau_prior**2)
elbo += 0.5 * (1 + np.log(2*np.pi*self.s2[k]))
# -E[log q(z)]
elbo -= (self.phi * np.log(self.phi + 1e-10)).sum()
return elbo
# Generate data from 2-component mixture
np.random.seed(42)
n_samples = 200
true_mu = np.array([-3, 3])
true_pi = np.array([0.3, 0.7])
z_true = np.random.choice(2, n_samples, p=true_pi)
X = true_mu[z_true] + np.random.randn(n_samples)
# Fit variational model
model = BayesianGMMVariational(n_components=2, sigma=1.0)
model.fit(X, max_iter=100)
print(f"\nTrue means: {true_mu}")
print(f"Estimated means: {model.m}")
print(f"True mixing: {true_pi}")
print(f"Estimated mixing: {model.alpha / model.alpha.sum()}")
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# ELBO convergence
axes[0].plot(model.elbo_history, linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('ELBO', fontsize=12)
axes[0].set_title('ELBO Convergence', fontsize=13)
axes[0].grid(True, alpha=0.3)
# Data and fitted components
axes[1].hist(X, bins=30, density=True, alpha=0.5, label='Data')
x_range = np.linspace(X.min(), X.max(), 200)
for k in range(2):
weight = model.alpha[k] / model.alpha.sum()
component = weight * stats.norm.pdf(x_range, model.m[k], np.sqrt(model.s2[k] + model.sigma**2))
axes[1].plot(x_range, component, linewidth=2, label=f'Component {k+1}')
axes[1].set_xlabel('x', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].set_title('Fitted Mixture', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Responsibilities
sort_idx = np.argsort(X)
axes[2].scatter(X[sort_idx], model.phi[sort_idx, 0], alpha=0.6, s=20, label='Resp. Component 1')
axes[2].scatter(X[sort_idx], model.phi[sort_idx, 1], alpha=0.6, s=20, label='Resp. Component 2')
axes[2].set_xlabel('x', fontsize=12)
axes[2].set_ylabel('Responsibility', fontsize=12)
axes[2].set_title('Variational Responsibilities', fontsize=13)
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
5. Black Box Variational Inference (BBVI)ยถ
Limitation of CAVIยถ
Requires conjugacy
Model-specific derivations
Not scalable to complex models
Solution: Gradient-Based VIยถ
Parameterize \(q_{\lambda}(z)\) and optimize: $\(\max_{\lambda} \mathcal{L}(\lambda) = \mathbb{E}_{q_{\lambda}}[\log p(x, z) - \log q_{\lambda}(z)]\)$
Gradient Estimationยถ
Score function (REINFORCE): $\(\nabla_{\lambda} \mathcal{L} = \mathbb{E}_{q_{\lambda}}[(\log p(x, z) - \log q_{\lambda}(z)) \nabla_{\lambda} \log q_{\lambda}(z)]\)$
Reparameterization trick (when applicable): $\(z = g(\epsilon, \lambda), \quad \epsilon \sim p(\epsilon)\)\( \)\(\nabla_{\lambda} \mathcal{L} = \mathbb{E}_{p(\epsilon)}[\nabla_{\lambda}(\log p(x, g(\epsilon, \lambda)) - \log q_{\lambda}(g(\epsilon, \lambda)))]\)$
class BBVIGaussian(nn.Module):
"""Black box VI with Gaussian approximation using reparameterization."""
def __init__(self, latent_dim):
super().__init__()
self.latent_dim = latent_dim
self.mu = nn.Parameter(torch.randn(latent_dim))
self.log_sigma = nn.Parameter(torch.zeros(latent_dim))
def sample(self, n_samples=1):
"""Reparameterization: z = mu + sigma * epsilon."""
epsilon = torch.randn(n_samples, self.latent_dim)
return self.mu + torch.exp(self.log_sigma) * epsilon
def log_prob(self, z):
"""Log probability of q(z)."""
sigma = torch.exp(self.log_sigma)
log_prob = -0.5 * torch.sum(
torch.log(2 * np.pi * sigma**2) + ((z - self.mu) / sigma)**2,
dim=-1
)
return log_prob
def log_joint(z, X):
"""Log p(X, z) for simple Bayesian linear regression.
Model: y = z^T x + noise, z ~ N(0, I)
"""
# Prior: p(z) = N(0, I)
log_prior = -0.5 * torch.sum(z**2, dim=-1)
# Likelihood: p(X|z) for linear regression
# Here X is (features, targets) pair
features, targets = X
predictions = features @ z.T
log_lik = -0.5 * torch.sum((targets.unsqueeze(-1) - predictions)**2, dim=0)
return log_prior + log_lik
# Generate regression data
torch.manual_seed(42)
n_data = 100
d = 5
X_train = torch.randn(n_data, d)
true_w = torch.randn(d)
y_train = X_train @ true_w + 0.5 * torch.randn(n_data)
# BBVI
q = BBVIGaussian(d)
optimizer = optim.Adam(q.parameters(), lr=0.01)
elbo_list = []
for epoch in range(1000):
optimizer.zero_grad()
# Sample from q
z_samples = q.sample(n_samples=10)
# ELBO = E[log p(X, z)] - E[log q(z)]
log_joint_vals = log_joint(z_samples, (X_train, y_train))
log_q_vals = q.log_prob(z_samples)
elbo = (log_joint_vals - log_q_vals).mean()
# Maximize ELBO (minimize negative ELBO)
loss = -elbo
loss.backward()
optimizer.step()
elbo_list.append(elbo.item())
if epoch % 200 == 0:
print(f"Epoch {epoch}: ELBO = {elbo.item():.2f}")
print(f"\nTrue weights: {true_w.numpy()}")
print(f"Estimated weights (mean): {q.mu.detach().numpy()}")
print(f"Estimated uncertainty (std): {torch.exp(q.log_sigma).detach().numpy()}")
# Plot BBVI convergence
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(elbo_list, linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('ELBO', fontsize=12)
axes[0].set_title('Black Box VI Convergence', fontsize=13)
axes[0].grid(True, alpha=0.3)
# Compare true vs estimated weights
x_pos = np.arange(d)
axes[1].bar(x_pos - 0.2, true_w.numpy(), 0.4, label='True', alpha=0.7)
axes[1].bar(x_pos + 0.2, q.mu.detach().numpy(), 0.4, label='Estimated', alpha=0.7)
axes[1].errorbar(x_pos + 0.2, q.mu.detach().numpy(),
yerr=torch.exp(q.log_sigma).detach().numpy(),
fmt='none', ecolor='black', capsize=5)
axes[1].set_xlabel('Weight Index', fontsize=12)
axes[1].set_ylabel('Value', fontsize=12)
axes[1].set_title('Weight Comparison', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
Summaryยถ
Key Concepts:ยถ
Variational Inference: Approximate intractable posteriors via optimization
ELBO: Evidence lower bound, objective function for VI
Mean-field: Assume factorized posterior for tractability
CAVI: Coordinate ascent for conjugate models
BBVI: Gradient-based VI for general models
ELBO Forms:ยถ
Advantages of VI:ยถ
Fast: Optimization vs sampling
Scalable: Stochastic gradients for big data
Deterministic: No MCMC diagnostics
Flexible: Any differentiable model
Limitations:ยถ
Approximation: Not exact (underestimates uncertainty)
Local optima: Optimization challenges
Family choice: Restricted to \(\mathcal{Q}\)
Modern Extensions:ยถ
Normalizing flows: Expressive variational families
Importance weighted ELBO: Tighter bounds
Amortized inference: Neural networks for \(q\)
Stochastic VI: Mini-batch optimization
Applications:ยถ
Variational autoencoders (VAEs)
Bayesian neural networks
Topic models (LDA)
Gaussian processes
Further Reading:ยถ
Blei et al. (2017) - โVariational Inference: A Reviewโ
Jordan et al. (1999) - โIntroduction to Variational Methodsโ
Kingma & Welling (2014) - โAuto-Encoding Variational Bayesโ
Next Steps:ยถ
03_variational_autoencoders_advanced.ipynb - VAE with ELBO
04_pac_bayes_theory.ipynb - PAC-Bayesian framework
09_expectation_maximization.ipynb - EM algorithm