import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.special import gammaln

sns.set_style('whitegrid')
np.random.seed(42)

1. Motivation: Infinite Flexibility

Parametric Models Limitation

Gaussian Mixture Model with \(K\) components: $\(p(x) = \sum_{k=1}^K \pi_k \mathcal{N}(x | \mu_k, \Sigma_k)\)$

Problem: Must choose \(K\) in advance!

Non-parametric Approach

“Non-parametric” ≠ no parameters

Instead: Number of parameters grows with data

Bayesian Non-parametrics

Prior over infinite-dimensional objects:

  • Infinite mixture models

  • Infinite latent feature models

  • Non-parametric regression

Key advantage: Model complexity adapts to data!

Main Tools

  1. Dirichlet Process (DP): Prior over distributions

  2. Chinese Restaurant Process (CRP): Clustering process

  3. Indian Buffet Process (IBP): Feature allocation

  4. Gaussian Process (GP): Function prior

2. Dirichlet Distribution Review

Definition

For \(K\)-dimensional probability vector \(\pi = (\pi_1, \ldots, \pi_K)\): $\(\pi \sim \text{Dir}(\alpha_1, \ldots, \alpha_K)\)$

PDF: $\(p(\pi) = \frac{1}{B(\alpha)} \prod_{k=1}^K \pi_k^{\alpha_k - 1}\)$

where \(B(\alpha) = \frac{\prod_k \Gamma(\alpha_k)}{\Gamma(\sum_k \alpha_k)}\)

Symmetric Dirichlet

\[\pi \sim \text{Dir}(\alpha/K, \ldots, \alpha/K)\]

Parameter \(\alpha\) controls:

  • \(\alpha < 1\): Sparse (few large components)

  • \(\alpha = 1\): Uniform

  • \(\alpha > 1\): Diffuse (many similar components)

Conjugacy

Prior: \(\pi \sim \text{Dir}(\alpha)\)

Likelihood: \(x_1, \ldots, x_n \sim \text{Categorical}(\pi)\)

Posterior: \(\pi | x \sim \text{Dir}(\alpha + n)\)

where \(n_k\) = count of observations in category \(k\)

