import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Mixture of Experts ConceptΒΆ

Gating Network:ΒΆ

\[g(x) = \text{softmax}(W_g x)\]

MoE Output:ΒΆ

\[y = \sum_{i=1}^n g_i(x) E_i(x)\]

where \(E_i\) are expert networks.

πŸ“š Reference Materials:

class Expert(nn.Module):
    """Single expert network."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

print("Expert defined")

Gating NetworkΒΆ

The gating network (or router) decides which expert(s) to activate for each input. It typically consists of a small neural network that takes the input and produces a probability distribution over experts via softmax. In top-k routing, only the \(k\) experts with the highest gating scores are activated, providing computational savings proportional to \(k/N\) where \(N\) is the total number of experts. The gating network must balance load (distributing inputs evenly across experts) and specialization (allowing experts to focus on the inputs they handle best). An auxiliary load-balancing loss penalizes uneven expert utilization.

class GatingNetwork(nn.Module):
    """Gating network for routing."""
    
    def __init__(self, input_dim, n_experts, k=2):
        super().__init__()
        self.n_experts = n_experts
        self.k = k  # Top-k experts
        
        self.gate = nn.Linear(input_dim, n_experts)
    
    def forward(self, x):
        """
        Returns:
            gates: (batch, n_experts) routing weights
            indices: (batch, k) top-k expert indices
        """
        logits = self.gate(x)
        
        # Top-k gating
        top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=1)
        
        # Softmax over top-k
        top_k_gates = F.softmax(top_k_logits, dim=1)
        
        # Sparse gates
        gates = torch.zeros_like(logits)
        gates.scatter_(1, top_k_indices, top_k_gates)
        
        return gates, top_k_indices

print("GatingNetwork defined")

Mixture of Experts LayerΒΆ

The MoE layer combines the gating network with a collection of expert sub-networks. For a given input, the gate produces routing weights, the top-\(k\) experts process the input, and their outputs are combined as a weighted sum: \(y = \sum_{i \in \text{top-}k} g_i \cdot E_i(x)\), where \(g_i\) are the gating weights and \(E_i\) are the expert outputs. Each expert has the same architecture but independent parameters, allowing different experts to specialize on different parts of the input space. This sparse activation pattern enables the model to scale to very large parameter counts while keeping per-example computation constant – the same principle behind models like Switch Transformer and GShard.

class MixtureOfExperts(nn.Module):
    """MoE layer with multiple experts."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, n_experts=4, k=2):
        super().__init__()
        self.n_experts = n_experts
        self.k = k
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, output_dim)
            for _ in range(n_experts)
        ])
        
        # Gating
        self.gate = GatingNetwork(input_dim, n_experts, k)
    
    def forward(self, x):
        # Get gates
        gates, indices = self.gate(x)
        
        # Compute expert outputs
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        # (batch, n_experts, output_dim)
        
        # Weighted combination
        output = torch.einsum('be,beo->bo', gates, expert_outputs)
        
        # Load balancing loss
        importance = gates.sum(dim=0)  # (n_experts,)
        load_loss = (importance ** 2).sum() / (importance.sum() ** 2 + 1e-8)
        
        return output, load_loss

print("MixtureOfExperts defined")

MoE Network for ClassificationΒΆ

Wrapping the MoE layer within a standard classification pipeline demonstrates how sparse expert routing works in practice. The model consists of shared feature extraction layers (e.g., convolutional backbone), the MoE layer that routes features to specialized experts, and a final classification head. During training, the total loss includes the task loss (cross-entropy) plus the load-balancing loss that encourages uniform expert utilization. This architecture naturally scales: adding more experts increases the model’s capacity without proportionally increasing inference cost.

