import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Mixture of Experts ConceptΒΆ
Gating Network:ΒΆ
MoE Output:ΒΆ
where \(E_i\) are expert networks.
π Reference Materials:
foundation_neural_network.pdf - Foundation Neural Network
class Expert(nn.Module):
"""Single expert network."""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
print("Expert defined")
Gating NetworkΒΆ
The gating network (or router) decides which expert(s) to activate for each input. It typically consists of a small neural network that takes the input and produces a probability distribution over experts via softmax. In top-k routing, only the \(k\) experts with the highest gating scores are activated, providing computational savings proportional to \(k/N\) where \(N\) is the total number of experts. The gating network must balance load (distributing inputs evenly across experts) and specialization (allowing experts to focus on the inputs they handle best). An auxiliary load-balancing loss penalizes uneven expert utilization.
class GatingNetwork(nn.Module):
"""Gating network for routing."""
def __init__(self, input_dim, n_experts, k=2):
super().__init__()
self.n_experts = n_experts
self.k = k # Top-k experts
self.gate = nn.Linear(input_dim, n_experts)
def forward(self, x):
"""
Returns:
gates: (batch, n_experts) routing weights
indices: (batch, k) top-k expert indices
"""
logits = self.gate(x)
# Top-k gating
top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=1)
# Softmax over top-k
top_k_gates = F.softmax(top_k_logits, dim=1)
# Sparse gates
gates = torch.zeros_like(logits)
gates.scatter_(1, top_k_indices, top_k_gates)
return gates, top_k_indices
print("GatingNetwork defined")
Mixture of Experts LayerΒΆ
The MoE layer combines the gating network with a collection of expert sub-networks. For a given input, the gate produces routing weights, the top-\(k\) experts process the input, and their outputs are combined as a weighted sum: \(y = \sum_{i \in \text{top-}k} g_i \cdot E_i(x)\), where \(g_i\) are the gating weights and \(E_i\) are the expert outputs. Each expert has the same architecture but independent parameters, allowing different experts to specialize on different parts of the input space. This sparse activation pattern enables the model to scale to very large parameter counts while keeping per-example computation constant β the same principle behind models like Switch Transformer and GShard.
class MixtureOfExperts(nn.Module):
"""MoE layer with multiple experts."""
def __init__(self, input_dim, hidden_dim, output_dim, n_experts=4, k=2):
super().__init__()
self.n_experts = n_experts
self.k = k
# Experts
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, output_dim)
for _ in range(n_experts)
])
# Gating
self.gate = GatingNetwork(input_dim, n_experts, k)
def forward(self, x):
# Get gates
gates, indices = self.gate(x)
# Compute expert outputs
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
# (batch, n_experts, output_dim)
# Weighted combination
output = torch.einsum('be,beo->bo', gates, expert_outputs)
# Load balancing loss
importance = gates.sum(dim=0) # (n_experts,)
load_loss = (importance ** 2).sum() / (importance.sum() ** 2 + 1e-8)
return output, load_loss
print("MixtureOfExperts defined")
MoE Network for ClassificationΒΆ
Wrapping the MoE layer within a standard classification pipeline demonstrates how sparse expert routing works in practice. The model consists of shared feature extraction layers (e.g., convolutional backbone), the MoE layer that routes features to specialized experts, and a final classification head. During training, the total loss includes the task loss (cross-entropy) plus the load-balancing loss that encourages uniform expert utilization. This architecture naturally scales: adding more experts increases the modelβs capacity without proportionally increasing inference cost.
class MoENet(nn.Module):
"""Network with MoE layers."""
def __init__(self, n_experts=4):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.moe = MixtureOfExperts(256, 128, 128, n_experts=n_experts, k=2)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x, load_loss = self.moe(x)
x = F.relu(x)
x = self.fc2(x)
return x, load_loss
print("MoENet defined")
Train MoEΒΆ
Training a Mixture of Experts model requires monitoring not just the task loss but also the expert utilization statistics: which experts are being selected, how evenly the load is distributed, and whether any experts have become dormant (never selected). Early in training, the router may collapse to always selecting the same expert (expert collapse), which the load-balancing loss is designed to prevent. Comparing MoE training speed and final accuracy against a dense model with the same total parameter count reveals the efficiency benefit of sparse computation.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)
# Model
model = MoENet(n_experts=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train
losses = []
for epoch in range(5):
model.train()
epoch_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
output, load_loss = model(x)
loss_ce = F.cross_entropy(output, y)
# Total loss with load balancing
loss = loss_ce + 0.01 * load_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(epoch_loss / len(train_loader))
print(f"Epoch {epoch+1}, Loss: {losses[-1]:.4f}")
# Evaluate
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
output, _ = model(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
print(f"\nTest Accuracy: {100 * correct / total:.2f}%")
Analyze Expert UsageΒΆ
After training, analyzing which inputs each expert specializes in reveals the learned division of labor. For image classification, different experts may specialize in different classes, different visual styles, or different difficulty levels. Visualizing the gating weights as a function of the input class or feature space provides interpretable insight into the modelβs internal routing strategy. Healthy expert usage shows all experts being utilized with some degree of specialization, whereas pathological usage shows a few experts handling most inputs while others remain idle.
# Analyze which experts are used
expert_usage = torch.zeros(4)
model.eval()
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
x_feat = F.relu(model.fc1(x.view(-1, 784)))
gates, _ = model.moe.gate(x_feat)
expert_usage += gates.sum(dim=0).cpu()
# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Expert usage
axes[0].bar(range(4), expert_usage.numpy())
axes[0].set_xlabel('Expert', fontsize=11)
axes[0].set_ylabel('Usage Count', fontsize=11)
axes[0].set_title('Expert Usage Distribution', fontsize=12)
axes[0].grid(True, alpha=0.3, axis='y')
# Training loss
axes[1].plot(losses, 'b-o', markersize=5)
axes[1].set_xlabel('Epoch', fontsize=11)
axes[1].set_ylabel('Loss', fontsize=11)
axes[1].set_title('Training Loss', fontsize=12)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
SummaryΒΆ
Mixture of Experts:ΒΆ
Components:
Multiple expert networks
Gating network for routing
Top-k sparse selection
Load balancing loss
Key Ideas:ΒΆ
Conditional computation: Only activate subset of parameters
Specialization: Experts learn different sub-tasks
Scalability: Add capacity without proportional compute
Load Balancing:ΒΆ
Prevent expert collapse:
Importance loss
Auxiliary load loss
Random routing
Applications:ΒΆ
Large language models (Switch Transformer, GShard)
Multi-task learning
Vision models (V-MoE)
Recommendation systems
Advantages:ΒΆ
Efficient scaling
Task specialization
Sparse activation
Challenges:ΒΆ
Load imbalance
Training instability
Communication overhead (distributed)
Advanced Mixture of Experts: Mathematical Foundations and Modern ArchitecturesΒΆ
Table of ContentsΒΆ
Introduction to Mixture of Experts
Mathematical Formulation
Gating Mechanisms
Sparsity and Load Balancing
Switch Transformers
Expert Choice Routing
Hierarchical MoE
Training Dynamics and Challenges
Theoretical Analysis
Modern Applications
Implementation Considerations
1. Introduction to Mixture of Experts (MoE)ΒΆ
Core ConceptΒΆ
Mixture of Experts is a machine learning technique that divides a complex problem into smaller sub-problems, each handled by a specialized βexpertβ network. A gating network learns to route inputs to the most appropriate experts.
Key Idea:
Conditional Computation: Only activate a subset of parameters per input
Specialization: Each expert learns different aspects of the data distribution
Scalability: Increase model capacity without proportional computation cost
Historical ContextΒΆ
1991: Jacobs et al. introduce MoE for regression
2017: Shazeer et al. apply MoE to neural machine translation (1000+ experts)
2021: Googleβs Switch Transformers achieve trillion-parameter models
2022-2024: MoE becomes standard in large language models (GPT-4, Mixtral)
Why MoE?ΒΆ
Benefits:
Model Capacity: \(N\) experts β \(N\)Γ parameters without \(N\)Γ computation
Sample Efficiency: Experts specialize on different data patterns
Transfer Learning: Experts can encode different domains/tasks
Inference Speed: Sparse activation reduces compute
Challenges:
Load Balancing: Prevent all inputs routing to same expert
Training Instability: Gating can collapse to use few experts
Communication Overhead: In distributed settings, routing is expensive
Representation Collapse: Experts may learn redundant functions
2. Mathematical FormulationΒΆ
Standard MoE LayerΒΆ
Given input \(\mathbf{x} \in \mathbb{R}^d\), \(N\) expert networks \(E_1, \ldots, E_N\), and a gating network \(G\):
where the gating function produces weights:
with \(g_i(\mathbf{x}) \in [0, 1]\) and \(\sum_{i=1}^N g_i(\mathbf{x}) = 1\).
Top-K Sparse GatingΒΆ
To reduce computation, only activate top-\(k\) experts:
Common Choice: \(k=1\) (single expert) or \(k=2\) (two experts)
Computation Savings:
Dense: \(O(N \cdot C_{\text{expert}})\)
Sparse: \(O(k \cdot C_{\text{expert}})\)
Speedup: \(N/k\) (e.g., 8Γ with \(N=8, k=1\))
Expert NetworksΒΆ
Each expert \(E_i\) is typically a feed-forward network:
Parameter Count:
Standard FFN: \(d \cdot d_{\text{ff}} + d_{\text{ff}} \cdot d\)
MoE with \(N\) experts: \(N \cdot (d \cdot d_{\text{ff}} + d_{\text{ff}} \cdot d)\)
Per-token compute: Same as standard FFN if \(k=1\)
3. Gating MechanismsΒΆ
3.1 Softmax Gating (Dense)ΒΆ
Pros: Smooth, differentiable Cons: All experts activated (expensive)
3.2 Top-K Gating with Noisy Top-KΒΆ
Add noise to logits to encourage exploration:
Then select top-\(k\) from \(H(\mathbf{x})\) and renormalize:
Benefit: Noise prevents gating from deterministically choosing same experts
3.3 Switch Routing (Top-1 Simplified)ΒΆ
Simplest form: route each token to single expert with highest logit:
Advantages:
Minimal routing computation
Deterministic (during inference)
Highest sparsity
Gradient Flow: Use Straight-Through Estimator (STE) for backprop
3.4 Expert Choice RoutingΒΆ
Instead of tokens choosing experts, experts choose tokens:
Each expert \(i\) selects top-\(k\) tokens with highest affinity: $\( \text{Tokens}_i = \text{Top-K}_{\text{tokens}}\left\{(W_g \mathbf{x}_j)_i \mid j=1,\ldots,T\right\} \)$
Each selected token is processed by that expert
Benefit: Perfect load balancing (each expert processes exactly \(k\) tokens)
4. Sparsity and Load BalancingΒΆ
The Load Balancing ProblemΒΆ
Issue: Gating may route most tokens to few experts (collapse)
Consequences:
Underutilized experts (waste of capacity)
Overloaded experts (slow training, limited batch parallelism)
Poor generalization (experts donβt specialize)
Load Balancing LossΒΆ
Importance Loss (Shazeer et al., 2017):
Define importance of expert \(i\) over batch \(B\):
We want uniform distribution: \(\text{Importance}(i) \approx |B| / N\) for all \(i\).
Load Balancing Loss:
where \(\mu = |B| / N\) and \(\text{CV}\) is coefficient of variation.
Alternative Formulation (Switch Transformers):
where:
\(f_i\) = fraction of tokens dispatched to expert \(i\)
\(P_i\) = mean gating probability for expert \(i\)
Optimal Balance: \(f_i = P_i = 1/N\) βΉ \(\mathcal{L}_{\text{aux}} = \alpha / N\)
Typical \(\alpha\): 0.01 to 0.1
Capacity FactorΒΆ
Limit tokens per expert to prevent overload:
where \(c\) is capacity factor (typically 1.0 to 1.5).
Overflow Handling:
Dropout: Discard overflow tokens (Switch Transformers)
Residual: Pass overflow through residual connection
Secondary Expert: Route to second-choice expert
5. Switch TransformersΒΆ
ArchitectureΒΆ
Key Innovation: Simplify MoE to top-1 routing for maximum sparsity
Switch Layer:
Replace FFN in Transformer with MoE-FFN
Use deterministic top-1 routing (no noise at inference)
Add auxiliary load balancing loss
Scaling:
Switch-Base: 7B parameters (95% in MoE layers)
Switch-Large: 26B parameters
Switch-C: 1.6 trillion parameters
Expert Capacity and Token DroppingΒΆ
With \(T\) tokens, \(N\) experts, capacity factor \(c\):
Example: \(T=1024\), \(N=8\), \(c=1.25\) βΉ 160 tokens per expert
If expert receives \(>160\) tokens, overflow is dropped (skips expert processing).
Trade-off:
Low \(c\): More drops, faster training, potential quality loss
High \(c\): Fewer drops, slower training, better quality
Emperical Finding: \(c=1.0\) to \(1.25\) optimal
Routing ImprovementsΒΆ
1. Router Z-Loss
Penalize large logits (prevents numerical instability):
where \(w_i\) are router weights.
2. Per-Expert Gradients
Compute gradients separately per expert to prevent interference.
3. Lower Precision Training
Use bfloat16 for expert computations (faster, memory efficient).
6. Expert Choice RoutingΒΆ
MotivationΒΆ
Problem with Token Choice: Tokens independently choose experts β unbalanced load
Solution: Experts choose tokens β perfect balance
AlgorithmΒΆ
For \(T\) tokens, \(N\) experts, expert capacity \(C = T/N\):
Compute Affinities: $\( S_{ij} = \mathbf{x}_j^T \mathbf{w}_i \quad \text{(token \)j\( affinity to expert \)i\()} \)$
Expert Selection: For each expert \(i\): $\( \text{Selected}(i) = \text{Top-C}_{j} \{S_{ij}\} \)$
Gating Weights: Normalize affinities for selected tokens: $\( g_{ij} = \frac{\exp(S_{ij})}{\sum_{k \in \text{Selected}(i)} \exp(S_{ik})} \)$
Expert Output: $\( \mathbf{y}_j = \sum_{i: j \in \text{Selected}(i)} g_{ij} \cdot E_i(\mathbf{x}_j) \)$
AdvantagesΒΆ
Perfect Load Balance: Each expert processes exactly \(C\) tokens
No Auxiliary Loss: No need for load balancing regularization
Predictable Performance: No token dropping
Better Specialization: Experts have global view
ImplementationΒΆ
Efficient Top-K Selection:
Use per-expert priority queues
Complexity: \(O(T \cdot N \log C)\)
Memory Layout:
Gather tokens into expert-specific buffers
Process all experts in parallel
Scatter results back to token positions
7. Hierarchical MoEΒΆ
Two-Level HierarchyΒΆ
Motivation: Coarse-grained routing reduces routing overhead
Architecture:
Primary Router: Route to expert group \(G_k\) (\(K\) groups of \(M\) experts) $\( k^* = \arg\max_k \mathbf{x}^T \mathbf{w}_k^{\text{primary}} \)$
Secondary Router: Within group \(G_{k^*}\), route to expert \(E_i\) $\( i^* = \arg\max_{i \in G_{k^*}} \mathbf{x}^T \mathbf{w}_i^{\text{secondary}} \)$
Output: $\( \text{MoE}(\mathbf{x}) = E_{i^*}(\mathbf{x}) \)$
Parameter Count: \(K + K \cdot M = K(1 + M)\) routing parameters vs. \(K \cdot M\) for flat routing
Use Case: Routing in distributed settings (primary = machine, secondary = expert on machine)
Multi-Path Hierarchical MoEΒΆ
Route to top-\(k_1\) groups, then top-\(k_2\) experts per group:
where \(w_{gi}\) combines primary and secondary routing weights.
8. Training Dynamics and ChallengesΒΆ
Gradient FlowΒΆ
Challenge: Only \(k\) out of \(N\) experts receive gradients per example
Consequence: Experts update at rate \(\approx k/N\) of dense model
Solutions:
Larger Batch Sizes: Ensure all experts see gradients frequently
Higher Learning Rates: Compensate for sparse gradients
Gradient Accumulation: Accumulate over multiple batches
Router CollapseΒΆ
Problem: Router converges to use only few experts
Indicators:
High load balancing loss
Few experts have high importance
Validation loss diverges from training
Prevention:
Auxiliary Loss: Penalize imbalance
Router Initialization: Random/orthogonal initialization
Warmup: Start with lower sparsity, gradually increase
Expert Dropout: Randomly drop expert activations
Fine-Tuning MoEΒΆ
Challenge: Fine-tuning on small datasets can cause catastrophic forgetting
Strategies:
Freeze Router: Only update expert parameters
LoRA on Experts: Add low-rank adapters to each expert
Distillation: Distill MoE to dense model for deployment
Task-Specific Experts: Add new experts for new task, keep old frozen
Communication in Distributed TrainingΒΆ
All-to-All Communication: In distributed MoE, tokens are scattered to experts on different GPUs
Cost: \(O(\text{model size} / \text{num devices})\) per routing step
Optimizations:
Expert Parallelism: Partition experts across devices
Local Experts: Constrain routing to same device
2D Parallelism: Combine expert + data parallelism
Asynchronous Routing: Overlap communication and computation
Example (DeepSpeed-MoE):
128 GPUs, 512 experts
4 experts per GPU
All-to-all brings tokens for each expertβs 4 experts
Bandwidth: ~100 GB/s per GPU
9. Theoretical AnalysisΒΆ
ExpressivenessΒΆ
Theorem (Jacobs et al., 1991): A mixture of \(N\) linear experts can represent any piecewise linear function with at most \(N\) pieces.
Extension to Neural Experts: With ReLU activations, MoE can approximate arbitrary continuous functions with fewer parameters than dense network.
Sample ComplexityΒΆ
Divide-and-Conquer Benefit:
If data distribution has \(K\) distinct modes, and each expert specializes in one mode:
vs. dense model:
Intuition: Each expert learns simpler function (lower capacity) on specialized data subset.
GeneralizationΒΆ
Generalization Bound (simplified):
With probability \(\geq 1 - \delta\):
Key Insight: Effective capacity is \(N \cdot \text{Capacity}(\text{expert})\), but with sparse routing (\(k \ll N\)), generalization improves.
10. Modern ApplicationsΒΆ
Large Language ModelsΒΆ
GPT-4 (rumored): 8 experts, 220B parameters each, ~1.76T total Mixtral 8x7B: 8 experts, 7B each, 47B total, 13B active GLaM: 64 experts, 1.2T parameters
Benefits:
Faster inference than dense models of similar quality
Better multilingual performance (experts specialize by language)
Task specialization (coding, math, creative writing)
Vision ModelsΒΆ
V-MoE (Vision MoE): Apply MoE to Vision Transformers
Replace every other FFN layer with MoE
30Γ parameter increase, 2Γ training time
SOTA on ImageNet with fewer FLOPs
BEiT-3: Unified vision-language model with MoE
Experts specialize in vision vs. language vs. cross-modal
Multimodal ModelsΒΆ
Multimodal MoE: Different experts for different modalities
Example Architecture:
Expert 1-4: Text processing
Expert 5-8: Image processing
Expert 9-12: Video processing
Router learns modality-aware routing
Benefits: Efficient scaling to many modalities
Recommendation SystemsΒΆ
Challenges: Billions of users, items, need real-time inference
MoE Solution:
Experts specialize by user segment (geography, behavior cluster)
Low latency (top-1 routing)
Personalization (user-aware routing)
11. Implementation ConsiderationsΒΆ
Memory ManagementΒΆ
Expert Parameters: Typically stored on same device, but can be distributed:
# Single-GPU: All experts on same GPU
experts = nn.ModuleList([Expert(d) for _ in range(N)])
# Multi-GPU: Experts distributed across GPUs
experts = [Expert(d).to(f'cuda:{i % num_gpus}') for i in range(N)]
Activation Memory: With batch size \(B\), sequence length \(T\):
Dense FFN: \(O(B \cdot T \cdot d_{\text{ff}})\)
MoE: \(O(k \cdot B \cdot T \cdot d_{\text{ff}})\) (same if \(k=1\))
Inference OptimizationΒΆ
Batching: Group requests by selected expert to maximize GPU utilization
Caching: For static routing (e.g., language-specific experts), cache expert selection
Quantization: Quantize experts to int8/int4 (60-75% compression)
Distillation: Distill MoE to smaller dense model: $\( \mathcal{L}_{\text{distill}} = \text{KL}\left(P_{\text{student}} \| P_{\text{MoE}}\right) \)$
DebuggingΒΆ
Common Issues:
Router Collapse: Check expert utilization histogram
NaN Losses: Reduce learning rate, add gradient clipping
Poor Load Balance: Increase auxiliary loss weight
Memory OOM: Reduce capacity factor or batch size
Monitoring:
Expert utilization per batch
Router entropy: \(H = -\sum_i p_i \log p_i\) (should be \(\approx \log N\))
Auxiliary loss value
Per-expert gradient norms
12. Recent Advances (2022-2024)ΒΆ
Soft Merging of Experts (MoE-Merge)ΒΆ
Idea: Instead of discrete routing, merge expert parameters:
where \(\alpha_i\) learned via meta-learning.
Benefit: Single-model inference (no routing overhead)
Sparse UpcyclingΒΆ
Convert pretrained dense model to MoE:
Initialize each expert with pretrained weights
Add small random perturbations to break symmetry
Train router from scratch, freeze experts initially
Gradually unfreeze experts
Result: Faster convergence than training MoE from scratch
MoE with RetrievalΒΆ
Combine MoE with retrieval-augmented generation:
Retrieve relevant documents
Route query + documents to specialized expert
Experts fine-tuned on different knowledge domains
Example: Medical expert, legal expert, scientific expert
Continuous Expert ModelsΒΆ
Replace discrete experts with continuous expert space:
where \(\boldsymbol{\theta}(\mathbf{x})\) is a continuous function of input.
Implementation: Hypernetwork generates expert parameters conditioned on \(\mathbf{x}\).
Best PracticesΒΆ
When to Use MoEΒΆ
Good Fit:
Large datasets with diverse distributions
Need to scale model capacity without proportional compute
Multi-task, multi-domain, or multilingual scenarios
Have distributed training infrastructure
Poor Fit:
Small datasets (experts wonβt specialize)
Real-time inference with strict latency (routing overhead)
Limited training resources (complex to tune)
Need interpretability (routing is learned, not explicit)
Hyperparameter TuningΒΆ
Priority Order:
Number of Experts \(N\): Start with 4-8, increase if underutilized
Top-K: Usually 1-2 (higher = more compute, less specialization)
Capacity Factor \(c\): 1.0-1.5 (higher = fewer drops, slower)
Auxiliary Loss \(\alpha\): 0.01-0.1 (tune to balance load)
Expert Size: Often same as original FFN
Tuning Strategy:
Train small MoE on subset of data
Monitor expert utilization and auxiliary loss
Adjust \(\alpha\) if utilization imbalanced
Scale to full model
EvaluationΒΆ
Metrics:
Perplexity/Accuracy: Standard task metrics
Router Entropy: \(H = -\sum p_i \log p_i\) (high = good)
Expert Utilization: Fraction of tokens per expert
Efficiency: FLOPs per token vs. dense model
Load Balance: Coefficient of variation in expert load
Ablations:
MoE vs. dense with same FLOPs
MoE vs. dense with same parameters
Top-1 vs. top-2 routing
With vs. without auxiliary loss
Common PitfallsΒΆ
Insufficient Batch Size: Experts donβt see enough gradients
Fix: Increase batch size or gradient accumulation
Router Overfitting: Router uses same experts for all inputs
Fix: Add noise to routing, increase auxiliary loss
Ignoring Communication Costs: All-to-all can dominate runtime
Fix: Use expert parallelism, profile communication
Forgetting Capacity Factor: Dropping too many tokens
Fix: Monitor drop rate, increase \(c\) if needed
Equal Expert Sizes: Some tasks need bigger experts
Fix: Use heterogeneous expert sizes
Future DirectionsΒΆ
Adaptive Sparsity: Learn \(k\) per input (easy vs. hard examples)
Hierarchical Specialization: Experts within experts
Neurosymbolic MoE: Symbolic reasoning experts + neural experts
Privacy-Preserving MoE: Federated experts (user data stays local)
Dynamic Expert Addition: Continual learning with new experts for new tasks
Energy-Efficient MoE: Hardware-aware routing for edge deployment
SummaryΒΆ
Mixture of Experts enables scaling neural networks to trillions of parameters while maintaining inference efficiency through sparse activation. Key innovations include:
Switch Transformers: Simplified top-1 routing with load balancing
Expert Choice: Inverting routing for perfect load balance
Hierarchical MoE: Reducing routing overhead in distributed settings
Soft Merging: Combining experts for deployment
Success Stories:
GPT-4, Mixtral, GLaM: SOTA language models
V-MoE: Efficient vision models
DeepSpeed-MoE: Distributed training framework
Key Takeaway: MoE is not just about scalingβitβs about learning to specialize and conditionally compute, which is fundamental to how biological brains work.
ReferencesΒΆ
Foundational Papers:
Jacobs et al. (1991): Adaptive Mixtures of Local Experts
Shazeer et al. (2017): Outrageously Large Neural Networks
Lepikhin et al. (2020): GShard
Fedus et al. (2021): Switch Transformers
Zhou et al. (2022): Mixture-of-Experts with Expert Choice Routing
Modern Applications: 6. Riquelme et al. (2021): Scaling Vision with Sparse MoE 7. Mustafa et al. (2022): Multimodal Contrastive Learning with MoE 8. Zoph (2022): ST-MoE
Analysis: 9. Artetxe et al. (2021): Specialization in MoE 10. Clark et al. (2022): Unified Scaling Laws for MoE
"""
Advanced Mixture of Experts - Production Implementation
Comprehensive implementations of modern MoE architectures
Methods Implemented:
1. Standard MoE with Top-K Gating
2. Noisy Top-K Gating (Shazeer et al.)
3. Switch Transformer (Top-1)
4. Expert Choice Routing
5. Hierarchical MoE
6. Load Balancing Mechanisms
Features:
- Efficient sparse routing
- Load balancing with auxiliary losses
- Expert capacity and token dropping
- Visualization of expert utilization
- Distributed MoE support foundations
Author: Advanced Deep Learning Course
Date: 2024
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
# ============================================================================
# 1. Expert Networks
# ============================================================================
class Expert(nn.Module):
"""Single expert: a feed-forward network."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor [batch_size, input_dim]
Returns:
Output tensor [batch_size, output_dim]
"""
h = self.activation(self.fc1(x))
h = self.dropout(h)
return self.fc2(h)
# ============================================================================
# 2. Standard MoE with Top-K Gating
# ============================================================================
class TopKGate(nn.Module):
"""
Top-K Gating mechanism.
Routes each input to top-k experts based on learned gating weights.
"""
def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
super().__init__()
self.input_dim = input_dim
self.num_experts = num_experts
self.top_k = top_k
# Gating network: linear projection to expert scores
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
gates: Gating weights [batch_size, top_k]
indices: Expert indices [batch_size, top_k]
load: Expert load for load balancing [num_experts]
"""
# Compute gating scores
logits = self.gate(x) # [batch_size, num_experts]
# Select top-k experts
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
# Compute gating weights (softmax over top-k)
top_k_gates = F.softmax(top_k_logits, dim=-1)
# Compute load for each expert (for load balancing)
# This is the fraction of tokens routed to each expert
gates_full = torch.zeros_like(logits).scatter_(1, top_k_indices, top_k_gates)
load = gates_full.sum(0) # [num_experts]
return top_k_gates, top_k_indices, load
class MixtureOfExperts(nn.Module):
"""
Standard Mixture of Experts layer with Top-K gating.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 8,
top_k: int = 2,
load_balance_weight: float = 0.01
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.load_balance_weight = load_balance_weight
# Create experts
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)
])
# Gating network
self.gate = TopKGate(input_dim, num_experts, top_k)
# Track expert statistics
self.register_buffer('expert_counts', torch.zeros(num_experts))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
output: MoE output [batch_size, output_dim]
aux_loss: Load balancing auxiliary loss (scalar)
"""
batch_size = x.size(0)
# Get gating decisions
gates, indices, load = self.gate(x) # gates: [B, k], indices: [B, k], load: [E]
# Initialize output
output = torch.zeros(batch_size, self.experts[0].fc2.out_features, device=x.device)
# Route to experts and aggregate
for i in range(self.top_k):
expert_indices = indices[:, i] # [batch_size]
expert_gates = gates[:, i:i+1] # [batch_size, 1]
# Process each expert
for expert_id in range(self.num_experts):
# Get inputs for this expert
mask = (expert_indices == expert_id)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[expert_id](expert_input)
# Weighted aggregate
output[mask] += expert_gates[mask] * expert_output
# Track usage
self.expert_counts[expert_id] += mask.sum()
# Compute load balancing loss
aux_loss = self.compute_load_balance_loss(load, batch_size)
return output, aux_loss
def compute_load_balance_loss(self, load: torch.Tensor, batch_size: int) -> torch.Tensor:
"""
Importance loss: encourages equal expert utilization.
Args:
load: Expert loads [num_experts]
batch_size: Number of tokens
Returns:
aux_loss: Load balancing loss (scalar)
"""
# Expected load per expert
expected_load = (batch_size * self.top_k) / self.num_experts
# Coefficient of variation squared
mean_load = load.mean()
std_load = load.std()
cv_squared = (std_load / (mean_load + 1e-10)) ** 2
aux_loss = self.load_balance_weight * self.num_experts * cv_squared
return aux_loss
# ============================================================================
# 3. Noisy Top-K Gating
# ============================================================================
class NoisyTopKGate(nn.Module):
"""
Noisy Top-K Gating (Shazeer et al., 2017).
Adds tunable Gaussian noise to encourage exploration during training.
"""
def __init__(self, input_dim: int, num_experts: int, top_k: int = 2, noise_std: float = 1.0):
super().__init__()
self.input_dim = input_dim
self.num_experts = num_experts
self.top_k = top_k
self.noise_std = noise_std
# Gating networks
self.gate = nn.Linear(input_dim, num_experts, bias=False)
self.noise_gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
gates: Gating weights [batch_size, top_k]
indices: Expert indices [batch_size, top_k]
load: Expert load [num_experts]
"""
# Clean gating scores
logits = self.gate(x) # [batch_size, num_experts]
# Add noise during training
if self.training:
noise_logits = self.noise_gate(x)
noise = torch.randn_like(logits) * self.noise_std
noise = noise * F.softplus(noise_logits)
logits = logits + noise
# Select top-k
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
top_k_gates = F.softmax(top_k_logits, dim=-1)
# Compute load
gates_full = torch.zeros_like(logits).scatter_(1, top_k_indices, top_k_gates)
load = gates_full.sum(0)
return top_k_gates, top_k_indices, load
# ============================================================================
# 4. Switch Transformer (Top-1 Routing)
# ============================================================================
@dataclass
class SwitchConfig:
"""Configuration for Switch Transformer."""
capacity_factor: float = 1.25
drop_tokens: bool = True
load_balance_weight: float = 0.01
class SwitchGate(nn.Module):
"""
Switch routing: simplified top-1 gating.
Each token routed to exactly one expert.
"""
def __init__(self, input_dim: int, num_experts: int, config: SwitchConfig):
super().__init__()
self.input_dim = input_dim
self.num_experts = num_experts
self.config = config
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
dispatch_mask: Routing decisions [num_experts, capacity, batch_size]
combine_weights: Combining weights [batch_size, num_experts, capacity]
metadata: Dict with auxiliary loss and statistics
"""
batch_size = x.size(0)
# Compute routing logits
logits = self.gate(x) # [batch_size, num_experts]
# Top-1 routing
expert_indices = torch.argmax(logits, dim=-1) # [batch_size]
expert_gates = F.softmax(logits, dim=-1) # [batch_size, num_experts]
# Capacity per expert
capacity = int((batch_size / self.num_experts) * self.config.capacity_factor)
# Create dispatch mask and combine weights
dispatch_mask, combine_weights, dropped_tokens = self._create_dispatch_and_combine(
expert_indices, expert_gates, batch_size, capacity
)
# Compute auxiliary loss
aux_loss = self._compute_switch_loss(expert_indices, expert_gates, batch_size)
metadata = {
'aux_loss': aux_loss,
'dropped_tokens': dropped_tokens,
'expert_counts': torch.bincount(expert_indices, minlength=self.num_experts)
}
return dispatch_mask, combine_weights, metadata
def _create_dispatch_and_combine(
self,
expert_indices: torch.Tensor,
expert_gates: torch.Tensor,
batch_size: int,
capacity: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Create dispatch mask and combine weights with capacity constraints."""
device = expert_indices.device
# Initialize
dispatch_mask = torch.zeros(self.num_experts, capacity, batch_size, device=device)
combine_weights = torch.zeros(batch_size, self.num_experts, capacity, device=device)
# Track position in each expert's buffer
expert_positions = torch.zeros(self.num_experts, dtype=torch.long, device=device)
dropped_tokens = 0
for i in range(batch_size):
expert = expert_indices[i].item()
pos = expert_positions[expert].item()
if pos < capacity:
# Dispatch to expert
dispatch_mask[expert, pos, i] = 1
combine_weights[i, expert, pos] = expert_gates[i, expert]
expert_positions[expert] += 1
else:
# Overflow: drop token
dropped_tokens += 1
return dispatch_mask, combine_weights, dropped_tokens
def _compute_switch_loss(
self,
expert_indices: torch.Tensor,
expert_gates: torch.Tensor,
batch_size: int
) -> torch.Tensor:
"""
Switch auxiliary loss.
L_aux = Ξ± * N * Ξ£_i (f_i * P_i)
where f_i = fraction routed to expert i, P_i = mean gate probability for expert i
"""
# Fraction of tokens routed to each expert
counts = torch.bincount(expert_indices, minlength=self.num_experts).float()
f = counts / batch_size # [num_experts]
# Mean gate probability per expert
P = expert_gates.mean(dim=0) # [num_experts]
# Auxiliary loss
aux_loss = self.config.load_balance_weight * self.num_experts * (f * P).sum()
return aux_loss
class SwitchFFN(nn.Module):
"""
Switch Feed-Forward Network.
Replaces standard FFN with MoE using switch routing.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_experts: int = 8,
config: Optional[SwitchConfig] = None
):
super().__init__()
self.num_experts = num_experts
self.config = config or SwitchConfig()
# Experts
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, input_dim) for _ in range(num_experts)
])
# Switch gate
self.gate = SwitchGate(input_dim, num_experts, self.config)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
output: MoE output [batch_size, input_dim]
metadata: Routing statistics and aux loss
"""
batch_size = x.size(0)
capacity = int((batch_size / self.num_experts) * self.config.capacity_factor)
# Get routing decisions
dispatch_mask, combine_weights, metadata = self.gate(x)
# Dispatch to experts
expert_outputs = []
for expert_id in range(self.num_experts):
# Get tokens for this expert
mask = dispatch_mask[expert_id] # [capacity, batch_size]
expert_input = torch.einsum('cb,bd->cd', mask, x) # [capacity, input_dim]
# Process
expert_output = self.experts[expert_id](expert_input) # [capacity, input_dim]
expert_outputs.append(expert_output)
# Combine expert outputs
output = torch.zeros_like(x)
for expert_id in range(self.num_experts):
weights = combine_weights[:, expert_id, :] # [batch_size, capacity]
expert_out = expert_outputs[expert_id] # [capacity, input_dim]
output += torch.einsum('bc,cd->bd', weights, expert_out)
return output, metadata
# ============================================================================
# 5. Expert Choice Routing
# ============================================================================
class ExpertChoiceGate(nn.Module):
"""
Expert Choice Routing.
Experts select tokens instead of tokens selecting experts.
"""
def __init__(self, input_dim: int, num_experts: int, capacity_per_expert: int):
super().__init__()
self.input_dim = input_dim
self.num_experts = num_experts
self.capacity = capacity_per_expert
# Expert selection weights
self.expert_weights = nn.Parameter(torch.randn(num_experts, input_dim))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
dispatch_mask: [num_experts, capacity, batch_size]
combine_weights: [batch_size, num_experts, capacity]
"""
batch_size = x.size(0)
device = x.device
# Compute affinity scores: S[i,j] = expert_i^T * token_j
scores = torch.matmul(self.expert_weights, x.t()) # [num_experts, batch_size]
# Each expert selects top-capacity tokens
_, selected_indices = torch.topk(scores, self.capacity, dim=1) # [num_experts, capacity]
# Create dispatch mask
dispatch_mask = torch.zeros(self.num_experts, self.capacity, batch_size, device=device)
for expert_id in range(self.num_experts):
for pos in range(self.capacity):
token_id = selected_indices[expert_id, pos]
dispatch_mask[expert_id, pos, token_id] = 1
# Compute gating weights (softmax over experts that selected each token)
combine_weights = torch.zeros(batch_size, self.num_experts, self.capacity, device=device)
for token_id in range(batch_size):
# Find experts that selected this token
selecting_experts = []
for expert_id in range(self.num_experts):
for pos in range(self.capacity):
if selected_indices[expert_id, pos] == token_id:
selecting_experts.append((expert_id, pos, scores[expert_id, token_id]))
if selecting_experts:
# Normalize scores
expert_ids, positions, token_scores = zip(*selecting_experts)
token_scores = torch.tensor(token_scores, device=device)
weights = F.softmax(token_scores, dim=0)
for (eid, pos), weight in zip(zip(expert_ids, positions), weights):
combine_weights[token_id, eid, pos] = weight
return dispatch_mask, combine_weights
class ExpertChoiceMoE(nn.Module):
"""MoE with Expert Choice routing."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_experts: int = 8,
tokens_per_expert: int = 32
):
super().__init__()
self.num_experts = num_experts
# Experts
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, input_dim) for _ in range(num_experts)
])
# Expert choice gate
self.gate = ExpertChoiceGate(input_dim, num_experts, tokens_per_expert)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
output: [batch_size, input_dim]
"""
# Get routing
dispatch_mask, combine_weights = self.gate(x)
# Dispatch and process
expert_outputs = []
for expert_id in range(self.num_experts):
mask = dispatch_mask[expert_id] # [capacity, batch_size]
expert_input = torch.einsum('cb,bd->cd', mask, x)
expert_output = self.experts[expert_id](expert_input)
expert_outputs.append(expert_output)
# Combine
output = torch.zeros_like(x)
for expert_id in range(self.num_experts):
weights = combine_weights[:, expert_id, :] # [batch_size, capacity]
expert_out = expert_outputs[expert_id]
output += torch.einsum('bc,cd->bd', weights, expert_out)
return output
# ============================================================================
# 6. Hierarchical MoE
# ============================================================================
class HierarchicalMoE(nn.Module):
"""
Two-level hierarchical MoE.
First routes to expert group, then to expert within group.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_groups: int = 4,
experts_per_group: int = 4
):
super().__init__()
self.num_groups = num_groups
self.experts_per_group = experts_per_group
self.total_experts = num_groups * experts_per_group
# Primary router: select group
self.primary_gate = nn.Linear(input_dim, num_groups, bias=False)
# Secondary routers: select expert within group
self.secondary_gates = nn.ModuleList([
nn.Linear(input_dim, experts_per_group, bias=False)
for _ in range(num_groups)
])
# Experts organized by group
self.expert_groups = nn.ModuleList([
nn.ModuleList([
Expert(input_dim, hidden_dim, input_dim)
for _ in range(experts_per_group)
])
for _ in range(num_groups)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input [batch_size, input_dim]
Returns:
output: [batch_size, input_dim]
"""
batch_size = x.size(0)
# Primary routing: select group
primary_logits = self.primary_gate(x)
group_indices = torch.argmax(primary_logits, dim=-1) # [batch_size]
# Initialize output
output = torch.zeros_like(x)
# Process each group
for group_id in range(self.num_groups):
# Get tokens for this group
group_mask = (group_indices == group_id)
if not group_mask.any():
continue
group_inputs = x[group_mask] # [group_size, input_dim]
# Secondary routing within group
secondary_logits = self.secondary_gates[group_id](group_inputs)
expert_indices = torch.argmax(secondary_logits, dim=-1)
# Process by experts in this group
group_output = torch.zeros_like(group_inputs)
for expert_id in range(self.experts_per_group):
expert_mask = (expert_indices == expert_id)
if expert_mask.any():
expert_input = group_inputs[expert_mask]
expert_output = self.expert_groups[group_id][expert_id](expert_input)
group_output[expert_mask] = expert_output
# Assign back to output
output[group_mask] = group_output
return output
# ============================================================================
# 7. Visualization and Analysis
# ============================================================================
def visualize_expert_utilization(
expert_counts: torch.Tensor,
num_experts: int,
title: str = "Expert Utilization"
):
"""Visualize expert usage distribution."""
counts = expert_counts.cpu().numpy()
plt.figure(figsize=(10, 4))
plt.bar(range(num_experts), counts)
plt.xlabel('Expert ID')
plt.ylabel('Number of Tokens')
plt.title(title)
plt.axhline(y=counts.mean(), color='r', linestyle='--', label='Mean')
plt.legend()
plt.tight_layout()
plt.show()
def compute_routing_entropy(gates: torch.Tensor) -> float:
"""
Compute entropy of routing distribution.
High entropy = balanced routing
Low entropy = collapsed routing
"""
# gates: [batch_size, num_experts]
mean_probs = gates.mean(dim=0) # [num_experts]
entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum()
return entropy.item()
# ============================================================================
# 8. Demo
# ============================================================================
def demo_moe_comparison():
"""Compare different MoE routing strategies."""
print("=" * 80)
print("Mixture of Experts - Routing Strategy Comparison")
print("=" * 80)
# Setup
batch_size = 128
input_dim = 64
hidden_dim = 256
num_experts = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")
# Random input
x = torch.randn(batch_size, input_dim).to(device)
# Models
models = {
'Top-2 MoE': MixtureOfExperts(input_dim, hidden_dim, input_dim, num_experts, top_k=2),
'Switch (Top-1)': SwitchFFN(input_dim, hidden_dim, num_experts),
'Expert Choice': ExpertChoiceMoE(input_dim, hidden_dim, num_experts, tokens_per_expert=16),
'Hierarchical': HierarchicalMoE(input_dim, hidden_dim, num_groups=4, experts_per_group=2)
}
print(f"\nInput: {x.shape}")
print(f"Number of experts: {num_experts}")
results = {}
for name, model in models.items():
model = model.to(device)
model.eval()
print(f"\n{'='*60}")
print(f"Model: {name}")
print(f"{'='*60}")
# Forward pass
with torch.no_grad():
if name == 'Top-2 MoE':
output, aux_loss = model(x)
print(f"Output shape: {output.shape}")
print(f"Auxiliary loss: {aux_loss.item():.6f}")
print(f"Expert utilization:")
visualize_expert_utilization(model.expert_counts, num_experts, f"{name} - Expert Usage")
elif name == 'Switch (Top-1)':
output, metadata = model(x)
print(f"Output shape: {output.shape}")
print(f"Auxiliary loss: {metadata['aux_loss'].item():.6f}")
print(f"Dropped tokens: {metadata['dropped_tokens']}")
print(f"Expert counts: {metadata['expert_counts'].cpu().numpy()}")
visualize_expert_utilization(metadata['expert_counts'], num_experts, f"{name} - Expert Usage")
else:
output = model(x)
print(f"Output shape: {output.shape}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
active_params = total_params // num_experts if 'Hierarchical' not in name else total_params // (4 * 2)
print(f"\nTotal parameters: {total_params:,}")
print(f"Active parameters per token: ~{active_params:,}")
print(f"Sparsity ratio: {total_params / active_params:.1f}x")
results[name] = {
'output': output,
'params': total_params,
'active_params': active_params
}
# Summary
print("\n" + "=" * 80)
print("Summary")
print("=" * 80)
for name, result in results.items():
print(f"\n{name}:")
print(f" Total parameters: {result['params']:,}")
print(f" Active per token: {result['active_params']:,}")
print(f" Efficiency: {result['params'] / result['active_params']:.1f}x")
def demo_load_balancing():
"""Demonstrate importance of load balancing."""
print("\n" + "=" * 80)
print("Load Balancing Impact Demo")
print("=" * 80)
batch_size = 256
input_dim = 32
hidden_dim = 128
num_experts = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Create two models: with and without load balancing
model_with_lb = MixtureOfExperts(
input_dim, hidden_dim, input_dim,
num_experts, top_k=2, load_balance_weight=0.1
).to(device)
model_without_lb = MixtureOfExperts(
input_dim, hidden_dim, input_dim,
num_experts, top_k=2, load_balance_weight=0.0
).to(device)
# Train briefly to see effect
optimizer_with = torch.optim.Adam(model_with_lb.parameters(), lr=0.001)
optimizer_without = torch.optim.Adam(model_without_lb.parameters(), lr=0.001)
num_steps = 100
print("\nTraining for {} steps...".format(num_steps))
for step in range(num_steps):
x = torch.randn(batch_size, input_dim).to(device)
target = torch.randn(batch_size, input_dim).to(device)
# With load balancing
optimizer_with.zero_grad()
output_with, aux_loss_with = model_with_lb(x)
loss_with = F.mse_loss(output_with, target) + aux_loss_with
loss_with.backward()
optimizer_with.step()
# Without load balancing
optimizer_without.zero_grad()
output_without, aux_loss_without = model_without_lb(x)
loss_without = F.mse_loss(output_without, target)
loss_without.backward()
optimizer_without.step()
print("\nFinal Expert Utilization:")
print("\nWith Load Balancing:")
print(model_with_lb.expert_counts.cpu().numpy())
visualize_expert_utilization(
model_with_lb.expert_counts, num_experts,
"With Load Balancing"
)
print("\nWithout Load Balancing:")
print(model_without_lb.expert_counts.cpu().numpy())
visualize_expert_utilization(
model_without_lb.expert_counts, num_experts,
"Without Load Balancing"
)
# Compute utilization statistics
with_lb_std = model_with_lb.expert_counts.float().std().item()
without_lb_std = model_without_lb.expert_counts.float().std().item()
print(f"\nUtilization Std Dev:")
print(f" With LB: {with_lb_std:.2f}")
print(f" Without LB: {without_lb_std:.2f}")
print(f" Improvement: {without_lb_std / with_lb_std:.2f}x more balanced")
if __name__ == "__main__":
print("\nAdvanced Mixture of Experts - Comprehensive Implementation\n")
# Run demos
demo_moe_comparison()
demo_load_balancing()
print("\n" + "=" * 80)
print("Demo complete! Key insights:")
print("=" * 80)
print("1. MoE enables scaling to many parameters with sparse activation")
print("2. Switch routing (top-1) is simplest and most efficient")
print("3. Expert Choice provides perfect load balancing")
print("4. Hierarchical MoE reduces routing overhead")
print("5. Load balancing loss is critical for training stability")
print("6. Different routing strategies trade off complexity vs. performance")
print("=" * 80)