# Visualize Dirichlet distribution (3D case)
def plot_dirichlet_samples(alpha, n_samples=1000):
    """Sample and plot from 3D Dirichlet."""
    samples = np.random.dirichlet(alpha, n_samples)
    
    fig = plt.figure(figsize=(15, 5))
    
    # 3D scatter
    ax1 = fig.add_subplot(131, projection='3d')
    ax1.scatter(samples[:, 0], samples[:, 1], samples[:, 2], alpha=0.5, s=10)
    ax1.set_xlabel('π₁', fontsize=11)
    ax1.set_ylabel('π₂', fontsize=11)
    ax1.set_zlabel('π₃', fontsize=11)
    ax1.set_title(f'Dir({alpha})', fontsize=12)
    
    # Ternary-like 2D projection
    ax2 = fig.add_subplot(132)
    ax2.scatter(samples[:, 0], samples[:, 1], alpha=0.3, s=5)
    ax2.set_xlabel('π₁', fontsize=11)
    ax2.set_ylabel('π₂', fontsize=11)
    ax2.set_title('2D Projection', fontsize=12)
    ax2.grid(True, alpha=0.3)
    
    # Histograms
    ax3 = fig.add_subplot(133)
    for i in range(3):
        ax3.hist(samples[:, i], bins=30, alpha=0.5, label=f{i+1}')
    ax3.set_xlabel('Value', fontsize=11)
    ax3.set_ylabel('Frequency', fontsize=11)
    ax3.set_title('Marginal Distributions', fontsize=12)
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Different concentration parameters
print("Sparse (α < 1):")
plot_dirichlet_samples([0.1, 0.1, 0.1])

print("\nUniform (α = 1):")
plot_dirichlet_samples([1.0, 1.0, 1.0])

print("\nDiffuse (α > 1):")
plot_dirichlet_samples([10.0, 10.0, 10.0])

3. Dirichlet Process (DP)

From Finite to Infinite

Dirichlet: Distribution over \(K\)-dim probability vectors

Dirichlet Process: Distribution over distributions!

Formal Definition

\(G \sim DP(\alpha, H)\) where:

  • \(\alpha > 0\): Concentration parameter

  • \(H\): Base distribution

Key property: For any partition \(A_1, \ldots, A_K\) of sample space: $\((G(A_1), \ldots, G(A_K)) \sim \text{Dir}(\alpha H(A_1), \ldots, \alpha H(A_K))\)$

Stick-Breaking Construction

Explicit construction of \(G \sim DP(\alpha, H)\):

  1. Sample atoms: \(\theta_k^* \sim H\) for \(k = 1, 2, \ldots\)

  2. Sample stick weights: \(\beta_k \sim \text{Beta}(1, \alpha)\)

  3. Define weights: \(\pi_k = \beta_k \prod_{j<k}(1 - \beta_j)\)

  4. Result: \(G = \sum_{k=1}^\infty \pi_k \delta_{\theta_k^*}\)

where \(\delta_{\theta}\) is point mass at \(\theta\).

Discreteness

Crucial: \(G\) is discrete with probability 1!

Samples from \(G\) will repeat (clustering property).

def stick_breaking(alpha, H_sampler, n_atoms=20):
    """Stick-breaking construction of DP.
    
    Args:
        alpha: Concentration parameter
        H_sampler: Function to sample from base distribution
        n_atoms: Number of atoms to generate
    
    Returns:
        atoms: Atom locations
        weights: Atom probabilities
    """
    # Sample atoms from base distribution
    atoms = [H_sampler() for _ in range(n_atoms)]
    
    # Stick-breaking process
    betas = np.random.beta(1, alpha, n_atoms)
    weights = np.zeros(n_atoms)
    
    stick_left = 1.0
    for k in range(n_atoms):
        weights[k] = betas[k] * stick_left
        stick_left *= (1 - betas[k])
    
    return np.array(atoms), weights

# Visualize stick-breaking for different alphas
H_sampler = lambda: np.random.randn()  # Base: N(0,1)

fig, axes = plt.subplots(1, 3, figsize=(16, 4))
alphas = [0.1, 1.0, 10.0]

for idx, alpha in enumerate(alphas):
    atoms, weights = stick_breaking(alpha, H_sampler, n_atoms=50)
    
    # Plot weights
    axes[idx].bar(range(len(weights)), weights, alpha=0.7)
    axes[idx].set_xlabel('Atom Index', fontsize=11)
    axes[idx].set_ylabel('Weight', fontsize=11)
    axes[idx].set_title(f'α = {alpha}', fontsize=12)
    axes[idx].grid(True, alpha=0.3, axis='y')
    
    # Show concentration
    cumsum = np.cumsum(weights)
    n_99 = np.argmax(cumsum > 0.99) + 1
    axes[idx].text(0.6, 0.9, f'99% mass in {n_99} atoms', 
                   transform=axes[idx].transAxes, fontsize=10)

plt.tight_layout()
plt.show()

print("Lower α → more concentrated (few dominant atoms)")
print("Higher α → more diffuse (many small atoms)")

4. Chinese Restaurant Process (CRP)

Metaphor

Customers enter restaurant sequentially:

  • Customer 1: Sits at table 1

  • Customer \(n+1\):

    • Joins table \(k\) with probability \(\propto n_k\) (# at table \(k\))

    • Sits at new table with probability \(\propto \alpha\)

Formal Definition

Given seating \(z_1, \ldots, z_n\), customer \(n+1\) sits at:

\[\begin{split}P(z_{n+1} = k | z_{1:n}) = \begin{cases} \frac{n_k}{n + \alpha} & \text{existing table } k \\ \frac{\alpha}{n + \alpha} & \text{new table} \end{cases}\end{split}\]

“Rich get richer”

Popular tables more likely to attract new customers!

Connection to DP

CRP is the clustering process induced by DP mixture: $\(\theta_i | G \sim G, \quad G \sim DP(\alpha, H)\)$

Tables = clusters, dish at table = cluster parameter \(\theta_k^*\)

def chinese_restaurant_process(n_customers, alpha):
    """Simulate CRP.
    
    Returns:
        table_assignments: Assignment for each customer
        table_counts: Number at each table over time
    """
    tables = [0]  # Customer 1 at table 0
    table_counts = [[1]]  # Track table sizes over time
    
    for n in range(1, n_customers):
        # Current table counts
        counts = np.bincount(tables)
        
        # Probabilities
        probs = np.append(counts, alpha) / (n + alpha)
        
        # Sample table
        table = np.random.choice(len(probs), p=probs)
        tables.append(table)
        
        # Record
        table_counts.append(np.bincount(tables).tolist())
    
    return tables, table_counts

# Simulate CRP
n_customers = 100
alpha_values = [0.1, 1.0, 10.0]

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

for idx, alpha in enumerate(alpha_values):
    tables, counts = chinese_restaurant_process(n_customers, alpha)
    
    # Plot table evolution
    n_tables = len(counts[-1])
    for table_id in range(n_tables):
        table_sizes = []
        for t in range(len(counts)):
            if table_id < len(counts[t]):
                table_sizes.append(counts[t][table_id])
            else:
                table_sizes.append(0)
        
        axes[idx].plot(table_sizes, linewidth=1.5, alpha=0.7)
    
    axes[idx].set_xlabel('Customer', fontsize=11)
    axes[idx].set_ylabel('Table Size', fontsize=11)
    axes[idx].set_title(f'α = {alpha} ({n_tables} tables)', fontsize=12)
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Lower α → fewer, larger tables")
print("Higher α → more, smaller tables")

5. DP Mixture Model for Clustering

class DPMixtureGaussian:
    """DP Gaussian mixture with collapsed Gibbs sampling."""
    def __init__(self, alpha=1.0, prior_mu=0, prior_sigma=10, likelihood_sigma=1):
        self.alpha = alpha
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.likelihood_sigma = likelihood_sigma
    
    def fit(self, X, n_iterations=100):
        """Collapsed Gibbs sampling."""
        n = len(X)
        
        # Initialize: each point in own cluster
        self.z = np.arange(n)
        
        # Track number of clusters
        self.n_clusters_history = []
        
        for iteration in range(n_iterations):
            for i in range(n):
                # Remove x_i from its cluster
                old_cluster = self.z[i]
                self.z[i] = -1  # Mark as unassigned
                
                # Get cluster counts (excluding x_i)
                unique_clusters = np.unique(self.z[self.z >= 0])
                
                # Compute probabilities for each existing cluster
                log_probs = []
                
                for k in unique_clusters:
                    n_k = np.sum(self.z == k)
                    X_k = X[self.z == k]
                    
                    # Likelihood under cluster k
                    log_prob_cluster = np.log(n_k)
                    log_prob_cluster += self._log_predictive(X[i], X_k)
                    log_probs.append(log_prob_cluster)
                
                # Probability of new cluster
                log_prob_new = np.log(self.alpha)
                log_prob_new += self._log_predictive(X[i], np.array([]))
                log_probs.append(log_prob_new)
                
                # Normalize and sample
                log_probs = np.array(log_probs)
                log_probs -= log_probs.max()  # Numerical stability
                probs = np.exp(log_probs)
                probs /= probs.sum()
                
                # Sample cluster
                choice = np.random.choice(len(probs), p=probs)
                
                if choice < len(unique_clusters):
                    self.z[i] = unique_clusters[choice]
                else:
                    # New cluster
                    self.z[i] = self.z.max() + 1
            
            # Relabel clusters to be contiguous
            unique = np.unique(self.z)
            for new_label, old_label in enumerate(unique):
                self.z[self.z == old_label] = new_label + 1000  # Temp
            self.z -= 1000
            
            # Record
            self.n_clusters_history.append(len(unique))
        
        return self
    
    def _log_predictive(self, x, X_cluster):
        """Log predictive probability of x under cluster."""
        sigma2 = self.likelihood_sigma**2
        tau2 = self.prior_sigma**2
        
        if len(X_cluster) == 0:
            # Prior predictive
            return stats.norm.logpdf(x, self.prior_mu, 
                                    np.sqrt(sigma2 + tau2))
        else:
            # Posterior predictive (conjugate normal-normal)
            n = len(X_cluster)
            x_mean = X_cluster.mean()
            
            # Posterior parameters
            post_sigma2 = 1 / (1/tau2 + n/sigma2)
            post_mu = post_sigma2 * (self.prior_mu/tau2 + n*x_mean/sigma2)
            
            # Predictive
            pred_var = sigma2 + post_sigma2
            return stats.norm.logpdf(x, post_mu, np.sqrt(pred_var))

# Generate data from 3 Gaussians
np.random.seed(42)
X1 = np.random.randn(50) - 5
X2 = np.random.randn(50)
X3 = np.random.randn(50) + 5
X = np.concatenate([X1, X2, X3])

# Fit DP mixture
model = DPMixtureGaussian(alpha=1.0)
model.fit(X, n_iterations=200)

print(f"Final number of clusters: {len(np.unique(model.z))}")
print(f"True number of clusters: 3")
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Cluster evolution
axes[0].plot(model.n_clusters_history, linewidth=2)
axes[0].axhline(y=3, color='r', linestyle='--', linewidth=2, label='True K=3')
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('Number of Clusters', fontsize=12)
axes[0].set_title('Cluster Discovery', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Final clustering
unique_clusters = np.unique(model.z)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_clusters)))