class MoENet(nn.Module):
    """Network with MoE layers."""
    
    def __init__(self, n_experts=4):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.moe = MixtureOfExperts(256, 128, 128, n_experts=n_experts, k=2)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x, load_loss = self.moe(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x, load_loss

print("MoENet defined")

Train MoEΒΆ

Training a Mixture of Experts model requires monitoring not just the task loss but also the expert utilization statistics: which experts are being selected, how evenly the load is distributed, and whether any experts have become dormant (never selected). Early in training, the router may collapse to always selecting the same expert (expert collapse), which the load-balancing loss is designed to prevent. Comparing MoE training speed and final accuracy against a dense model with the same total parameter count reveals the efficiency benefit of sparse computation.

# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

# Model
model = MoENet(n_experts=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train
losses = []
for epoch in range(5):
    model.train()
    epoch_loss = 0
    
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        
        output, load_loss = model(x)
        loss_ce = F.cross_entropy(output, y)
        
        # Total loss with load balancing
        loss = loss_ce + 0.01 * load_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    losses.append(epoch_loss / len(train_loader))
    print(f"Epoch {epoch+1}, Loss: {losses[-1]:.4f}")

# Evaluate
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        output, _ = model(x)
        pred = output.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

print(f"\nTest Accuracy: {100 * correct / total:.2f}%")

Analyze Expert UsageΒΆ

After training, analyzing which inputs each expert specializes in reveals the learned division of labor. For image classification, different experts may specialize in different classes, different visual styles, or different difficulty levels. Visualizing the gating weights as a function of the input class or feature space provides interpretable insight into the model’s internal routing strategy. Healthy expert usage shows all experts being utilized with some degree of specialization, whereas pathological usage shows a few experts handling most inputs while others remain idle.

# Analyze which experts are used
expert_usage = torch.zeros(4)

model.eval()
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        x_feat = F.relu(model.fc1(x.view(-1, 784)))
        gates, _ = model.moe.gate(x_feat)
        expert_usage += gates.sum(dim=0).cpu()

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Expert usage
axes[0].bar(range(4), expert_usage.numpy())
axes[0].set_xlabel('Expert', fontsize=11)
axes[0].set_ylabel('Usage Count', fontsize=11)
axes[0].set_title('Expert Usage Distribution', fontsize=12)
axes[0].grid(True, alpha=0.3, axis='y')

# Training loss
axes[1].plot(losses, 'b-o', markersize=5)
axes[1].set_xlabel('Epoch', fontsize=11)
axes[1].set_ylabel('Loss', fontsize=11)
axes[1].set_title('Training Loss', fontsize=12)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

SummaryΒΆ

Mixture of Experts:ΒΆ

Components:

  1. Multiple expert networks

  2. Gating network for routing

  3. Top-k sparse selection

  4. Load balancing loss

Key Ideas:ΒΆ

  • Conditional computation: Only activate subset of parameters

  • Specialization: Experts learn different sub-tasks

  • Scalability: Add capacity without proportional compute

Load Balancing:ΒΆ

Prevent expert collapse:

  • Importance loss

  • Auxiliary load loss

  • Random routing

Applications:ΒΆ

  • Large language models (Switch Transformer, GShard)

  • Multi-task learning

  • Vision models (V-MoE)

  • Recommendation systems

Advantages:ΒΆ

  • Efficient scaling

  • Task specialization

  • Sparse activation

Challenges:ΒΆ

  • Load imbalance

  • Training instability

  • Communication overhead (distributed)

Advanced Mixture of Experts: Mathematical Foundations and Modern ArchitecturesΒΆ

Table of ContentsΒΆ

  1. Introduction to Mixture of Experts

  2. Mathematical Formulation

  3. Gating Mechanisms

  4. Sparsity and Load Balancing

  5. Switch Transformers

  6. Expert Choice Routing

  7. Hierarchical MoE

  8. Training Dynamics and Challenges

  9. Theoretical Analysis

  10. Modern Applications

  11. Implementation Considerations

1. Introduction to Mixture of Experts (MoE)ΒΆ

Core ConceptΒΆ

Mixture of Experts is a machine learning technique that divides a complex problem into smaller sub-problems, each handled by a specialized β€œexpert” network. A gating network learns to route inputs to the most appropriate experts.

Key Idea:

  • Conditional Computation: Only activate a subset of parameters per input

  • Specialization: Each expert learns different aspects of the data distribution

  • Scalability: Increase model capacity without proportional computation cost

Historical ContextΒΆ

  • 1991: Jacobs et al. introduce MoE for regression

  • 2017: Shazeer et al. apply MoE to neural machine translation (1000+ experts)

  • 2021: Google’s Switch Transformers achieve trillion-parameter models

  • 2022-2024: MoE becomes standard in large language models (GPT-4, Mixtral)

Why MoE?ΒΆ

Benefits:

  1. Model Capacity: \(N\) experts β‰ˆ \(N\)Γ— parameters without \(N\)Γ— computation

  2. Sample Efficiency: Experts specialize on different data patterns

  3. Transfer Learning: Experts can encode different domains/tasks

  4. Inference Speed: Sparse activation reduces compute

Challenges:

  1. Load Balancing: Prevent all inputs routing to same expert

  2. Training Instability: Gating can collapse to use few experts

  3. Communication Overhead: In distributed settings, routing is expensive

  4. Representation Collapse: Experts may learn redundant functions

2. Mathematical FormulationΒΆ

Standard MoE LayerΒΆ

Given input \(\mathbf{x} \in \mathbb{R}^d\), \(N\) expert networks \(E_1, \ldots, E_N\), and a gating network \(G\):

\[ \text{MoE}(\mathbf{x}) = \sum_{i=1}^{N} g_i(\mathbf{x}) \cdot E_i(\mathbf{x}) \]

where the gating function produces weights:

\[ \mathbf{g}(\mathbf{x}) = \text{Softmax}(G(\mathbf{x})) = \text{Softmax}(W_g \mathbf{x} + \mathbf{b}_g) \]

with \(g_i(\mathbf{x}) \in [0, 1]\) and \(\sum_{i=1}^N g_i(\mathbf{x}) = 1\).

Top-K Sparse GatingΒΆ

To reduce computation, only activate top-\(k\) experts:

\[ \text{MoE}_{\text{sparse}}(\mathbf{x}) = \sum_{i \in \text{Top-K}(\mathbf{g}(\mathbf{x}))} \frac{g_i(\mathbf{x})}{\sum_{j \in \text{Top-K}} g_j(\mathbf{x})} \cdot E_i(\mathbf{x}) \]

Common Choice: \(k=1\) (single expert) or \(k=2\) (two experts)

Computation Savings:

  • Dense: \(O(N \cdot C_{\text{expert}})\)

  • Sparse: \(O(k \cdot C_{\text{expert}})\)

  • Speedup: \(N/k\) (e.g., 8Γ— with \(N=8, k=1\))

Expert NetworksΒΆ

Each expert \(E_i\) is typically a feed-forward network:

\[ E_i(\mathbf{x}) = W_i^{(2)} \sigma(W_i^{(1)} \mathbf{x} + \mathbf{b}_i^{(1)}) + \mathbf{b}_i^{(2)} \]

Parameter Count:

  • Standard FFN: \(d \cdot d_{\text{ff}} + d_{\text{ff}} \cdot d\)

  • MoE with \(N\) experts: \(N \cdot (d \cdot d_{\text{ff}} + d_{\text{ff}} \cdot d)\)

  • Per-token compute: Same as standard FFN if \(k=1\)

3. Gating MechanismsΒΆ

3.1 Softmax Gating (Dense)ΒΆ

\[ g_i(\mathbf{x}) = \frac{\exp(w_i^T \mathbf{x})}{\sum_{j=1}^N \exp(w_j^T \mathbf{x})} \]

Pros: Smooth, differentiable Cons: All experts activated (expensive)

3.2 Top-K Gating with Noisy Top-KΒΆ

Add noise to logits to encourage exploration:

\[ H(\mathbf{x})_i = (W_g \mathbf{x})_i + \text{StandardNormal}() \cdot \text{Softplus}((W_{\text{noise}} \mathbf{x})_i) \]

Then select top-\(k\) from \(H(\mathbf{x})\) and renormalize:

\[\begin{split} g_i(\mathbf{x}) = \begin{cases} \frac{\exp(H(\mathbf{x})_i)}{\sum_{j \in \text{Top-K}} \exp(H(\mathbf{x})_j)} & \text{if } i \in \text{Top-K}(H(\mathbf{x})) \\ 0 & \text{otherwise} \end{cases} \end{split}\]

Benefit: Noise prevents gating from deterministically choosing same experts

3.3 Switch Routing (Top-1 Simplified)ΒΆ

Simplest form: route each token to single expert with highest logit:

\[ i^* = \arg\max_{i} (W_g \mathbf{x})_i \]
\[ \text{MoE}_{\text{switch}}(\mathbf{x}) = E_{i^*}(\mathbf{x}) \]

Advantages:

  • Minimal routing computation

  • Deterministic (during inference)

  • Highest sparsity

Gradient Flow: Use Straight-Through Estimator (STE) for backprop

3.4 Expert Choice RoutingΒΆ

Instead of tokens choosing experts, experts choose tokens:

  1. Each expert \(i\) selects top-\(k\) tokens with highest affinity: $\( \text{Tokens}_i = \text{Top-K}_{\text{tokens}}\left\{(W_g \mathbf{x}_j)_i \mid j=1,\ldots,T\right\} \)$

  2. Each selected token is processed by that expert

Benefit: Perfect load balancing (each expert processes exactly \(k\) tokens)

4. Sparsity and Load BalancingΒΆ

The Load Balancing ProblemΒΆ

Issue: Gating may route most tokens to few experts (collapse)

Consequences:

  • Underutilized experts (waste of capacity)

  • Overloaded experts (slow training, limited batch parallelism)

  • Poor generalization (experts don’t specialize)

Load Balancing LossΒΆ

Importance Loss (Shazeer et al., 2017):

Define importance of expert \(i\) over batch \(B\):

\[ \text{Importance}(i) = \sum_{\mathbf{x} \in B} g_i(\mathbf{x}) \]

We want uniform distribution: \(\text{Importance}(i) \approx |B| / N\) for all \(i\).

Load Balancing Loss:

\[ \mathcal{L}_{\text{load}} = \alpha \cdot \text{CV}^2(\text{Importance}) = \alpha \cdot \frac{N \cdot \sum_i (\text{Importance}(i) - \mu)^2}{\mu^2} \]

where \(\mu = |B| / N\) and \(\text{CV}\) is coefficient of variation.

Alternative Formulation (Switch Transformers):

\[ \mathcal{L}_{\text{aux}} = \alpha \cdot N \sum_{i=1}^N f_i \cdot P_i \]

where:

  • \(f_i\) = fraction of tokens dispatched to expert \(i\)

  • \(P_i\) = mean gating probability for expert \(i\)

Optimal Balance: \(f_i = P_i = 1/N\) ⟹ \(\mathcal{L}_{\text{aux}} = \alpha / N\)

Typical \(\alpha\): 0.01 to 0.1

Capacity FactorΒΆ

Limit tokens per expert to prevent overload:

\[ \text{Capacity}_i = \left(\frac{|B|}{N}\right) \cdot c \]

where \(c\) is capacity factor (typically 1.0 to 1.5).

Overflow Handling:

  1. Dropout: Discard overflow tokens (Switch Transformers)

  2. Residual: Pass overflow through residual connection

  3. Secondary Expert: Route to second-choice expert

5. Switch TransformersΒΆ

ArchitectureΒΆ

Key Innovation: Simplify MoE to top-1 routing for maximum sparsity

Switch Layer:

  1. Replace FFN in Transformer with MoE-FFN

  2. Use deterministic top-1 routing (no noise at inference)

  3. Add auxiliary load balancing loss

Scaling:

  • Switch-Base: 7B parameters (95% in MoE layers)

  • Switch-Large: 26B parameters

  • Switch-C: 1.6 trillion parameters

Expert Capacity and Token DroppingΒΆ

With \(T\) tokens, \(N\) experts, capacity factor \(c\):

\[ \text{Buffer size per expert} = \frac{T}{N} \cdot c \]

Example: \(T=1024\), \(N=8\), \(c=1.25\) ⟹ 160 tokens per expert

If expert receives \(>160\) tokens, overflow is dropped (skips expert processing).

Trade-off:

  • Low \(c\): More drops, faster training, potential quality loss

  • High \(c\): Fewer drops, slower training, better quality

Emperical Finding: \(c=1.0\) to \(1.25\) optimal

Routing ImprovementsΒΆ

1. Router Z-Loss

Penalize large logits (prevents numerical instability):

\[ \mathcal{L}_{z} = \frac{1}{T} \sum_{t=1}^T \left(\log \sum_{i=1}^N \exp(x_t^T w_i)\right)^2 \]

where \(w_i\) are router weights.

2. Per-Expert Gradients

Compute gradients separately per expert to prevent interference.

3. Lower Precision Training

Use bfloat16 for expert computations (faster, memory efficient).

6. Expert Choice RoutingΒΆ

MotivationΒΆ

Problem with Token Choice: Tokens independently choose experts β†’ unbalanced load

Solution: Experts choose tokens β†’ perfect balance

AlgorithmΒΆ

For \(T\) tokens, \(N\) experts, expert capacity \(C = T/N\):

  1. Compute Affinities: $\( S_{ij} = \mathbf{x}_j^T \mathbf{w}_i \quad \text{(token \)j\( affinity to expert \)i\()} \)$

  2. Expert Selection: For each expert \(i\): $\( \text{Selected}(i) = \text{Top-C}_{j} \{S_{ij}\} \)$

  3. Gating Weights: Normalize affinities for selected tokens: $\( g_{ij} = \frac{\exp(S_{ij})}{\sum_{k \in \text{Selected}(i)} \exp(S_{ik})} \)$

  4. Expert Output: $\( \mathbf{y}_j = \sum_{i: j \in \text{Selected}(i)} g_{ij} \cdot E_i(\mathbf{x}_j) \)$

AdvantagesΒΆ

  1. Perfect Load Balance: Each expert processes exactly \(C\) tokens

  2. No Auxiliary Loss: No need for load balancing regularization

  3. Predictable Performance: No token dropping

  4. Better Specialization: Experts have global view

ImplementationΒΆ

Efficient Top-K Selection:

  • Use per-expert priority queues

  • Complexity: \(O(T \cdot N \log C)\)

Memory Layout:

  • Gather tokens into expert-specific buffers

  • Process all experts in parallel

  • Scatter results back to token positions

7. Hierarchical MoEΒΆ

Two-Level HierarchyΒΆ

Motivation: Coarse-grained routing reduces routing overhead

Architecture:

  1. Primary Router: Route to expert group \(G_k\) (\(K\) groups of \(M\) experts) $\( k^* = \arg\max_k \mathbf{x}^T \mathbf{w}_k^{\text{primary}} \)$

  2. Secondary Router: Within group \(G_{k^*}\), route to expert \(E_i\) $\( i^* = \arg\max_{i \in G_{k^*}} \mathbf{x}^T \mathbf{w}_i^{\text{secondary}} \)$

  3. Output: $\( \text{MoE}(\mathbf{x}) = E_{i^*}(\mathbf{x}) \)$

Parameter Count: \(K + K \cdot M = K(1 + M)\) routing parameters vs. \(K \cdot M\) for flat routing

Use Case: Routing in distributed settings (primary = machine, secondary = expert on machine)

Multi-Path Hierarchical MoEΒΆ

Route to top-\(k_1\) groups, then top-\(k_2\) experts per group:

\[ \text{MoE}(\mathbf{x}) = \sum_{g \in \text{Top-}k_1} \sum_{i \in \text{Top-}k_2(g)} w_{gi} \cdot E_{gi}(\mathbf{x}) \]

where \(w_{gi}\) combines primary and secondary routing weights.

8. Training Dynamics and ChallengesΒΆ

Gradient FlowΒΆ

Challenge: Only \(k\) out of \(N\) experts receive gradients per example

Consequence: Experts update at rate \(\approx k/N\) of dense model

Solutions:

  1. Larger Batch Sizes: Ensure all experts see gradients frequently

  2. Higher Learning Rates: Compensate for sparse gradients

  3. Gradient Accumulation: Accumulate over multiple batches

Router CollapseΒΆ

Problem: Router converges to use only few experts

Indicators:

  • High load balancing loss

  • Few experts have high importance

  • Validation loss diverges from training

Prevention:

  1. Auxiliary Loss: Penalize imbalance

  2. Router Initialization: Random/orthogonal initialization

  3. Warmup: Start with lower sparsity, gradually increase

  4. Expert Dropout: Randomly drop expert activations

Fine-Tuning MoEΒΆ

Challenge: Fine-tuning on small datasets can cause catastrophic forgetting

Strategies:

  1. Freeze Router: Only update expert parameters

  2. LoRA on Experts: Add low-rank adapters to each expert

  3. Distillation: Distill MoE to dense model for deployment

  4. Task-Specific Experts: Add new experts for new task, keep old frozen

Communication in Distributed TrainingΒΆ

All-to-All Communication: In distributed MoE, tokens are scattered to experts on different GPUs

Cost: \(O(\text{model size} / \text{num devices})\) per routing step

Optimizations:

  1. Expert Parallelism: Partition experts across devices

  2. Local Experts: Constrain routing to same device

  3. 2D Parallelism: Combine expert + data parallelism

  4. Asynchronous Routing: Overlap communication and computation

Example (DeepSpeed-MoE):

  • 128 GPUs, 512 experts

  • 4 experts per GPU

  • All-to-all brings tokens for each expert’s 4 experts

  • Bandwidth: ~100 GB/s per GPU

9. Theoretical AnalysisΒΆ

ExpressivenessΒΆ

Theorem (Jacobs et al., 1991): A mixture of \(N\) linear experts can represent any piecewise linear function with at most \(N\) pieces.

Extension to Neural Experts: With ReLU activations, MoE can approximate arbitrary continuous functions with fewer parameters than dense network.

Sample ComplexityΒΆ

Divide-and-Conquer Benefit:

If data distribution has \(K\) distinct modes, and each expert specializes in one mode:

\[ \text{Sample Complexity}(MoE) \approx O\left(\frac{d \log K}{K}\right) \]

vs. dense model:

\[ \text{Sample Complexity}(Dense) \approx O(d \log K) \]

Intuition: Each expert learns simpler function (lower capacity) on specialized data subset.

GeneralizationΒΆ

Generalization Bound (simplified):

With probability \(\geq 1 - \delta\):

\[ \text{Error}_{\text{test}} \leq \text{Error}_{\text{train}} + O\left(\sqrt{\frac{N \cdot \text{Capacity}(\text{expert}) \log(1/\delta)}{n}}\right) \]

Key Insight: Effective capacity is \(N \cdot \text{Capacity}(\text{expert})\), but with sparse routing (\(k \ll N\)), generalization improves.

10. Modern ApplicationsΒΆ

Large Language ModelsΒΆ

GPT-4 (rumored): 8 experts, 220B parameters each, ~1.76T total Mixtral 8x7B: 8 experts, 7B each, 47B total, 13B active GLaM: 64 experts, 1.2T parameters

Benefits:

  • Faster inference than dense models of similar quality

  • Better multilingual performance (experts specialize by language)

  • Task specialization (coding, math, creative writing)

Vision ModelsΒΆ

V-MoE (Vision MoE): Apply MoE to Vision Transformers

  • Replace every other FFN layer with MoE

  • 30Γ— parameter increase, 2Γ— training time

  • SOTA on ImageNet with fewer FLOPs

BEiT-3: Unified vision-language model with MoE

  • Experts specialize in vision vs. language vs. cross-modal

Multimodal ModelsΒΆ

Multimodal MoE: Different experts for different modalities

Example Architecture:

  • Expert 1-4: Text processing

  • Expert 5-8: Image processing

  • Expert 9-12: Video processing

  • Router learns modality-aware routing

Benefits: Efficient scaling to many modalities

Recommendation SystemsΒΆ

Challenges: Billions of users, items, need real-time inference

MoE Solution:

  • Experts specialize by user segment (geography, behavior cluster)

  • Low latency (top-1 routing)

  • Personalization (user-aware routing)

11. Implementation ConsiderationsΒΆ

Memory ManagementΒΆ

Expert Parameters: Typically stored on same device, but can be distributed:

# Single-GPU: All experts on same GPU
experts = nn.ModuleList([Expert(d) for _ in range(N)])

# Multi-GPU: Experts distributed across GPUs
experts = [Expert(d).to(f'cuda:{i % num_gpus}') for i in range(N)]

Activation Memory: With batch size \(B\), sequence length \(T\):

  • Dense FFN: \(O(B \cdot T \cdot d_{\text{ff}})\)

  • MoE: \(O(k \cdot B \cdot T \cdot d_{\text{ff}})\) (same if \(k=1\))

Inference OptimizationΒΆ

Batching: Group requests by selected expert to maximize GPU utilization

Caching: For static routing (e.g., language-specific experts), cache expert selection

Quantization: Quantize experts to int8/int4 (60-75% compression)

Distillation: Distill MoE to smaller dense model: $\( \mathcal{L}_{\text{distill}} = \text{KL}\left(P_{\text{student}} \| P_{\text{MoE}}\right) \)$

DebuggingΒΆ

Common Issues:

  1. Router Collapse: Check expert utilization histogram

  2. NaN Losses: Reduce learning rate, add gradient clipping

  3. Poor Load Balance: Increase auxiliary loss weight

  4. Memory OOM: Reduce capacity factor or batch size

Monitoring:

  • Expert utilization per batch

  • Router entropy: \(H = -\sum_i p_i \log p_i\) (should be \(\approx \log N\))

  • Auxiliary loss value

  • Per-expert gradient norms

12. Recent Advances (2022-2024)ΒΆ

Soft Merging of Experts (MoE-Merge)ΒΆ

Idea: Instead of discrete routing, merge expert parameters:

\[ W_{\text{merged}} = \sum_{i=1}^N \alpha_i W_i \]

where \(\alpha_i\) learned via meta-learning.

Benefit: Single-model inference (no routing overhead)

Sparse UpcyclingΒΆ

Convert pretrained dense model to MoE:

  1. Initialize each expert with pretrained weights

  2. Add small random perturbations to break symmetry

  3. Train router from scratch, freeze experts initially

  4. Gradually unfreeze experts

Result: Faster convergence than training MoE from scratch

MoE with RetrievalΒΆ

Combine MoE with retrieval-augmented generation:

  1. Retrieve relevant documents

  2. Route query + documents to specialized expert

  3. Experts fine-tuned on different knowledge domains

Example: Medical expert, legal expert, scientific expert

Continuous Expert ModelsΒΆ

Replace discrete experts with continuous expert space:

\[ E(\mathbf{x}, \boldsymbol{\theta}(\mathbf{x})) \]

where \(\boldsymbol{\theta}(\mathbf{x})\) is a continuous function of input.

Implementation: Hypernetwork generates expert parameters conditioned on \(\mathbf{x}\).

Best PracticesΒΆ

When to Use MoEΒΆ

Good Fit:

  • Large datasets with diverse distributions

  • Need to scale model capacity without proportional compute

  • Multi-task, multi-domain, or multilingual scenarios

  • Have distributed training infrastructure

Poor Fit:

  • Small datasets (experts won’t specialize)

  • Real-time inference with strict latency (routing overhead)

  • Limited training resources (complex to tune)

  • Need interpretability (routing is learned, not explicit)

Hyperparameter TuningΒΆ

Priority Order:

  1. Number of Experts \(N\): Start with 4-8, increase if underutilized

  2. Top-K: Usually 1-2 (higher = more compute, less specialization)

  3. Capacity Factor \(c\): 1.0-1.5 (higher = fewer drops, slower)

  4. Auxiliary Loss \(\alpha\): 0.01-0.1 (tune to balance load)

  5. Expert Size: Often same as original FFN

Tuning Strategy:

  1. Train small MoE on subset of data

  2. Monitor expert utilization and auxiliary loss

  3. Adjust \(\alpha\) if utilization imbalanced

  4. Scale to full model

EvaluationΒΆ

Metrics:

  • Perplexity/Accuracy: Standard task metrics

  • Router Entropy: \(H = -\sum p_i \log p_i\) (high = good)

  • Expert Utilization: Fraction of tokens per expert

  • Efficiency: FLOPs per token vs. dense model

  • Load Balance: Coefficient of variation in expert load

Ablations:

  • MoE vs. dense with same FLOPs

  • MoE vs. dense with same parameters

  • Top-1 vs. top-2 routing

  • With vs. without auxiliary loss

Common PitfallsΒΆ

  1. Insufficient Batch Size: Experts don’t see enough gradients

    • Fix: Increase batch size or gradient accumulation

  2. Router Overfitting: Router uses same experts for all inputs

    • Fix: Add noise to routing, increase auxiliary loss

  3. Ignoring Communication Costs: All-to-all can dominate runtime

    • Fix: Use expert parallelism, profile communication

  4. Forgetting Capacity Factor: Dropping too many tokens

    • Fix: Monitor drop rate, increase \(c\) if needed

  5. Equal Expert Sizes: Some tasks need bigger experts

    • Fix: Use heterogeneous expert sizes

Future DirectionsΒΆ

  1. Adaptive Sparsity: Learn \(k\) per input (easy vs. hard examples)

  2. Hierarchical Specialization: Experts within experts

  3. Neurosymbolic MoE: Symbolic reasoning experts + neural experts

  4. Privacy-Preserving MoE: Federated experts (user data stays local)

  5. Dynamic Expert Addition: Continual learning with new experts for new tasks

  6. Energy-Efficient MoE: Hardware-aware routing for edge deployment

SummaryΒΆ

Mixture of Experts enables scaling neural networks to trillions of parameters while maintaining inference efficiency through sparse activation. Key innovations include:

  • Switch Transformers: Simplified top-1 routing with load balancing

  • Expert Choice: Inverting routing for perfect load balance

  • Hierarchical MoE: Reducing routing overhead in distributed settings

  • Soft Merging: Combining experts for deployment

Success Stories:

  • GPT-4, Mixtral, GLaM: SOTA language models

  • V-MoE: Efficient vision models

  • DeepSpeed-MoE: Distributed training framework

Key Takeaway: MoE is not just about scalingβ€”it’s about learning to specialize and conditionally compute, which is fundamental to how biological brains work.

ReferencesΒΆ

Foundational Papers:

  1. Jacobs et al. (1991): Adaptive Mixtures of Local Experts

  2. Shazeer et al. (2017): Outrageously Large Neural Networks

  3. Lepikhin et al. (2020): GShard

  4. Fedus et al. (2021): Switch Transformers

  5. Zhou et al. (2022): Mixture-of-Experts with Expert Choice Routing

Modern Applications: 6. Riquelme et al. (2021): Scaling Vision with Sparse MoE 7. Mustafa et al. (2022): Multimodal Contrastive Learning with MoE 8. Zoph (2022): ST-MoE

Analysis: 9. Artetxe et al. (2021): Specialization in MoE 10. Clark et al. (2022): Unified Scaling Laws for MoE

"""
Advanced Mixture of Experts - Production Implementation
Comprehensive implementations of modern MoE architectures

Methods Implemented:
1. Standard MoE with Top-K Gating
2. Noisy Top-K Gating (Shazeer et al.)
3. Switch Transformer (Top-1)
4. Expert Choice Routing
5. Hierarchical MoE
6. Load Balancing Mechanisms

Features:
- Efficient sparse routing
- Load balancing with auxiliary losses
- Expert capacity and token dropping
- Visualization of expert utilization
- Distributed MoE support foundations

Author: Advanced Deep Learning Course
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass


# ============================================================================
# 1. Expert Networks
# ============================================================================

class Expert(nn.Module):
    """Single expert: a feed-forward network."""
    
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [batch_size, input_dim]
        Returns:
            Output tensor [batch_size, output_dim]
        """
        h = self.activation(self.fc1(x))
        h = self.dropout(h)
        return self.fc2(h)


# ============================================================================
# 2. Standard MoE with Top-K Gating
# ============================================================================

class TopKGate(nn.Module):
    """
    Top-K Gating mechanism.
    
    Routes each input to top-k experts based on learned gating weights.
    """
    
    def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Gating network: linear projection to expert scores
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            gates: Gating weights [batch_size, top_k]
            indices: Expert indices [batch_size, top_k]
            load: Expert load for load balancing [num_experts]
        """
        # Compute gating scores
        logits = self.gate(x)  # [batch_size, num_experts]
        
        # Select top-k experts
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        
        # Compute gating weights (softmax over top-k)
        top_k_gates = F.softmax(top_k_logits, dim=-1)
        
        # Compute load for each expert (for load balancing)
        # This is the fraction of tokens routed to each expert
        gates_full = torch.zeros_like(logits).scatter_(1, top_k_indices, top_k_gates)
        load = gates_full.sum(0)  # [num_experts]
        
        return top_k_gates, top_k_indices, load


class MixtureOfExperts(nn.Module):
    """
    Standard Mixture of Experts layer with Top-K gating.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
        load_balance_weight: float = 0.01
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.load_balance_weight = load_balance_weight
        
        # Create experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = TopKGate(input_dim, num_experts, top_k)
        
        # Track expert statistics
        self.register_buffer('expert_counts', torch.zeros(num_experts))
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            output: MoE output [batch_size, output_dim]
            aux_loss: Load balancing auxiliary loss (scalar)
        """
        batch_size = x.size(0)
        
        # Get gating decisions
        gates, indices, load = self.gate(x)  # gates: [B, k], indices: [B, k], load: [E]
        
        # Initialize output
        output = torch.zeros(batch_size, self.experts[0].fc2.out_features, device=x.device)
        
        # Route to experts and aggregate
        for i in range(self.top_k):
            expert_indices = indices[:, i]  # [batch_size]
            expert_gates = gates[:, i:i+1]  # [batch_size, 1]
            
            # Process each expert
            for expert_id in range(self.num_experts):
                # Get inputs for this expert
                mask = (expert_indices == expert_id)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[expert_id](expert_input)
                    
                    # Weighted aggregate
                    output[mask] += expert_gates[mask] * expert_output
                    
                    # Track usage
                    self.expert_counts[expert_id] += mask.sum()
        
        # Compute load balancing loss
        aux_loss = self.compute_load_balance_loss(load, batch_size)
        
        return output, aux_loss
    
    def compute_load_balance_loss(self, load: torch.Tensor, batch_size: int) -> torch.Tensor:
        """
        Importance loss: encourages equal expert utilization.
        
        Args:
            load: Expert loads [num_experts]
            batch_size: Number of tokens
        
        Returns:
            aux_loss: Load balancing loss (scalar)
        """
        # Expected load per expert
        expected_load = (batch_size * self.top_k) / self.num_experts
        
        # Coefficient of variation squared
        mean_load = load.mean()
        std_load = load.std()
        cv_squared = (std_load / (mean_load + 1e-10)) ** 2
        
        aux_loss = self.load_balance_weight * self.num_experts * cv_squared
        
        return aux_loss


# ============================================================================
# 3. Noisy Top-K Gating
# ============================================================================

class NoisyTopKGate(nn.Module):
    """
    Noisy Top-K Gating (Shazeer et al., 2017).
    
    Adds tunable Gaussian noise to encourage exploration during training.
    """
    
    def __init__(self, input_dim: int, num_experts: int, top_k: int = 2, noise_std: float = 1.0):
        super().__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        
        # Gating networks
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        self.noise_gate = nn.Linear(input_dim, num_experts, bias=False)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            gates: Gating weights [batch_size, top_k]
            indices: Expert indices [batch_size, top_k]
            load: Expert load [num_experts]
        """
        # Clean gating scores
        logits = self.gate(x)  # [batch_size, num_experts]
        
        # Add noise during training
        if self.training:
            noise_logits = self.noise_gate(x)
            noise = torch.randn_like(logits) * self.noise_std
            noise = noise * F.softplus(noise_logits)
            logits = logits + noise
        
        # Select top-k
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        top_k_gates = F.softmax(top_k_logits, dim=-1)
        
        # Compute load
        gates_full = torch.zeros_like(logits).scatter_(1, top_k_indices, top_k_gates)
        load = gates_full.sum(0)
        
        return top_k_gates, top_k_indices, load


# ============================================================================
# 4. Switch Transformer (Top-1 Routing)
# ============================================================================

@dataclass
class SwitchConfig:
    """Configuration for Switch Transformer."""
    capacity_factor: float = 1.25
    drop_tokens: bool = True
    load_balance_weight: float = 0.01


class SwitchGate(nn.Module):
    """
    Switch routing: simplified top-1 gating.
    
    Each token routed to exactly one expert.
    """
    
    def __init__(self, input_dim: int, num_experts: int, config: SwitchConfig):
        super().__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.config = config
        
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            dispatch_mask: Routing decisions [num_experts, capacity, batch_size]
            combine_weights: Combining weights [batch_size, num_experts, capacity]
            metadata: Dict with auxiliary loss and statistics
        """
        batch_size = x.size(0)
        
        # Compute routing logits
        logits = self.gate(x)  # [batch_size, num_experts]
        
        # Top-1 routing
        expert_indices = torch.argmax(logits, dim=-1)  # [batch_size]
        expert_gates = F.softmax(logits, dim=-1)  # [batch_size, num_experts]
        
        # Capacity per expert
        capacity = int((batch_size / self.num_experts) * self.config.capacity_factor)
        
        # Create dispatch mask and combine weights
        dispatch_mask, combine_weights, dropped_tokens = self._create_dispatch_and_combine(
            expert_indices, expert_gates, batch_size, capacity
        )
        
        # Compute auxiliary loss
        aux_loss = self._compute_switch_loss(expert_indices, expert_gates, batch_size)
        
        metadata = {
            'aux_loss': aux_loss,
            'dropped_tokens': dropped_tokens,
            'expert_counts': torch.bincount(expert_indices, minlength=self.num_experts)
        }
        
        return dispatch_mask, combine_weights, metadata
    
    def _create_dispatch_and_combine(
        self,
        expert_indices: torch.Tensor,
        expert_gates: torch.Tensor,
        batch_size: int,
        capacity: int
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """Create dispatch mask and combine weights with capacity constraints."""
        device = expert_indices.device
        
        # Initialize
        dispatch_mask = torch.zeros(self.num_experts, capacity, batch_size, device=device)
        combine_weights = torch.zeros(batch_size, self.num_experts, capacity, device=device)
        
        # Track position in each expert's buffer
        expert_positions = torch.zeros(self.num_experts, dtype=torch.long, device=device)
        dropped_tokens = 0
        
        for i in range(batch_size):
            expert = expert_indices[i].item()
            pos = expert_positions[expert].item()
            
            if pos < capacity:
                # Dispatch to expert
                dispatch_mask[expert, pos, i] = 1
                combine_weights[i, expert, pos] = expert_gates[i, expert]
                expert_positions[expert] += 1
            else:
                # Overflow: drop token
                dropped_tokens += 1
        
        return dispatch_mask, combine_weights, dropped_tokens
    
    def _compute_switch_loss(
        self,
        expert_indices: torch.Tensor,
        expert_gates: torch.Tensor,
        batch_size: int
    ) -> torch.Tensor:
        """
        Switch auxiliary loss.
        
        L_aux = Ξ± * N * Ξ£_i (f_i * P_i)
        where f_i = fraction routed to expert i, P_i = mean gate probability for expert i
        """
        # Fraction of tokens routed to each expert
        counts = torch.bincount(expert_indices, minlength=self.num_experts).float()
        f = counts / batch_size  # [num_experts]
        
        # Mean gate probability per expert
        P = expert_gates.mean(dim=0)  # [num_experts]
        
        # Auxiliary loss
        aux_loss = self.config.load_balance_weight * self.num_experts * (f * P).sum()
        
        return aux_loss


class SwitchFFN(nn.Module):
    """
    Switch Feed-Forward Network.
    
    Replaces standard FFN with MoE using switch routing.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_experts: int = 8,
        config: Optional[SwitchConfig] = None
    ):
        super().__init__()
        self.num_experts = num_experts
        self.config = config or SwitchConfig()
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, input_dim) for _ in range(num_experts)
        ])
        
        # Switch gate
        self.gate = SwitchGate(input_dim, num_experts, self.config)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            output: MoE output [batch_size, input_dim]
            metadata: Routing statistics and aux loss
        """
        batch_size = x.size(0)
        capacity = int((batch_size / self.num_experts) * self.config.capacity_factor)
        
        # Get routing decisions
        dispatch_mask, combine_weights, metadata = self.gate(x)
        
        # Dispatch to experts
        expert_outputs = []
        for expert_id in range(self.num_experts):
            # Get tokens for this expert
            mask = dispatch_mask[expert_id]  # [capacity, batch_size]
            expert_input = torch.einsum('cb,bd->cd', mask, x)  # [capacity, input_dim]
            
            # Process
            expert_output = self.experts[expert_id](expert_input)  # [capacity, input_dim]
            expert_outputs.append(expert_output)
        
        # Combine expert outputs
        output = torch.zeros_like(x)
        for expert_id in range(self.num_experts):
            weights = combine_weights[:, expert_id, :]  # [batch_size, capacity]
            expert_out = expert_outputs[expert_id]  # [capacity, input_dim]
            output += torch.einsum('bc,cd->bd', weights, expert_out)
        
        return output, metadata


# ============================================================================
# 5. Expert Choice Routing
# ============================================================================

class ExpertChoiceGate(nn.Module):
    """
    Expert Choice Routing.
    
    Experts select tokens instead of tokens selecting experts.
    """
    
    def __init__(self, input_dim: int, num_experts: int, capacity_per_expert: int):
        super().__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.capacity = capacity_per_expert
        
        # Expert selection weights
        self.expert_weights = nn.Parameter(torch.randn(num_experts, input_dim))
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            dispatch_mask: [num_experts, capacity, batch_size]
            combine_weights: [batch_size, num_experts, capacity]
        """
        batch_size = x.size(0)
        device = x.device
        
        # Compute affinity scores: S[i,j] = expert_i^T * token_j
        scores = torch.matmul(self.expert_weights, x.t())  # [num_experts, batch_size]
        
        # Each expert selects top-capacity tokens
        _, selected_indices = torch.topk(scores, self.capacity, dim=1)  # [num_experts, capacity]
        
        # Create dispatch mask
        dispatch_mask = torch.zeros(self.num_experts, self.capacity, batch_size, device=device)
        for expert_id in range(self.num_experts):
            for pos in range(self.capacity):
                token_id = selected_indices[expert_id, pos]
                dispatch_mask[expert_id, pos, token_id] = 1
        
        # Compute gating weights (softmax over experts that selected each token)
        combine_weights = torch.zeros(batch_size, self.num_experts, self.capacity, device=device)
        for token_id in range(batch_size):
            # Find experts that selected this token
            selecting_experts = []
            for expert_id in range(self.num_experts):
                for pos in range(self.capacity):
                    if selected_indices[expert_id, pos] == token_id:
                        selecting_experts.append((expert_id, pos, scores[expert_id, token_id]))
            
            if selecting_experts:
                # Normalize scores
                expert_ids, positions, token_scores = zip(*selecting_experts)
                token_scores = torch.tensor(token_scores, device=device)
                weights = F.softmax(token_scores, dim=0)
                
                for (eid, pos), weight in zip(zip(expert_ids, positions), weights):
                    combine_weights[token_id, eid, pos] = weight
        
        return dispatch_mask, combine_weights


