import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.special import gammaln
sns.set_style('whitegrid')
np.random.seed(42)
1. Motivation: Infinite Flexibility¶
Parametric Models Limitation¶
Gaussian Mixture Model with \(K\) components: $\(p(x) = \sum_{k=1}^K \pi_k \mathcal{N}(x | \mu_k, \Sigma_k)\)$
Problem: Must choose \(K\) in advance!
Non-parametric Approach¶
“Non-parametric” ≠ no parameters
Instead: Number of parameters grows with data
Bayesian Non-parametrics¶
Prior over infinite-dimensional objects:
Infinite mixture models
Infinite latent feature models
Non-parametric regression
Key advantage: Model complexity adapts to data!
Main Tools¶
Dirichlet Process (DP): Prior over distributions
Chinese Restaurant Process (CRP): Clustering process
Indian Buffet Process (IBP): Feature allocation
Gaussian Process (GP): Function prior
2. Dirichlet Distribution Review¶
Definition¶
For \(K\)-dimensional probability vector \(\pi = (\pi_1, \ldots, \pi_K)\): $\(\pi \sim \text{Dir}(\alpha_1, \ldots, \alpha_K)\)$
PDF: $\(p(\pi) = \frac{1}{B(\alpha)} \prod_{k=1}^K \pi_k^{\alpha_k - 1}\)$
where \(B(\alpha) = \frac{\prod_k \Gamma(\alpha_k)}{\Gamma(\sum_k \alpha_k)}\)
Symmetric Dirichlet¶
Parameter \(\alpha\) controls:
\(\alpha < 1\): Sparse (few large components)
\(\alpha = 1\): Uniform
\(\alpha > 1\): Diffuse (many similar components)
Conjugacy¶
Prior: \(\pi \sim \text{Dir}(\alpha)\)
Likelihood: \(x_1, \ldots, x_n \sim \text{Categorical}(\pi)\)
Posterior: \(\pi | x \sim \text{Dir}(\alpha + n)\)
where \(n_k\) = count of observations in category \(k\)
# Visualize Dirichlet distribution (3D case)
def plot_dirichlet_samples(alpha, n_samples=1000):
"""Sample and plot from 3D Dirichlet."""
samples = np.random.dirichlet(alpha, n_samples)
fig = plt.figure(figsize=(15, 5))
# 3D scatter
ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(samples[:, 0], samples[:, 1], samples[:, 2], alpha=0.5, s=10)
ax1.set_xlabel('π₁', fontsize=11)
ax1.set_ylabel('π₂', fontsize=11)
ax1.set_zlabel('π₃', fontsize=11)
ax1.set_title(f'Dir({alpha})', fontsize=12)
# Ternary-like 2D projection
ax2 = fig.add_subplot(132)
ax2.scatter(samples[:, 0], samples[:, 1], alpha=0.3, s=5)
ax2.set_xlabel('π₁', fontsize=11)
ax2.set_ylabel('π₂', fontsize=11)
ax2.set_title('2D Projection', fontsize=12)
ax2.grid(True, alpha=0.3)
# Histograms
ax3 = fig.add_subplot(133)
for i in range(3):
ax3.hist(samples[:, i], bins=30, alpha=0.5, label=f'π{i+1}')
ax3.set_xlabel('Value', fontsize=11)
ax3.set_ylabel('Frequency', fontsize=11)
ax3.set_title('Marginal Distributions', fontsize=12)
ax3.legend()
ax3.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Different concentration parameters
print("Sparse (α < 1):")
plot_dirichlet_samples([0.1, 0.1, 0.1])
print("\nUniform (α = 1):")
plot_dirichlet_samples([1.0, 1.0, 1.0])
print("\nDiffuse (α > 1):")
plot_dirichlet_samples([10.0, 10.0, 10.0])
3. Dirichlet Process (DP)¶
From Finite to Infinite¶
Dirichlet: Distribution over \(K\)-dim probability vectors
Dirichlet Process: Distribution over distributions!
Formal Definition¶
\(G \sim DP(\alpha, H)\) where:
\(\alpha > 0\): Concentration parameter
\(H\): Base distribution
Key property: For any partition \(A_1, \ldots, A_K\) of sample space: $\((G(A_1), \ldots, G(A_K)) \sim \text{Dir}(\alpha H(A_1), \ldots, \alpha H(A_K))\)$
Stick-Breaking Construction¶
Explicit construction of \(G \sim DP(\alpha, H)\):
Sample atoms: \(\theta_k^* \sim H\) for \(k = 1, 2, \ldots\)
Sample stick weights: \(\beta_k \sim \text{Beta}(1, \alpha)\)
Define weights: \(\pi_k = \beta_k \prod_{j<k}(1 - \beta_j)\)
Result: \(G = \sum_{k=1}^\infty \pi_k \delta_{\theta_k^*}\)
where \(\delta_{\theta}\) is point mass at \(\theta\).
Discreteness¶
Crucial: \(G\) is discrete with probability 1!
Samples from \(G\) will repeat (clustering property).
def stick_breaking(alpha, H_sampler, n_atoms=20):
"""Stick-breaking construction of DP.
Args:
alpha: Concentration parameter
H_sampler: Function to sample from base distribution
n_atoms: Number of atoms to generate
Returns:
atoms: Atom locations
weights: Atom probabilities
"""
# Sample atoms from base distribution
atoms = [H_sampler() for _ in range(n_atoms)]
# Stick-breaking process
betas = np.random.beta(1, alpha, n_atoms)
weights = np.zeros(n_atoms)
stick_left = 1.0
for k in range(n_atoms):
weights[k] = betas[k] * stick_left
stick_left *= (1 - betas[k])
return np.array(atoms), weights
# Visualize stick-breaking for different alphas
H_sampler = lambda: np.random.randn() # Base: N(0,1)
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
alphas = [0.1, 1.0, 10.0]
for idx, alpha in enumerate(alphas):
atoms, weights = stick_breaking(alpha, H_sampler, n_atoms=50)
# Plot weights
axes[idx].bar(range(len(weights)), weights, alpha=0.7)
axes[idx].set_xlabel('Atom Index', fontsize=11)
axes[idx].set_ylabel('Weight', fontsize=11)
axes[idx].set_title(f'α = {alpha}', fontsize=12)
axes[idx].grid(True, alpha=0.3, axis='y')
# Show concentration
cumsum = np.cumsum(weights)
n_99 = np.argmax(cumsum > 0.99) + 1
axes[idx].text(0.6, 0.9, f'99% mass in {n_99} atoms',
transform=axes[idx].transAxes, fontsize=10)
plt.tight_layout()
plt.show()
print("Lower α → more concentrated (few dominant atoms)")
print("Higher α → more diffuse (many small atoms)")
4. Chinese Restaurant Process (CRP)¶
Metaphor¶
Customers enter restaurant sequentially:
Customer 1: Sits at table 1
Customer \(n+1\):
Joins table \(k\) with probability \(\propto n_k\) (# at table \(k\))
Sits at new table with probability \(\propto \alpha\)
Formal Definition¶
Given seating \(z_1, \ldots, z_n\), customer \(n+1\) sits at:
“Rich get richer”¶
Popular tables more likely to attract new customers!
Connection to DP¶
CRP is the clustering process induced by DP mixture: $\(\theta_i | G \sim G, \quad G \sim DP(\alpha, H)\)$
Tables = clusters, dish at table = cluster parameter \(\theta_k^*\)
def chinese_restaurant_process(n_customers, alpha):
"""Simulate CRP.
Returns:
table_assignments: Assignment for each customer
table_counts: Number at each table over time
"""
tables = [0] # Customer 1 at table 0
table_counts = [[1]] # Track table sizes over time
for n in range(1, n_customers):
# Current table counts
counts = np.bincount(tables)
# Probabilities
probs = np.append(counts, alpha) / (n + alpha)
# Sample table
table = np.random.choice(len(probs), p=probs)
tables.append(table)
# Record
table_counts.append(np.bincount(tables).tolist())
return tables, table_counts
# Simulate CRP
n_customers = 100
alpha_values = [0.1, 1.0, 10.0]
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
for idx, alpha in enumerate(alpha_values):
tables, counts = chinese_restaurant_process(n_customers, alpha)
# Plot table evolution
n_tables = len(counts[-1])
for table_id in range(n_tables):
table_sizes = []
for t in range(len(counts)):
if table_id < len(counts[t]):
table_sizes.append(counts[t][table_id])
else:
table_sizes.append(0)
axes[idx].plot(table_sizes, linewidth=1.5, alpha=0.7)
axes[idx].set_xlabel('Customer', fontsize=11)
axes[idx].set_ylabel('Table Size', fontsize=11)
axes[idx].set_title(f'α = {alpha} ({n_tables} tables)', fontsize=12)
axes[idx].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Lower α → fewer, larger tables")
print("Higher α → more, smaller tables")
5. DP Mixture Model for Clustering¶
class DPMixtureGaussian:
"""DP Gaussian mixture with collapsed Gibbs sampling."""
def __init__(self, alpha=1.0, prior_mu=0, prior_sigma=10, likelihood_sigma=1):
self.alpha = alpha
self.prior_mu = prior_mu
self.prior_sigma = prior_sigma
self.likelihood_sigma = likelihood_sigma
def fit(self, X, n_iterations=100):
"""Collapsed Gibbs sampling."""
n = len(X)
# Initialize: each point in own cluster
self.z = np.arange(n)
# Track number of clusters
self.n_clusters_history = []
for iteration in range(n_iterations):
for i in range(n):
# Remove x_i from its cluster
old_cluster = self.z[i]
self.z[i] = -1 # Mark as unassigned
# Get cluster counts (excluding x_i)
unique_clusters = np.unique(self.z[self.z >= 0])
# Compute probabilities for each existing cluster
log_probs = []
for k in unique_clusters:
n_k = np.sum(self.z == k)
X_k = X[self.z == k]
# Likelihood under cluster k
log_prob_cluster = np.log(n_k)
log_prob_cluster += self._log_predictive(X[i], X_k)
log_probs.append(log_prob_cluster)
# Probability of new cluster
log_prob_new = np.log(self.alpha)
log_prob_new += self._log_predictive(X[i], np.array([]))
log_probs.append(log_prob_new)
# Normalize and sample
log_probs = np.array(log_probs)
log_probs -= log_probs.max() # Numerical stability
probs = np.exp(log_probs)
probs /= probs.sum()
# Sample cluster
choice = np.random.choice(len(probs), p=probs)
if choice < len(unique_clusters):
self.z[i] = unique_clusters[choice]
else:
# New cluster
self.z[i] = self.z.max() + 1
# Relabel clusters to be contiguous
unique = np.unique(self.z)
for new_label, old_label in enumerate(unique):
self.z[self.z == old_label] = new_label + 1000 # Temp
self.z -= 1000
# Record
self.n_clusters_history.append(len(unique))
return self
def _log_predictive(self, x, X_cluster):
"""Log predictive probability of x under cluster."""
sigma2 = self.likelihood_sigma**2
tau2 = self.prior_sigma**2
if len(X_cluster) == 0:
# Prior predictive
return stats.norm.logpdf(x, self.prior_mu,
np.sqrt(sigma2 + tau2))
else:
# Posterior predictive (conjugate normal-normal)
n = len(X_cluster)
x_mean = X_cluster.mean()
# Posterior parameters
post_sigma2 = 1 / (1/tau2 + n/sigma2)
post_mu = post_sigma2 * (self.prior_mu/tau2 + n*x_mean/sigma2)
# Predictive
pred_var = sigma2 + post_sigma2
return stats.norm.logpdf(x, post_mu, np.sqrt(pred_var))
# Generate data from 3 Gaussians
np.random.seed(42)
X1 = np.random.randn(50) - 5
X2 = np.random.randn(50)
X3 = np.random.randn(50) + 5
X = np.concatenate([X1, X2, X3])
# Fit DP mixture
model = DPMixtureGaussian(alpha=1.0)
model.fit(X, n_iterations=200)
print(f"Final number of clusters: {len(np.unique(model.z))}")
print(f"True number of clusters: 3")
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Cluster evolution
axes[0].plot(model.n_clusters_history, linewidth=2)
axes[0].axhline(y=3, color='r', linestyle='--', linewidth=2, label='True K=3')
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('Number of Clusters', fontsize=12)
axes[0].set_title('Cluster Discovery', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Final clustering
unique_clusters = np.unique(model.z)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_clusters)))
for k, color in zip(unique_clusters, colors):
mask = model.z == k
axes[1].hist(X[mask], bins=20, alpha=0.6, label=f'Cluster {k}', color=color)
axes[1].set_xlabel('Value', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Final Clustering', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Summary¶
Key Concepts:¶
Non-parametric: Model complexity grows with data
Dirichlet Process: Prior over distributions, discrete with prob 1
Stick-breaking: Explicit construction via infinite series
CRP: Clustering process with rich-get-richer dynamics
DP Mixture: Infinite mixture model with automatic K selection
Dirichlet Process Properties:¶
Concentration \(\alpha\): Controls cluster propensity
Base \(H\): Prior on cluster parameters
Discreteness: \(G\) is discrete almost surely
CRP Probabilities:¶
Inference Methods:¶
Collapsed Gibbs: Integrate out \(G\), sample \(z\)
Stick-breaking sampler: Explicit \(\pi, \theta\)
Variational inference: Approximate posterior
Slice sampler: Finite approximation
Applications:¶
Clustering: Automatic K selection
Topic models: Hierarchical DP for LDA
Hidden Markov models: Infinite state space
Survival analysis: Non-parametric hazards
Extensions:¶
Hierarchical DP: Shared clusters across groups
Pitman-Yor process: Power-law behavior
Indian Buffet Process: Infinite latent features
Beta Process: Non-parametric factor analysis
Further Reading:¶
Teh (2010) - “Dirichlet Process”
Gershman & Blei (2012) - “Tutorial on Bayesian Nonparametric Models”
Neal (2000) - “Markov Chain Sampling Methods for DP”
Next Steps:¶
06_variational_inference.ipynb - Approximate inference
09_expectation_maximization.ipynb - EM for mixtures
Phase 6 neural networks for deep learning