for k, color in zip(unique_clusters, colors):
    mask = model.z == k
    axes[1].hist(X[mask], bins=20, alpha=0.6, label=f'Cluster {k}', color=color)

axes[1].set_xlabel('Value', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Final Clustering', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Summary

Key Concepts:

  1. Non-parametric: Model complexity grows with data

  2. Dirichlet Process: Prior over distributions, discrete with prob 1

  3. Stick-breaking: Explicit construction via infinite series

  4. CRP: Clustering process with rich-get-richer dynamics

  5. DP Mixture: Infinite mixture model with automatic K selection

Dirichlet Process Properties:

\[G \sim DP(\alpha, H)\]
  • Concentration \(\alpha\): Controls cluster propensity

  • Base \(H\): Prior on cluster parameters

  • Discreteness: \(G\) is discrete almost surely

CRP Probabilities:

\[P(z_{n+1} = k) = \frac{n_k}{n + \alpha}\]
\[P(\text{new cluster}) = \frac{\alpha}{n + \alpha}\]

Inference Methods:

  1. Collapsed Gibbs: Integrate out \(G\), sample \(z\)

  2. Stick-breaking sampler: Explicit \(\pi, \theta\)

  3. Variational inference: Approximate posterior

  4. Slice sampler: Finite approximation

Applications:

  • Clustering: Automatic K selection

  • Topic models: Hierarchical DP for LDA

  • Hidden Markov models: Infinite state space

  • Survival analysis: Non-parametric hazards

Extensions:

  • Hierarchical DP: Shared clusters across groups

  • Pitman-Yor process: Power-law behavior

  • Indian Buffet Process: Infinite latent features

  • Beta Process: Non-parametric factor analysis

Further Reading:

  • Teh (2010) - “Dirichlet Process”

  • Gershman & Blei (2012) - “Tutorial on Bayesian Nonparametric Models”

  • Neal (2000) - “Markov Chain Sampling Methods for DP”

Next Steps:

  • 06_variational_inference.ipynb - Approximate inference

  • 09_expectation_maximization.ipynb - EM for mixtures

  • Phase 6 neural networks for deep learning