class ExpertChoiceMoE(nn.Module):
    """MoE with Expert Choice routing."""
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_experts: int = 8,
        tokens_per_expert: int = 32
    ):
        super().__init__()
        self.num_experts = num_experts
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, input_dim) for _ in range(num_experts)
        ])
        
        # Expert choice gate
        self.gate = ExpertChoiceGate(input_dim, num_experts, tokens_per_expert)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            output: [batch_size, input_dim]
        """
        # Get routing
        dispatch_mask, combine_weights = self.gate(x)
        
        # Dispatch and process
        expert_outputs = []
        for expert_id in range(self.num_experts):
            mask = dispatch_mask[expert_id]  # [capacity, batch_size]
            expert_input = torch.einsum('cb,bd->cd', mask, x)
            expert_output = self.experts[expert_id](expert_input)
            expert_outputs.append(expert_output)
        
        # Combine
        output = torch.zeros_like(x)
        for expert_id in range(self.num_experts):
            weights = combine_weights[:, expert_id, :]  # [batch_size, capacity]
            expert_out = expert_outputs[expert_id]
            output += torch.einsum('bc,cd->bd', weights, expert_out)
        
        return output


# ============================================================================
# 6. Hierarchical MoE
# ============================================================================

class HierarchicalMoE(nn.Module):
    """
    Two-level hierarchical MoE.
    
    First routes to expert group, then to expert within group.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_groups: int = 4,
        experts_per_group: int = 4
    ):
        super().__init__()
        self.num_groups = num_groups
        self.experts_per_group = experts_per_group
        self.total_experts = num_groups * experts_per_group
        
        # Primary router: select group
        self.primary_gate = nn.Linear(input_dim, num_groups, bias=False)
        
        # Secondary routers: select expert within group
        self.secondary_gates = nn.ModuleList([
            nn.Linear(input_dim, experts_per_group, bias=False)
            for _ in range(num_groups)
        ])
        
        # Experts organized by group
        self.expert_groups = nn.ModuleList([
            nn.ModuleList([
                Expert(input_dim, hidden_dim, input_dim)
                for _ in range(experts_per_group)
            ])
            for _ in range(num_groups)
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input [batch_size, input_dim]
        
        Returns:
            output: [batch_size, input_dim]
        """
        batch_size = x.size(0)
        
        # Primary routing: select group
        primary_logits = self.primary_gate(x)
        group_indices = torch.argmax(primary_logits, dim=-1)  # [batch_size]
        
        # Initialize output
        output = torch.zeros_like(x)
        
        # Process each group
        for group_id in range(self.num_groups):
            # Get tokens for this group
            group_mask = (group_indices == group_id)
            if not group_mask.any():
                continue
            
            group_inputs = x[group_mask]  # [group_size, input_dim]
            
            # Secondary routing within group
            secondary_logits = self.secondary_gates[group_id](group_inputs)
            expert_indices = torch.argmax(secondary_logits, dim=-1)
            
            # Process by experts in this group
            group_output = torch.zeros_like(group_inputs)
            for expert_id in range(self.experts_per_group):
                expert_mask = (expert_indices == expert_id)
                if expert_mask.any():
                    expert_input = group_inputs[expert_mask]
                    expert_output = self.expert_groups[group_id][expert_id](expert_input)
                    group_output[expert_mask] = expert_output
            
            # Assign back to output
            output[group_mask] = group_output
        
        return output


# ============================================================================
# 7. Visualization and Analysis
# ============================================================================

def visualize_expert_utilization(
    expert_counts: torch.Tensor,
    num_experts: int,
    title: str = "Expert Utilization"
):
    """Visualize expert usage distribution."""
    counts = expert_counts.cpu().numpy()
    
    plt.figure(figsize=(10, 4))
    plt.bar(range(num_experts), counts)
    plt.xlabel('Expert ID')
    plt.ylabel('Number of Tokens')
    plt.title(title)
    plt.axhline(y=counts.mean(), color='r', linestyle='--', label='Mean')
    plt.legend()
    plt.tight_layout()
    plt.show()


def compute_routing_entropy(gates: torch.Tensor) -> float:
    """
    Compute entropy of routing distribution.
    
    High entropy = balanced routing
    Low entropy = collapsed routing
    """
    # gates: [batch_size, num_experts]
    mean_probs = gates.mean(dim=0)  # [num_experts]
    entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum()
    return entropy.item()


# ============================================================================
# 8. Demo
# ============================================================================

def demo_moe_comparison():
    """Compare different MoE routing strategies."""
    print("=" * 80)
    print("Mixture of Experts - Routing Strategy Comparison")
    print("=" * 80)
    
    # Setup
    batch_size = 128
    input_dim = 64
    hidden_dim = 256
    num_experts = 8
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")
    
    # Random input
    x = torch.randn(batch_size, input_dim).to(device)
    
    # Models
    models = {
        'Top-2 MoE': MixtureOfExperts(input_dim, hidden_dim, input_dim, num_experts, top_k=2),
        'Switch (Top-1)': SwitchFFN(input_dim, hidden_dim, num_experts),
        'Expert Choice': ExpertChoiceMoE(input_dim, hidden_dim, num_experts, tokens_per_expert=16),
        'Hierarchical': HierarchicalMoE(input_dim, hidden_dim, num_groups=4, experts_per_group=2)
    }
    
    print(f"\nInput: {x.shape}")
    print(f"Number of experts: {num_experts}")
    
    results = {}
    
    for name, model in models.items():
        model = model.to(device)
        model.eval()
        
        print(f"\n{'='*60}")
        print(f"Model: {name}")
        print(f"{'='*60}")
        
        # Forward pass
        with torch.no_grad():
            if name == 'Top-2 MoE':
                output, aux_loss = model(x)
                print(f"Output shape: {output.shape}")
                print(f"Auxiliary loss: {aux_loss.item():.6f}")
                print(f"Expert utilization:")
                visualize_expert_utilization(model.expert_counts, num_experts, f"{name} - Expert Usage")
                
            elif name == 'Switch (Top-1)':
                output, metadata = model(x)
                print(f"Output shape: {output.shape}")
                print(f"Auxiliary loss: {metadata['aux_loss'].item():.6f}")
                print(f"Dropped tokens: {metadata['dropped_tokens']}")
                print(f"Expert counts: {metadata['expert_counts'].cpu().numpy()}")
                visualize_expert_utilization(metadata['expert_counts'], num_experts, f"{name} - Expert Usage")
                
            else:
                output = model(x)
                print(f"Output shape: {output.shape}")
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        active_params = total_params // num_experts if 'Hierarchical' not in name else total_params // (4 * 2)
        
        print(f"\nTotal parameters: {total_params:,}")
        print(f"Active parameters per token: ~{active_params:,}")
        print(f"Sparsity ratio: {total_params / active_params:.1f}x")
        
        results[name] = {
            'output': output,
            'params': total_params,
            'active_params': active_params
        }
    
    # Summary
    print("\n" + "=" * 80)
    print("Summary")
    print("=" * 80)
    
    for name, result in results.items():
        print(f"\n{name}:")
        print(f"  Total parameters: {result['params']:,}")
        print(f"  Active per token: {result['active_params']:,}")
        print(f"  Efficiency: {result['params'] / result['active_params']:.1f}x")


def demo_load_balancing():
    """Demonstrate importance of load balancing."""
    print("\n" + "=" * 80)
    print("Load Balancing Impact Demo")
    print("=" * 80)
    
    batch_size = 256
    input_dim = 32
    hidden_dim = 128
    num_experts = 8
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create two models: with and without load balancing
    model_with_lb = MixtureOfExperts(
        input_dim, hidden_dim, input_dim,
        num_experts, top_k=2, load_balance_weight=0.1
    ).to(device)
    
    model_without_lb = MixtureOfExperts(
        input_dim, hidden_dim, input_dim,
        num_experts, top_k=2, load_balance_weight=0.0
    ).to(device)
    
    # Train briefly to see effect
    optimizer_with = torch.optim.Adam(model_with_lb.parameters(), lr=0.001)
    optimizer_without = torch.optim.Adam(model_without_lb.parameters(), lr=0.001)
    
    num_steps = 100
    
    print("\nTraining for {} steps...".format(num_steps))
    
    for step in range(num_steps):
        x = torch.randn(batch_size, input_dim).to(device)
        target = torch.randn(batch_size, input_dim).to(device)
        
        # With load balancing
        optimizer_with.zero_grad()
        output_with, aux_loss_with = model_with_lb(x)
        loss_with = F.mse_loss(output_with, target) + aux_loss_with
        loss_with.backward()
        optimizer_with.step()
        
        # Without load balancing
        optimizer_without.zero_grad()
        output_without, aux_loss_without = model_without_lb(x)
        loss_without = F.mse_loss(output_without, target)
        loss_without.backward()
        optimizer_without.step()
    
    print("\nFinal Expert Utilization:")
    print("\nWith Load Balancing:")
    print(model_with_lb.expert_counts.cpu().numpy())
    visualize_expert_utilization(
        model_with_lb.expert_counts, num_experts,
        "With Load Balancing"
    )
    
    print("\nWithout Load Balancing:")
    print(model_without_lb.expert_counts.cpu().numpy())
    visualize_expert_utilization(
        model_without_lb.expert_counts, num_experts,
        "Without Load Balancing"
    )
    
    # Compute utilization statistics
    with_lb_std = model_with_lb.expert_counts.float().std().item()
    without_lb_std = model_without_lb.expert_counts.float().std().item()
    
    print(f"\nUtilization Std Dev:")
    print(f"  With LB: {with_lb_std:.2f}")
    print(f"  Without LB: {without_lb_std:.2f}")
    print(f"  Improvement: {without_lb_std / with_lb_std:.2f}x more balanced")


if __name__ == "__main__":
    print("\nAdvanced Mixture of Experts - Comprehensive Implementation\n")
    
    # Run demos
    demo_moe_comparison()
    demo_load_balancing()
    
    print("\n" + "=" * 80)
    print("Demo complete! Key insights:")
    print("=" * 80)
    print("1. MoE enables scaling to many parameters with sparse activation")
    print("2. Switch routing (top-1) is simplest and most efficient")
    print("3. Expert Choice provides perfect load balancing")
    print("4. Hierarchical MoE reduces routing overhead")
    print("5. Load balancing loss is critical for training stability")
    print("6. Different routing strategies trade off complexity vs. performance")
    print("=" * 80)