import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Capsule Networks ConceptΒΆ
Capsule:ΒΆ
Group of neurons representing:
Length: probability entity exists
Orientation: instantiation parameters
Squashing Function:ΒΆ
Keeps direction, squashes magnitude to [0,1].
π Reference Materials:
cnn_beyond.pdf - Cnn Beyond
def squash(s, dim=-1):
"""Squashing non-linearity."""
squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * s / torch.sqrt(squared_norm + 1e-8)
# Test squashing
s = torch.randn(10, 16)
v = squash(s)
norms = torch.norm(v, dim=1)
print(f"Squashed norms - min: {norms.min():.3f}, max: {norms.max():.3f}")
2. Dynamic RoutingΒΆ
Routing by Agreement:ΒΆ
where \(\hat{u}_{j|i} = W_{ij} u_i\) is prediction vector.
Update Rule:ΒΆ
class DynamicRouting(nn.Module):
"""Dynamic routing between capsule layers."""
def __init__(self, in_caps, out_caps, in_dim, out_dim, num_iterations=3):
super().__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.num_iterations = num_iterations
# Transformation matrix for each capsule pair
self.W = nn.Parameter(torch.randn(in_caps, out_caps, out_dim, in_dim))
def forward(self, u):
"""
Args:
u: (batch, in_caps, in_dim)
Returns:
v: (batch, out_caps, out_dim)
"""
batch_size = u.size(0)
# Compute predictions
u_expand = u.unsqueeze(2).unsqueeze(4) # (batch, in_caps, 1, in_dim, 1)
W_expand = self.W.unsqueeze(0) # (1, in_caps, out_caps, out_dim, in_dim)
# u_hat = W @ u for each capsule pair
u_hat = torch.matmul(W_expand, u_expand) # (batch, in_caps, out_caps, out_dim, 1)
u_hat = u_hat.squeeze(-1) # (batch, in_caps, out_caps, out_dim)
# Initialize routing logits
b = torch.zeros(batch_size, self.in_caps, self.out_caps, 1).to(u.device)
# Routing iterations
for iteration in range(self.num_iterations):
# Coupling coefficients
c = F.softmax(b, dim=2) # (batch, in_caps, out_caps, 1)
# Weighted sum
s = (c * u_hat).sum(dim=1) # (batch, out_caps, out_dim)
# Squashing
v = squash(s, dim=-1) # (batch, out_caps, out_dim)
# Update routing logits (except last iteration)
if iteration < self.num_iterations - 1:
# Agreement: dot product
agreement = torch.matmul(u_hat, v.unsqueeze(-1)) # (batch, in_caps, out_caps, 1)
b = b + agreement
return v
print("DynamicRouting defined")
Primary CapsulesΒΆ
Primary capsules form the first capsule layer, converting traditional convolutional features into capsule format. Instead of scalar activations, each capsule outputs a vector whose length represents the probability that a particular entity exists and whose orientation encodes the entityβs properties (pose, deformation, texture). The squashing function \(v_j = \frac{\|s_j\|^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}\) ensures the vector length stays between 0 and 1 (interpretable as a probability) while preserving direction. Primary capsules detect basic visual patterns (edges, textures, simple shapes) and pass their pose information up to higher-level capsules via dynamic routing.
class PrimaryCaps(nn.Module):
"""Primary capsule layer."""
def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride):
super().__init__()
self.dim_caps = dim_caps
self.num_caps = out_channels // dim_caps
# Convolutional layer
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
"""
Args:
x: (batch, in_channels, H, W)
Returns:
capsules: (batch, num_caps, dim_caps)
"""
# Convolution
out = self.conv(x) # (batch, out_channels, H', W')
batch_size = out.size(0)
# Reshape into capsules
out = out.view(batch_size, self.num_caps, self.dim_caps, -1)
out = out.view(batch_size, self.num_caps * out.size(-1), self.dim_caps)
# Squashing
return squash(out, dim=-1)
print("PrimaryCaps defined")
CapsNet ArchitectureΒΆ
The full CapsNet architecture consists of: (1) a standard convolutional layer for initial feature extraction, (2) primary capsules that convert features to capsule vectors, and (3) digit capsules (or class capsules) connected via dynamic routing. Dynamic routing iteratively refines the coupling coefficients between lower-level and higher-level capsules: capsules that agree on the predicted pose of a higher-level entity strengthen their connection, while disagreeing capsules weaken theirs. The length of each digit capsule vector gives the predicted probability for that class, and a reconstruction network can decode the capsule vector back to an image, serving as a regularizer.
class CapsNet(nn.Module):
"""Capsule Network for MNIST."""
def __init__(self, num_classes=10):
super().__init__()
# Initial conv layer
self.conv1 = nn.Conv2d(1, 256, 9, stride=1)
# Primary capsules
self.primary_caps = PrimaryCaps(
in_channels=256,
out_channels=256,
dim_caps=8,
kernel_size=9,
stride=2
)
# Digit capsules
self.digit_caps = DynamicRouting(
in_caps=32 * 6 * 6, # Primary caps
out_caps=num_classes,
in_dim=8,
out_dim=16,
num_iterations=3
)
# Decoder for regularization
self.decoder = nn.Sequential(
nn.Linear(16 * num_classes, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Sigmoid()
)
def forward(self, x, y=None):
# Conv1
x = F.relu(self.conv1(x))
# Primary capsules
x = self.primary_caps(x)
# Digit capsules
digit_caps = self.digit_caps(x)
# Length as class scores
lengths = torch.sqrt((digit_caps ** 2).sum(dim=-1))
# Reconstruction
if y is None:
# Use predicted class
index = lengths.argmax(dim=1)
else:
# Use true class
index = y
# Mask
mask = torch.zeros_like(digit_caps)
mask[torch.arange(digit_caps.size(0)), index] = 1
# Decode
masked = (digit_caps * mask).view(digit_caps.size(0), -1)
reconstruction = self.decoder(masked)
return lengths, reconstruction
print("CapsNet defined")
5. Margin LossΒΆ
where \(T_k=1\) if class \(k\) present, \(m^+=0.9\), \(m^-=0.1\).
def margin_loss(lengths, labels, m_plus=0.9, m_minus=0.1, lambda_=0.5):
"""Margin loss for capsule networks."""
# One-hot labels
T = F.one_hot(labels, num_classes=10).float()
# Loss for present classes
loss_present = T * F.relu(m_plus - lengths) ** 2
# Loss for absent classes
loss_absent = (1 - T) * F.relu(lengths - m_minus) ** 2
# Total
loss = loss_present + lambda_ * loss_absent
return loss.sum(dim=1).mean()
print("Margin loss defined")
Train CapsNetΒΆ
CapsNet training uses a margin loss per class: \(L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\| - m^-)^2\), where \(T_k = 1\) if class \(k\) is present and \(m^+, m^-\) are the positive and negative margins. An optional reconstruction loss (MSE between the original image and the reconstruction from the correct class capsule) encourages the capsule vectors to encode meaningful pose information. The total loss balances classification accuracy with reconstruction quality, with a small weight on the reconstruction term to prevent it from dominating.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)
# Model
model = CapsNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_epoch(model, loader):
model.train()
total_loss = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
# Forward
lengths, reconstruction = model(x, y)
# Margin loss
loss_margin = margin_loss(lengths, y)
# Reconstruction loss
loss_recon = F.mse_loss(reconstruction, x.view(x.size(0), -1))
# Total loss
loss = loss_margin + 0.0005 * loss_recon
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
# Train
losses = []
for epoch in range(5):
loss = train_epoch(model, train_loader)
losses.append(loss)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
Evaluate and VisualizeΒΆ
Beyond classification accuracy, CapsNet evaluation includes dimension perturbation analysis: varying each dimension of the class capsule vector independently and observing the effect on the reconstructed image. Each dimension should control a distinct visual attribute (stroke width, skew, translation, etc.), demonstrating that capsules learn disentangled representations of pose and appearance. This interpretability is a key advantage over standard CNNs, where individual neuron activations rarely correspond to single, interpretable factors of variation.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
lengths, _ = model(x)
pred = lengths.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
print(f"\nTest Accuracy: {100 * correct / total:.2f}%")
# Visualize reconstructions
x_test, y_test = next(iter(test_loader))
x_test = x_test[:8].to(device)
y_test = y_test[:8].to(device)
with torch.no_grad():
_, recon = model(x_test, y_test)
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
# Original
axes[0, i].imshow(x_test[i].cpu().squeeze(), cmap='gray')
axes[0, i].axis('off')
# Reconstruction
axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
axes[1, i].axis('off')
axes[0, 0].set_ylabel('Original', fontsize=11)
axes[1, 0].set_ylabel('Reconstructed', fontsize=11)
plt.suptitle('CapsNet Reconstructions', fontsize=12)
plt.tight_layout()
plt.show()
SummaryΒΆ
Capsule Networks:ΒΆ
Key Ideas:
Capsules = groups of neurons
Length = existence probability
Orientation = instantiation parameters
Dynamic routing by agreement
Advantages:
Viewpoint equivariance
Part-whole relationships
Better generalization to novel viewpoints
Interpretable representations
Challenges:
Computationally expensive
Harder to scale
Limited to small images
Applications:ΒΆ
Overlapping digits
3D object recognition
Medical imaging
Video understanding
Advanced Capsule Networks TheoryΒΆ
1. Mathematical FoundationsΒΆ
Capsule RepresentationΒΆ
Definition: A capsule is a group of neurons whose activity vector represents instantiation parameters of an entity.
Capsule Output: v_j β βα΅ where:
Magnitude ||v_j||: Probability entity exists (0 to 1)
Direction: Instantiation parameters (pose, deformation, texture)
Key Difference from CNNs:
CNN scalar neuron: Binary feature detection
Capsule vector: Rich entity representation with properties
Squashing Non-linearityΒΆ
Function: Maps any vector to unit length while preserving direction:
where s_j is total input to capsule j.
Properties:
||v_j|| β [0, 1): Short vectors β near 0, long vectors β near 1
Direction preserved: v_j β₯ s_j
Differentiable: Smooth for gradient descent
Gradient:
2. Dynamic Routing by AgreementΒΆ
Routing AlgorithmΒΆ
Intuition: Lower-level capsules βvoteβ for higher-level capsules based on agreement.
Prediction Vectors: Each lower capsule i predicts higher capsule jβs output:
where W_ij β β^{d_out Γ d_in} is transformation matrix.
Routing Coefficients: Probability capsule i sends output to capsule j:
where b_ij are routing logits (initially 0).
Total Input: Weighted sum of predictions:
Update Rule (iterative):
Agreement: Dot product measures how well prediction aligns with actual output.
Routing Algorithm (Detailed)ΒΆ
Procedure: DynamicRouting(u, r, W)
Input: u_i (lower capsule outputs), r (iterations), W_ij (weights)
Output: v_j (higher capsule outputs)
1. Initialize: b_ij β 0 for all i, j
2. For iteration in 1 to r:
a. c_ij β softmax(b_ij) over j
b. s_j β Ξ£_i c_ij (W_ij u_i)
c. v_j β squash(s_j)
d. b_ij β b_ij + (W_ij u_i) Β· v_j (except last iteration)
3. Return v_j
Computational Cost: O(r Β· I Β· J Β· dΒ²) where r = iterations, I = input caps, J = output caps, d = dimension.
Theoretical JustificationΒΆ
EM Perspective: Routing is similar to EM algorithm:
E-step: Compute c_ij (soft assignment)
M-step: Update v_j (cluster means)
Convergence: Routing maximizes agreement between predictions and outputs.
3. Margin LossΒΆ
Classification LossΒΆ
Objective: Separate present and absent classes with margin.
Loss for Class k:
where:
T_k β {0, 1}: Indicator if class k present
mβΊ = 0.9: Target magnitude for present classes
mβ» = 0.1: Max magnitude for absent classes
Ξ» = 0.5: Down-weight absent class loss
Total Loss:
Intuition:
Present class: Push ||v_k|| above 0.9
Absent class: Push ||v_k|| below 0.1
Down-weight absent to prevent initial learning stagnation
Reconstruction RegularizationΒΆ
Decoder Network: Reconstruct input from capsule outputs.
Objective: Force capsules to encode useful information.
Loss:
where X is input image, XΜ is reconstruction.
Combined:
Typical: Ξ± = 0.0005 (small to avoid drowning out margin loss).
4. EM Routing (Matrix Capsules)ΒΆ
MotivationΒΆ
Limitation of Dynamic Routing: Scalar weights c_ij canβt represent uncertainty.
Solution: Model each capsule as Gaussian with mean ΞΌ_j and variance Ο_jΒ².
Capsule as GaussianΒΆ
Activation a_j: Probability capsule j is active.
Pose Matrix M_j β β^{4Γ4}: Viewpoint-invariant representation.
Distribution: p(M_j) = N(ΞΌ_j, Ο_jΒ²I)
EM Routing AlgorithmΒΆ
E-Step: Compute assignment probabilities R_ij:
where V_i = W_ij M_i is vote from capsule i.
M-Step: Update higher-level capsule parameters:
Cost Function: Negative log-likelihood with activation cost:
where Ξ² controls sparsity.
5. Equivariance PropertiesΒΆ
Viewpoint EquivarianceΒΆ
Definition: Transformation of input causes corresponding transformation of capsule output.
Mathematical Formulation: For transformation T:
where Tβ is corresponding transformation in capsule space.
Example: Rotate image 15Β° β capsule orientation rotates 15Β°.
Advantage over CNNs: CNNs only have translation equivariance (via convolution).
Part-Whole RelationshipsΒΆ
Compositional Structure: Higher capsules represent wholes, lower capsules represent parts.
Coordinate Frame Agreement: Routing ensures parts agree on wholeβs pose.
Example: Two eye capsules route to face capsule if they agree face is present.
6. Capsule Architecture VariantsΒΆ
Original CapsNet (Sabour et al., 2017)ΒΆ
Structure:
Input (28Γ28)
β Conv(256, 9Γ9, stride=1) + ReLU
β PrimaryCaps(32 caps Γ 8D, 9Γ9, stride=2)
β DigitCaps(10 caps Γ 16D, dynamic routing)
β Length as class probability
Parameters: ~8.2M (similar to baseline CNN)
Performance:
MNIST: 99.75% (SOTA at time)
MultiMNIST (overlapping): 95.7% (vs 93.3% baseline)
Matrix Capsules (Hinton et al., 2018)ΒΆ
Key Changes:
4Γ4 pose matrices instead of vectors
EM routing instead of dynamic routing
Coordinate addition for translation
Structure:
Input β Conv β Primary Capsules (pose + activation)
β Class Capsules (EM routing, K iterations)
β Spread loss
Performance: SmallNORB 97.8% (viewpoint generalization)
Stacked Capsule Autoencoders (Kosiorek et al., 2019)ΒΆ
Unsupervised Learning: Learn capsule representations without labels.
Components:
Part Capsule Encoder: Detect object parts
Object Capsule Encoder: Compose parts into objects
Decoder: Reconstruct from object capsules
Set-Based Routing: Treat capsules as sets, use attention.
7. Training TechniquesΒΆ
Learning Rate SchedulingΒΆ
Strategy: Exponential decay with warmup.
Schedule:
lr(t) = lr_base * min(1, t/T_warmup) * decay^(t/T_decay)
Typical: lr_base=1e-3, T_warmup=5000 steps, decay=0.96, T_decay=2000 steps.
RegularizationΒΆ
Techniques:
Reconstruction loss (Ξ±=0.0005)
Weight decay (1e-4)
Dropout on decoder (p=0.5)
No BatchNorm: Interferes with capsule magnitudes.
Data AugmentationΒΆ
Affine Transformations:
Random rotation: Β±15Β°
Translation: Β±2 pixels
Scaling: 0.9-1.1Γ
Cutout: Randomly mask image patches (improves robustness).
8. Computational EfficiencyΒΆ
Memory ComplexityΒΆ
Dynamic Routing: O(B Β· I Β· J Β· d) per iteration
B: Batch size
I: Input capsules
J: Output capsules
d: Capsule dimension
Bottleneck: Storing u_hat for all pairs (IΓJΓd).
Speed OptimizationsΒΆ
1. Fewer Routing Iterations:
r=3 is standard, r=1 often sufficient
Diminishing returns beyond r=3
2. Sparse Routing:
Route each input to top-k output capsules only
Reduces O(IΒ·J) to O(IΒ·k)
3. Grouped Capsules:
Split capsules into independent groups
Route within groups only
4. Parallel Implementation:
Routing iterations sequential, but batch parallel
GPU-friendly matrix operations
9. Limitations and ChallengesΒΆ
Scalability IssuesΒΆ
Problem: Quadratic complexity in number of capsules.
Impact: Hard to scale to ImageNet (224Γ224).
Solutions:
Multi-scale architecture
Local routing (spatial locality)
Attention-based routing
Adversarial VulnerabilityΒΆ
Finding: Capsule networks as vulnerable as CNNs to adversarial examples.
Reason: Margin loss doesnβt inherently provide robustness.
Mitigation: Adversarial training + capsules.
Training InstabilityΒΆ
Issues:
Gradient explosion in early training
Routing coefficients may collapse to one capsule
Solutions:
Careful initialization (Xavier/He)
Gradient clipping (max norm 1.0)
Warmup learning rate schedule
10. Recent Advances (2020-2024)ΒΆ
Self-Routing CapsulesΒΆ
Idea: Learn routing mechanism instead of iterative algorithm.
Method: Attention-based routing in single forward pass.
Advantage: Faster, differentiable routing weights.
Efficient Capsules (Mazzia et al., 2021)ΒΆ
Improvements:
Depthwise separable convolutions in capsule layers
Squash activation only at final layer
Faster than original CapsNet by 3Γ
Capsule Networks for Vision TransformersΒΆ
Hybrid Architecture: Combine ViT attention with capsule routing.
Benefit: Part-whole relationships + long-range dependencies.
11. ApplicationsΒΆ
Medical ImagingΒΆ
Use Case: Lesion detection, organ segmentation.
Advantage: Handles overlapping structures better than CNNs.
Example: Brain tumor segmentation (BraTS dataset).
3D Object RecognitionΒΆ
Task: Classify objects under different viewpoints.
Dataset: ShapeNet, ModelNet40.
Performance: Better viewpoint generalization than CNNs.
Action RecognitionΒΆ
Method: Temporal capsules for video understanding.
Structure: 3D convolution β temporal capsules β action class.
Benefit: Model hierarchical action parts.
Adversarial DefenseΒΆ
Approach: Capsule networks as certified defense.
Mechanism: Reconstruction loss detects adversarial perturbations.
Limitation: Still vulnerable to sophisticated attacks.
12. Comparison with Other ArchitecturesΒΆ
Capsules vs CNNsΒΆ
Aspect |
CNN |
Capsule |
|---|---|---|
Feature |
Scalar |
Vector |
Pooling |
Max pool (loses info) |
Routing (preserves) |
Equivariance |
Translation only |
Viewpoint |
Part-whole |
No |
Yes (via routing) |
Overlapping |
Struggles |
Handles well |
Speed |
Fast |
Slower (routing) |
Capsules vs TransformersΒΆ
Similarities:
Both use attention/routing mechanisms
Both model relationships between entities
Differences:
Capsules: Hierarchical part-whole (compositional)
Transformers: Flat self-attention (relational)
Capsules: Viewpoint equivariance
Transformers: Permutation equivariance
Capsules vs Graph Neural NetworksΒΆ
Connection: Both propagate information between entities.
Difference:
GNN: General graph structure
Capsules: Hierarchical (layer-wise routing)
13. Theoretical InsightsΒΆ
Routing as ClusteringΒΆ
Interpretation: Dynamic routing clusters lower capsules to higher capsules.
Objective: Minimize within-cluster variance (maximize agreement).
Connection to k-means: Soft assignment via softmax (vs hard assignment).
Information BottleneckΒΆ
View: Capsule network as layered information bottleneck.
Compression: Routing compresses lower-level information.
Preservation: Maintains task-relevant information in higher capsules.
Generalization BoundsΒΆ
Analysis: VC dimension and Rademacher complexity for capsule networks.
Finding: Similar generalization to CNNs with comparable parameters.
Advantage: Better generalization on viewpoint variations (empirical).
14. Implementation Best PracticesΒΆ
InitializationΒΆ
Weight Matrices W_ij:
Xavier/Glorot: $\(W \sim \mathcal{N}(0, \frac{2}{d_{in} + d_{out}})\)$
He initialization: $\(W \sim \mathcal{N}(0, \frac{2}{d_{in}})\)$
Routing Logits b_ij: Initialize to 0 (uniform routing).
Hyperparameter TuningΒΆ
Parameter |
Typical Value |
Sensitivity |
|---|---|---|
Routing iterations |
3 |
Low (2-5 works) |
Capsule dimension |
8-16 |
Medium |
Reconstruction weight |
0.0005 |
High |
Margin mβΊ |
0.9 |
Medium |
Margin mβ» |
0.1 |
Low |
Down-weight Ξ» |
0.5 |
Low |
Debugging TipsΒΆ
Check:
Capsule norms: Should be in [0, 1]
Routing coefficients: Should sum to 1 across j
Gradient norms: Clip if >1.0
Loss breakdown: Margin vs reconstruction ratio
Common Issues:
Routing collapse: All c_ij route to one capsule β use entropy regularization
Vanishing capsules: ||v_j|| always near 0 β increase learning rate, check initialization
Reconstruction dominates: Scale down Ξ±
15. Key PapersΒΆ
FoundationalΒΆ
Sabour, Frosst, Hinton (2017): βDynamic Routing Between Capsulesβ (Original CapsNet)
Hinton, Sabour, Frosst (2018): βMatrix Capsules with EM Routingβ
ImprovementsΒΆ
Mazzia et al. (2021): βEfficient-CapsNet: Capsule Network with Self-Attention Routingβ
Kosiorek et al. (2019): βStacked Capsule Autoencodersβ (Unsupervised)
Wang & Liu (2020): βGroup Equivariant Capsule Networksβ
ApplicationsΒΆ
LaLonde & Bagci (2018): βCapsules for Object Segmentationβ (Medical imaging)
Zhao et al. (2019): β3D Point Capsule Networksβ (Point cloud classification)
TheoryΒΆ
Paik et al. (2019): βCapsule Networks Need an Improved Routing Algorithmβ
Zhang et al. (2020): βRethinking the Inception Architecture for Capsule Networksβ
16. Future DirectionsΒΆ
Open ProblemsΒΆ
Scalability: How to scale capsules to ImageNet-scale images?
Architecture Search: Optimal capsule layer configurations?
Theoretical Understanding: Why do capsules work? Formal guarantees?
Routing Alternatives: Better than dynamic routing?
Promising ApproachesΒΆ
Self-Attention Routing: Replace iterative routing with learned attention.
Sparse Capsules: Activate only relevant capsules (conditional computation).
Continuous Capsules: Infinite-dimensional capsules (functional representation).
Hybrid Models: Combine capsules with transformers/GNNs.
17. When to Use Capsule NetworksΒΆ
Choose Capsules When:
Viewpoint equivariance critical (3D object recognition)
Overlapping/occluded objects common (medical imaging)
Part-whole relationships important (compositional reasoning)
Interpretable representations desired
Avoid Capsules When:
Speed critical (real-time inference)
Large-scale images (computational cost)
Standard classification sufficient (CNNs faster)
Limited compute resources (capsules memory-intensive)
# Advanced Capsule Networks Implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional
class SquashActivation(nn.Module):
"""
Squashing non-linearity: v = (||s||^2 / (1 + ||s||^2)) * (s / ||s||)
Maps vectors to [0,1) while preserving direction.
"""
def __init__(self, dim=-1, eps=1e-8):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, s):
"""Apply squashing to input vectors."""
squared_norm = (s ** 2).sum(dim=self.dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
unit_vector = s / torch.sqrt(squared_norm + self.eps)
return scale * unit_vector
class DynamicRoutingLayer(nn.Module):
"""
Dynamic routing between capsule layers (Sabour et al., 2017).
Iteratively routes lower capsules to higher capsules based on agreement.
"""
def __init__(self, in_capsules, out_capsules, in_dim, out_dim,
num_iterations=3, share_weights=False):
super().__init__()
self.in_capsules = in_capsules
self.out_capsules = out_capsules
self.in_dim = in_dim
self.out_dim = out_dim
self.num_iterations = num_iterations
# Transformation matrices
if share_weights:
# Share weights across all input capsules
self.W = nn.Parameter(torch.randn(1, out_capsules, out_dim, in_dim))
else:
# Separate weights for each (input, output) capsule pair
self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
self.squash = SquashActivation(dim=-1)
def forward(self, u):
"""
Args:
u: (batch, in_capsules, in_dim) - lower capsule outputs
Returns:
v: (batch, out_capsules, out_dim) - higher capsule outputs
"""
batch_size = u.size(0)
# Compute prediction vectors: u_hat = W @ u
# u shape: (batch, in_caps, in_dim)
# W shape: (in_caps, out_caps, out_dim, in_dim)
u_expand = u.unsqueeze(2).unsqueeze(4) # (batch, in_caps, 1, in_dim, 1)
W_expand = self.W.unsqueeze(0) # (1, in_caps, out_caps, out_dim, in_dim)
# Matrix multiplication: (out_dim, in_dim) @ (in_dim, 1) = (out_dim, 1)
u_hat = torch.matmul(W_expand, u_expand) # (batch, in_caps, out_caps, out_dim, 1)
u_hat = u_hat.squeeze(-1) # (batch, in_caps, out_caps, out_dim)
# Stop gradient on predictions (don't backprop through routing)
u_hat_detached = u_hat.detach()
# Initialize routing logits to zero
b = torch.zeros(batch_size, self.in_capsules, self.out_capsules, 1,
device=u.device, dtype=u.dtype)
# Routing iterations
for iteration in range(self.num_iterations):
# Softmax to get coupling coefficients: c_ij = softmax(b_ij)
c = F.softmax(b, dim=2) # (batch, in_caps, out_caps, 1)
# Weighted sum of predictions: s_j = sum_i c_ij * u_hat_j|i
if iteration == self.num_iterations - 1:
# Last iteration: use actual predictions (with gradient)
s = (c * u_hat).sum(dim=1) # (batch, out_caps, out_dim)
else:
# Other iterations: use detached predictions (no gradient)
s = (c * u_hat_detached).sum(dim=1)
# Squashing non-linearity: v_j = squash(s_j)
v = self.squash(s) # (batch, out_caps, out_dim)
# Update routing logits based on agreement (except last iteration)
if iteration < self.num_iterations - 1:
# Agreement: a_ij = u_hat_j|i Β· v_j
v_expand = v.unsqueeze(1) # (batch, 1, out_caps, out_dim)
agreement = (u_hat_detached * v_expand).sum(dim=-1, keepdim=True)
# b_ij += a_ij
b = b + agreement
return v
class EMRoutingLayer(nn.Module):
"""
EM routing between capsule layers (Hinton et al., 2018).
Models capsules as Gaussians, uses EM for routing.
"""
def __init__(self, in_capsules, out_capsules, pose_dim=4, num_iterations=3):
super().__init__()
self.in_capsules = in_capsules
self.out_capsules = out_capsules
self.pose_dim = pose_dim
self.num_iterations = num_iterations
# Transformation matrices for pose
self.W = nn.Parameter(torch.randn(in_capsules, out_capsules,
pose_dim, pose_dim))
# Learned parameters
self.beta_v = nn.Parameter(torch.randn(out_capsules)) # Activation cost
self.beta_a = nn.Parameter(torch.randn(out_capsules)) # Activation threshold
def forward(self, a_in, M_in):
"""
Args:
a_in: (batch, in_capsules) - input activations
M_in: (batch, in_capsules, pose_dim) - input pose matrices
Returns:
a_out: (batch, out_capsules) - output activations
M_out: (batch, out_capsules, pose_dim) - output poses
"""
batch_size = a_in.size(0)
# Compute votes: V_ij = W_ij @ M_i
M_expand = M_in.unsqueeze(2).unsqueeze(4) # (batch, in_caps, 1, pose_dim, 1)
W_expand = self.W.unsqueeze(0) # (1, in_caps, out_caps, pose_dim, pose_dim)
V = torch.matmul(W_expand, M_expand).squeeze(-1) # (batch, in_caps, out_caps, pose_dim)
# Initialize parameters
R = torch.ones(batch_size, self.in_capsules, self.out_capsules,
device=a_in.device) / self.out_capsules # Uniform assignment
# EM iterations
for iteration in range(self.num_iterations):
# M-step: Compute capsule parameters
# Weighted votes
R_expand = R.unsqueeze(-1) # (batch, in_caps, out_caps, 1)
a_expand = a_in.unsqueeze(2).unsqueeze(3) # (batch, in_caps, 1, 1)
# Activation
sum_R_a = (R_expand * a_expand).sum(dim=1) # (batch, out_caps, 1)
a_out = torch.sigmoid(sum_R_a.squeeze(-1) - self.beta_a) # (batch, out_caps)
# Mean pose
weighted_votes = R_expand * a_expand * V # (batch, in_caps, out_caps, pose_dim)
sum_weighted = weighted_votes.sum(dim=1) # (batch, out_caps, pose_dim)
M_out = sum_weighted / (sum_R_a + 1e-8) # (batch, out_caps, pose_dim)
# Variance (assume isotropic)
M_expand = M_out.unsqueeze(1) # (batch, 1, out_caps, pose_dim)
diff = V - M_expand # (batch, in_caps, out_caps, pose_dim)
variance = ((R_expand * a_expand * diff ** 2).sum(dim=1) /
(sum_R_a + 1e-8)) # (batch, out_caps, pose_dim)
sigma_sq = variance.mean(dim=-1, keepdim=True) # (batch, out_caps, 1)
# E-step: Update assignment probabilities
if iteration < self.num_iterations - 1:
# Log probability under Gaussian
log_p = -0.5 * ((diff ** 2) / (sigma_sq.unsqueeze(1) + 1e-8)).sum(dim=-1)
log_p = log_p - 0.5 * self.pose_dim * torch.log(2 * np.pi * sigma_sq.unsqueeze(1))
# Weighted by input activation
log_p_weighted = log_p + torch.log(a_expand.squeeze(-1) + 1e-8)
# Softmax to get new assignments
R = F.softmax(log_p_weighted, dim=2) # (batch, in_caps, out_caps)
return a_out, M_out
class PrimaryCapsuleLayer(nn.Module):
"""
Primary capsule layer: converts CNN features to capsules.
"""
def __init__(self, in_channels, num_capsules, capsule_dim,
kernel_size=9, stride=2, padding=0):
super().__init__()
self.num_capsules = num_capsules
self.capsule_dim = capsule_dim
# Convolution to produce capsule features
out_channels = num_capsules * capsule_dim
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, stride, padding)
self.squash = SquashActivation(dim=-1)
def forward(self, x):
"""
Args:
x: (batch, in_channels, H, W)
Returns:
capsules: (batch, num_capsules * H' * W', capsule_dim)
"""
# Convolution
conv_out = self.conv(x) # (batch, num_caps * cap_dim, H', W')
batch_size = conv_out.size(0)
H_out, W_out = conv_out.size(2), conv_out.size(3)
# Reshape to capsules
# (batch, num_caps, cap_dim, H', W')
conv_out = conv_out.view(batch_size, self.num_capsules,
self.capsule_dim, H_out, W_out)
# Merge spatial dimensions: (batch, num_caps * H' * W', cap_dim)
conv_out = conv_out.permute(0, 1, 3, 4, 2).contiguous()
capsules = conv_out.view(batch_size, -1, self.capsule_dim)
# Apply squashing
return self.squash(capsules)
class CapsuleNetworkClassifier(nn.Module):
"""
Complete CapsNet for classification with reconstruction.
Based on Sabour et al. (2017).
"""
def __init__(self, input_channels=1, num_classes=10,
routing_iterations=3, reconstruction=True):
super().__init__()
self.num_classes = num_classes
self.reconstruction = reconstruction
# Initial convolution
self.conv1 = nn.Conv2d(input_channels, 256, kernel_size=9, stride=1)
# Primary capsules: 32 capsules of 8D
self.primary_caps = PrimaryCapsuleLayer(
in_channels=256,
num_capsules=32,
capsule_dim=8,
kernel_size=9,
stride=2
)
# Calculate number of primary capsules (depends on input size)
# For 28x28 input: conv1(28->20) -> primary(20->6) -> 32*6*6 = 1152 capsules
self.num_primary_caps = 32 * 6 * 6 # For MNIST 28x28
# Digit capsules: 10 capsules of 16D (one per class)
self.digit_caps = DynamicRoutingLayer(
in_capsules=self.num_primary_caps,
out_capsules=num_classes,
in_dim=8,
out_dim=16,
num_iterations=routing_iterations
)
# Decoder for reconstruction regularization
if self.reconstruction:
self.decoder = nn.Sequential(
nn.Linear(16 * num_classes, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 28 * 28),
nn.Sigmoid()
)
def forward(self, x, labels=None):
"""
Args:
x: (batch, channels, H, W) - input images
labels: (batch,) - true labels (for reconstruction)
Returns:
class_probs: (batch, num_classes) - classification probabilities
reconstruction: (batch, H*W) - reconstructed images (if enabled)
"""
# Initial convolution with ReLU
conv_out = F.relu(self.conv1(x)) # (batch, 256, 20, 20)
# Primary capsules
primary_caps = self.primary_caps(conv_out) # (batch, 1152, 8)
# Digit capsules via dynamic routing
digit_caps = self.digit_caps(primary_caps) # (batch, 10, 16)
# Class probabilities: length of capsule vectors
class_probs = torch.sqrt((digit_caps ** 2).sum(dim=-1)) # (batch, 10)
# Reconstruction
if self.reconstruction:
# Mask: select capsule for correct class (training) or predicted class (testing)
if labels is not None:
mask_indices = labels
else:
mask_indices = class_probs.argmax(dim=1)
# Create mask
mask = torch.zeros_like(digit_caps)
mask[torch.arange(digit_caps.size(0)), mask_indices] = 1
# Masked capsules
masked_caps = (digit_caps * mask).view(digit_caps.size(0), -1)
# Decode
reconstruction = self.decoder(masked_caps)
else:
reconstruction = None
return class_probs, reconstruction
class MarginLoss(nn.Module):
"""
Margin loss for capsule networks.
L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1-T_k) * max(0, ||v_k|| - m-)^2
"""
def __init__(self, m_plus=0.9, m_minus=0.1, lambda_=0.5):
super().__init__()
self.m_plus = m_plus
self.m_minus = m_minus
self.lambda_ = lambda_
def forward(self, lengths, labels):
"""
Args:
lengths: (batch, num_classes) - capsule lengths
labels: (batch,) - true class labels
Returns:
loss: scalar
"""
# One-hot encode labels
T = F.one_hot(labels, num_classes=lengths.size(1)).float()
# Loss for present classes
loss_present = T * F.relu(self.m_plus - lengths) ** 2
# Loss for absent classes
loss_absent = (1 - T) * F.relu(lengths - self.m_minus) ** 2
# Total loss
loss = loss_present + self.lambda_ * loss_absent
return loss.sum(dim=1).mean()
class CapsuleTrainer:
"""
Trainer for capsule networks with margin loss and reconstruction.
"""
def __init__(self, model, optimizer, reconstruction_weight=0.0005):
self.model = model
self.optimizer = optimizer
self.margin_loss = MarginLoss()
self.reconstruction_weight = reconstruction_weight
def train_step(self, images, labels):
"""Single training step."""
self.model.train()
# Forward pass
class_probs, reconstruction = self.model(images, labels)
# Margin loss
loss_margin = self.margin_loss(class_probs, labels)
# Reconstruction loss
if reconstruction is not None:
target = images.view(images.size(0), -1)
loss_recon = F.mse_loss(reconstruction, target)
loss = loss_margin + self.reconstruction_weight * loss_recon
else:
loss_recon = torch.tensor(0.0)
loss = loss_margin
# Backward pass
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
return {
'loss': loss.item(),
'margin_loss': loss_margin.item(),
'recon_loss': loss_recon.item() if isinstance(loss_recon, torch.Tensor) else 0.0
}
def evaluate(self, data_loader):
"""Evaluate model on dataset."""
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in data_loader:
class_probs, _ = self.model(images)
predictions = class_probs.argmax(dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = 100.0 * correct / total
return accuracy
# ============================================================================
# Demonstrations
# ============================================================================
print("=" * 70)
print("Capsule Networks - Advanced Implementations")
print("=" * 70)
# 1. Squashing function
print("\n1. Squashing Non-linearity:")
squash = SquashActivation(dim=-1)
s_test = torch.tensor([[0.1, 0.2], [1.0, 2.0], [5.0, 10.0]])
v_test = squash(s_test)
norms = torch.norm(v_test, dim=1)
print(f" Input vectors: {s_test.shape}")
print(f" Input norms: {torch.norm(s_test, dim=1).numpy()}")
print(f" Output norms: {norms.numpy()}")
print(f" Property: All norms in [0, 1): β")
# 2. Dynamic routing
print("\n2. Dynamic Routing:")
routing = DynamicRoutingLayer(
in_capsules=10,
out_capsules=5,
in_dim=8,
out_dim=16,
num_iterations=3
)
u_test = torch.randn(2, 10, 8)
v_out = routing(u_test)
print(f" Input: {u_test.shape} (batch, in_caps, in_dim)")
print(f" Output: {v_out.shape} (batch, out_caps, out_dim)")
print(f" Routing iterations: 3")
print(f" Mechanism: Agreement-based soft assignment")
print(f" Parameters: {sum(p.numel() for p in routing.parameters()):,}")
# 3. EM routing
print("\n3. EM Routing:")
em_routing = EMRoutingLayer(
in_capsules=10,
out_capsules=5,
pose_dim=4,
num_iterations=3
)
a_in = torch.rand(2, 10) # Activations
M_in = torch.randn(2, 10, 4) # Poses
a_out, M_out = em_routing(a_in, M_in)
print(f" Input: activations {a_in.shape}, poses {M_in.shape}")
print(f" Output: activations {a_out.shape}, poses {M_out.shape}")
print(f" Algorithm: E-step (soft assignment) + M-step (Gaussian params)")
print(f" Advantage: Models uncertainty in routing")
# 4. Complete CapsNet
print("\n4. CapsNet Architecture:")
capsnet = CapsuleNetworkClassifier(
input_channels=1,
num_classes=10,
routing_iterations=3,
reconstruction=True
)
x_test = torch.randn(2, 1, 28, 28)
labels_test = torch.tensor([3, 7])
probs, recon = capsnet(x_test, labels_test)
print(f" Input: {x_test.shape}")
print(f" Conv1: 1 -> 256 channels (9x9)")
print(f" Primary Caps: 32 capsules Γ 8D (1,152 total)")
print(f" Digit Caps: 10 capsules Γ 16D (dynamic routing)")
print(f" Output probs: {probs.shape}")
print(f" Reconstruction: {recon.shape}")
print(f" Total parameters: {sum(p.numel() for p in capsnet.parameters()):,}")
# 5. Margin loss
print("\n5. Margin Loss:")
margin_loss_fn = MarginLoss(m_plus=0.9, m_minus=0.1, lambda_=0.5)
lengths_test = torch.tensor([[0.95, 0.05, 0.02], [0.1, 0.85, 0.03]])
labels_test = torch.tensor([0, 1])
loss_test = margin_loss_fn(lengths_test, labels_test)
print(f" Formula: L_k = T_kΒ·max(0, mβΊ - ||v_k||)Β² + Ξ»(1-T_k)Β·max(0, ||v_k|| - mβ»)Β²")
print(f" mβΊ = 0.9 (target for present class)")
print(f" mβ» = 0.1 (max for absent classes)")
print(f" Ξ» = 0.5 (down-weight absent classes)")
print(f" Example loss: {loss_test.item():.4f}")
# 6. Routing comparison
print("\n6. Dynamic vs EM Routing:")
print(" ββββββββββββββββββββ¬ββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ")
print(" β Aspect β Dynamic β EM β Best For β")
print(" ββββββββββββββββββββΌββββββββββββββΌβββββββββββββββΌβββββββββββββββ€")
print(" β Representation β Vectors β Gaussians β EM (richer) β")
print(" β Uncertainty β No β Yes β EM β")
print(" β Complexity β Lower β Higher β Dynamic β")
print(" β Pose info β Implicit β Explicit β EM β")
print(" β Speed β Faster β Slower β Dynamic β")
print(" ββββββββββββββββββββ΄ββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ")
# 7. Equivariance demonstration
print("\n7. Viewpoint Equivariance:")
print(" CNN: Translation equivariant only")
print(" rotate(image) -> features != rotate(features)")
print(" ")
print(" Capsule: Viewpoint equivariant")
print(" rotate(image) -> capsule orientation rotates")
print(" Capsule direction encodes pose information")
print(" ")
print(" Benefit: Better generalization to novel viewpoints")
# 8. When to use guide
print("\n8. When to Use Capsule Networks:")
print(" Use Capsules when:")
print(" β Viewpoint equivariance critical (3D objects)")
print(" β Overlapping objects common (medical imaging)")
print(" β Part-whole relationships important")
print(" β Interpretable representations desired")
print("\n Use CNNs when:")
print(" β Speed critical (real-time inference)")
print(" β Large-scale images (computational cost)")
print(" β Standard classification sufficient")
print(" β Limited compute resources")
# 9. Training tips
print("\n9. Training Best Practices:")
print(" Initialization:")
print(" β’ Xavier/He for weight matrices")
print(" β’ Zero for routing logits")
print(" ")
print(" Hyperparameters:")
print(" β’ Routing iterations: 3 (2-5 works)")
print(" β’ Reconstruction weight: 0.0005")
print(" β’ Margin mβΊ: 0.9, mβ»: 0.1")
print(" β’ Learning rate: 1e-3 with decay")
print(" ")
print(" Regularization:")
print(" β’ Gradient clipping (max norm 1.0)")
print(" β’ Reconstruction loss (forces useful encoding)")
print(" β’ No batch normalization (interferes with magnitudes)")
print("\n" + "=" * 70)
Advanced Capsule Networks TheoryΒΆ
1. Introduction to Capsule NetworksΒΆ
Capsule Networks (CapsNets) represent a paradigm shift from traditional CNNs, replacing scalar neurons with vector-valued capsules.
1.1 MotivationΒΆ
Problems with CNNs:
Pooling loses spatial information: Max pooling discards precise localization
Lack of part-whole relationships: CNNs donβt model hierarchical structure
Pose insensitivity: Difficult to capture viewpoint, orientation, scale
Adversarial vulnerability: Small perturbations fool the network
Capsule philosophy: βActivities of neurons in a capsule represent various properties of the same entity.β
1.2 What is a Capsule?ΒΆ
Capsule: Group of neurons whose activity vector represents:
Length: Probability that entity exists (||v|| β [0, 1])
Orientation: Instantiation parameters (pose, texture, deformation, etc.)
Example: Face detection capsule
Length: Confidence face is present
Orientation: Face angle, position, lighting
Key innovation: Replace scalar activations with vectors.
1.3 Core PrinciplesΒΆ
Equivariance: Activities should change predictably with viewpoint
Part-whole relationships: Lower capsules vote for higher capsules
Routing by agreement: Higher capsules represent agreements of lower capsules
Inverse graphics: Reconstruct input from capsule activations
2. Dynamic Routing AlgorithmΒΆ
Dynamic routing replaces pooling with iterative routing between capsule layers.
2.1 Routing ProcedureΒΆ
Goal: Route lower-level capsules to higher-level capsules based on agreement.
Notation:
u_i: Activation of capsule i in layer l
v_j: Activation of capsule j in layer l+1
W_{ij}: Transformation matrix from i to j
c_{ij}: Coupling coefficient (routing weight)
Prediction vector (vote):
Γ»_{j|i} = W_{ij} u_i
Weighted sum:
s_j = Ξ£_i c_{ij} Γ»_{j|i}
Squashing (non-linearity):
v_j = squash(s_j) = (||s_j||Β² / (1 + ||s_j||Β²)) Β· (s_j / ||s_j||)
Properties of squash:
Short vectors β nearly 0
Long vectors β length close to 1
Direction preserved
2.2 Routing by AgreementΒΆ
Iterative procedure:
Initialize: b_{ij} = 0 (log prior probabilities)
Repeat for r iterations:
Softmax: c_{ij} = exp(b_{ij}) / Ξ£_k exp(b_{ik})
Weighted sum: s_j = Ξ£_i c_{ij} Γ»_{j|i}
Squash: v_j = squash(s_j)
Update: b_{ij} β b_{ij} + Γ»_{j|i} Β· v_j
Agreement: Γ»_{j|i} Β· v_j measures how well prediction matches output
Intuition:
High agreement β increase c_{ij}
Low agreement β decrease c_{ij}
2.3 Routing AlgorithmΒΆ
Procedure: DynamicRouting(Γ»_{j|i}, r, l)
for all capsule i in layer l and capsule j in layer l+1:
b_{ij} β 0
for iteration in 1 to r:
for all capsule i in layer l:
c_i β softmax(b_i) # c_i = [c_{i1}, ..., c_{iJ}]
for all capsule j in layer l+1:
s_j β Ξ£_i c_{ij} Γ»_{j|i}
v_j β squash(s_j)
for all capsule i in layer l and capsule j in layer l+1:
b_{ij} β b_{ij} + Γ»_{j|i} Β· v_j
return v
Complexity: O(r Β· L Β· H Β· dΒ²)
r: Routing iterations (typically 3)
L: Lower capsules
H: Higher capsules
d: Capsule dimension
3. CapsNet ArchitectureΒΆ
3.1 Original CapsNet (Sabour et al., 2017)ΒΆ
Architecture for MNIST:
Layer 1: Convolutional
256 filters, 9Γ9 kernel, stride 1, ReLU
Output: [batch, 20, 20, 256]
Layer 2: PrimaryCaps
32 channels of 8D capsules
Each channel: Conv 9Γ9, stride 2
Output: [batch, 6, 6, 32, 8] = 1,152 capsules of dim 8
Layer 3: DigitCaps
10 capsules (one per class), dim 16
Dynamic routing from PrimaryCaps
Output: [batch, 10, 16]
Prediction: argmax_j ||v_j||
3.2 Squashing FunctionΒΆ
Mathematical form:
squash(s) = (||s||Β² / (1 + ||s||Β²)) Β· (s / ||s||)
Derivative:
βsquash(s) / βs = (1 / (1 + ||s||Β²)) [I + (s s^T / ||s||Β²) Β· (2||s||Β² / (1 + ||s||Β²) - 1)]
Properties:
squash(0) = 0
squash(s) β s/||s|| as ||s|| β β
||squash(s)|| < 1 for all s
3.3 Margin LossΒΆ
For multi-class classification:
L_k = T_k max(0, mβΊ - ||v_k||)Β² + Ξ» (1 - T_k) max(0, ||v_k|| - mβ»)Β²
Where:
T_k = 1 if class k present, 0 otherwise
mβΊ = 0.9: Upper margin
mβ» = 0.1: Lower margin
Ξ» = 0.5: Down-weighting for absent classes
Total loss:
L_margin = Ξ£_k L_k
Intuition:
Present class: ||v_k|| should be > 0.9
Absent class: ||v_k|| should be < 0.1
3.4 Reconstruction RegularizationΒΆ
Decoder network: Reconstruct input from DigitCaps
Mask: Keep only correct class capsule (or predicted during testing)
Architecture:
DigitCaps (16D) β FC(512) β FC(1024) β FC(784) β Sigmoid β Reconstruction
Reconstruction loss:
L_recon = ||x - xΜ||Β²
Total loss:
L_total = L_margin + Ξ± L_recon
Where Ξ± = 0.0005 (small weight to not dominate)
Purpose:
Regularization
Force capsules to learn meaningful representations
Enable interpretability
4. Routing AlgorithmsΒΆ
4.1 Dynamic RoutingΒΆ
Original algorithm (Sabour et al., 2017):
Iterative (3 iterations typical)
Coupling coefficients sum to 1
Agreement-based updates
Advantages:
Captures part-whole relationships
Equivariant to affine transformations
Disadvantages:
Computationally expensive
Non-parallelizable across routing iterations
4.2 EM RoutingΒΆ
EM Routing [Hinton et al., 2018]: Treat routing as EM clustering.
E-step: Assign lower capsules to higher capsules
r_{ij} = a_i Β· p_j Β· exp(-cost_{ij})
Where cost_{ij} is Mahalanobis distance.
M-step: Update higher capsule parameters
ΞΌ_j = (Ξ£_i r_{ij} Γ»_{j|i}) / (Ξ£_i r_{ij})
ΟΒ²_j = (Ξ£_i r_{ij} (Γ»_{j|i} - ΞΌ_j)Β²) / (Ξ£_i r_{ij})
Activation:
a_j = logistic(Ξ» (Ξ²_a - Ξ£_i cost_j))
Advantages:
Probabilistic interpretation
Models uncertainty (variance)
Disadvantages:
More complex
More hyperparameters
4.3 Self-RoutingΒΆ
Self-Routing [Hahn et al., 2019]: Non-iterative, single-pass routing.
Key idea: Predict routing coefficients directly.
c_{ij} = softmax_j(MLP([u_i, Γ»_{j|i}]))
Advantages:
Faster (no iterations)
Parallelizable
Gradient-friendly
Disadvantages:
Less expressive than iterative routing
4.4 Attention RoutingΒΆ
Use attention mechanism for routing:
c_{ij} = softmax_j((W_Q u_i)^T (W_K Γ»_{j|i}) / βd)
Similar to transformer attention but for capsule routing.
5. Capsule LayersΒΆ
5.1 Primary CapsulesΒΆ
First capsule layer: Converts CNN features to capsules.
Implementation:
conv_output = Conv(input) # [batch, H, W, C]
capsules = reshape(conv_output, [batch, H, W, num_capsules, capsule_dim])
capsules = squash(capsules)
Purpose:
Extract low-level features
Initialize capsule hierarchy
5.2 Convolutional CapsulesΒΆ
ConvCaps [Hinton et al., 2018]: Capsules with local receptive fields.
Votes:
Γ»_{j|i} = W_{ij} u_i
Where i ranges over spatial neighborhood.
Advantages:
Translation equivariance
Fewer parameters than fully-connected capsules
5.3 Class CapsulesΒΆ
Final layer: One capsule per class.
Length: Class probability Orientation: Class-specific properties
6. Variants and ExtensionsΒΆ
6.1 Matrix Capsules (Hinton et al., 2018)ΒΆ
Capsules as matrices instead of vectors:
Activation: Matrix M_{ij}
Represents pose (rotation, scale, etc.)
Routing: EM algorithm
Advantages:
More expressive (4Γ4 matrix = 16 parameters)
Better models affine transformations
6.2 Stacked Capsule Autoencoders (SCAE)ΒΆ
SCAE [Kosiorek et al., 2019]: Unsupervised capsule learning.
Components:
Part Capsules: Discover object parts
Object Capsules: Discover objects from parts
Template-based: Learn templates for parts
Training:
Reconstruction loss
Set loss (permutation invariant)
6.3 Efficient CapsNetsΒΆ
Bottleneck: Routing is expensive (O(L Β· H Β· dΒ²)).
Solutions:
1. Sparse Routing:
Route only to k nearest capsules
k-means clustering
2. Low-Rank Approximation:
W_{ij} = U_i V_j^T
Reduces parameters from dΒ² to 2dr.
3. Shared Weights:
Weight sharing across spatial positions
Similar to convolution
6.4 Capsule GANsΒΆ
CapsGAN: Use capsules in generator/discriminator.
Benefits:
Disentangled representations
Controllable generation
6.5 Capsule TransformersΒΆ
CapsFormer: Combine capsules with transformers.
Routing as attention:
Attention(Q, K, V) = softmax(QK^T / βd) V
Where Q, K, V are capsule activations.
7. Training ConsiderationsΒΆ
7.1 Loss FunctionsΒΆ
Margin loss (classification):
L_k = T_k max(0, mβΊ - ||v_k||)Β² + Ξ» (1 - T_k) max(0, ||v_k|| - mβ»)Β²
Spread loss [Hinton et al., 2018]:
L = Ξ£_{iβ t} max(0, margin - (a_t - a_i))Β²
Encourages correct class activation to exceed others by margin.
Reconstruction loss:
L_recon = ||x - decoder(mask(v))||Β²
7.2 HyperparametersΒΆ
Critical hyperparameters:
Parameter |
Typical Value |
Notes |
|---|---|---|
Routing iterations r |
3 |
More iterations = better, but slower |
Capsule dimension |
8-16 |
Higher = more expressive |
mβΊ (upper margin) |
0.9 |
Target for present class |
mβ» (lower margin) |
0.1 |
Target for absent class |
Ξ» (down-weight) |
0.5 |
Balance false positives |
Ξ± (recon weight) |
0.0005 |
Regularization strength |
7.3 InitializationΒΆ
Capsule weights: Xavier/He initialization
Routing logits: Initialize b_{ij} = 0 (uniform routing initially)
Reconstruction decoder: Standard initialization
7.4 OptimizationΒΆ
Optimizer: Adam typical (lr = 0.001)
Gradient clipping: Important due to routing dynamics
Batch normalization: Can be applied to capsule activations
8. Theoretical PropertiesΒΆ
8.1 EquivarianceΒΆ
Capsules exhibit equivariance to transformations:
If input undergoes affine transformation T:
v(T(x)) β T(v(x))
Proof sketch:
Transformation matrices W_{ij} model geometric relationships
Routing preserves structure through agreement
8.2 Part-Whole RelationshipsΒΆ
Capsules model compositional hierarchy:
Lower capsule u_i represents part. Higher capsule v_j represents whole. Agreement Γ»_{j|i} Β· v_j measures compatibility.
Example: Face recognition
Lower: Eyes, nose, mouth capsules
Higher: Face capsule
Routing: Face capsule active when parts agree on face pose
8.3 Viewpoint InvarianceΒΆ
Traditional CNNs: Achieve invariance via pooling (loses info).
CapsNets: Achieve equivariance (retains info).
Key difference:
Invariance: f(T(x)) = f(x)
Equivariance: f(T(x)) = T(f(x))
Equivariance is more informative!
8.4 RobustnessΒΆ
Adversarial robustness: Capsules more robust than CNNs [Hinton et al., 2018].
Reason:
Requires fooling multiple capsule dimensions
Agreement mechanism filters inconsistent perturbations
Empirical evidence:
CapsNets achieve higher accuracy under adversarial attacks
More interpretable failure modes
9. ApplicationsΒΆ
9.1 Image ClassificationΒΆ
MNIST: 99.75% accuracy (state-of-art among non-ensemble methods)
CIFAR-10: ~75% accuracy (competitive but not state-of-art)
SmallNORB: Excellent performance (viewpoint variation)
Observation: CapsNets excel on tasks requiring pose understanding.
9.2 Object DetectionΒΆ
CapsNet-based detectors:
Replace RPN with capsule proposals
Capsule features for bounding box regression
Challenges:
Scalability to high-resolution images
Computational cost
9.3 SegmentationΒΆ
Segmentation CapsNets:
Encoder-decoder with capsule layers
Pixel-wise capsules
Deconvolutional capsules for upsampling
Benefits:
Better boundary delineation
Part-aware segmentation
9.4 Medical ImagingΒΆ
Success stories:
Tumor segmentation
Pathology classification
X-ray analysis
Reasons for success:
Small datasets (regularization via reconstruction)
Viewpoint variation (capsule equivariance)
Interpretability (clinician trust)
9.5 Video AnalysisΒΆ
Temporal Capsules:
3D convolutions β 3D capsules
Routing across time
Applications:
Action recognition
Video prediction
Anomaly detection
10. Advantages and DisadvantagesΒΆ
10.1 AdvantagesΒΆ
Equivariance: Capsule activities change predictably with transformations
Part-whole relationships: Explicit modeling of hierarchical structure
Fewer parameters: Can achieve comparable performance with fewer params
Interpretability: Capsule dimensions have semantic meaning
Robustness: More robust to adversarial attacks
Data efficiency: Reconstruction regularization helps with small datasets
10.2 DisadvantagesΒΆ
Computational cost: Routing is expensive (O(r Β· L Β· H Β· dΒ²))
Memory: Vector activations require more memory than scalars
Scalability: Difficult to scale to ImageNet-size images
Training instability: Routing dynamics can be unstable
Limited adoption: Not as widely used as CNNs/Transformers
Engineering: Less optimized libraries and hardware support
11. Recent Advances (2019-2024)ΒΆ
11.1 Efficient RoutingΒΆ
Fast Capsule Networks [Gu et al., 2021]:
Approximate routing via matrix sketching
O(L Β· H Β· d) complexity (linear in d)
Parallel Routing:
Parallelize across routing iterations
Use gradient checkpointing to save memory
11.2 Group Equivariant CapsulesΒΆ
G-CapsNets [Lenssen et al., 2018]:
Equivariance to group transformations (rotations, reflections)
Steerable filters in capsule space
SE(3)-equivariant capsules: For 3D data (point clouds, molecules)
11.3 Self-Supervised CapsulesΒΆ
Contrastive Capsules:
Pre-train capsules with contrastive learning
Augmentation-aware capsules
Masked Capsule Modeling:
Similar to BERT
Predict masked capsules from context
11.4 Hybrid ArchitecturesΒΆ
CapsNet + Transformers:
Use capsules for low-level features
Transformers for global context
CapsNet + Graph Networks:
Capsules as graph nodes
Message passing for routing
11.5 Theoretical InsightsΒΆ
Connection to EM algorithm [Hinton et al., 2018]:
Routing as E-step (assignment)
Capsule updates as M-step (parameter updates)
Neural ODEs: Routing as continuous dynamics
dv/dt = f(v, u, W)
12. Comparison with CNNsΒΆ
12.1 ArchitectureΒΆ
CNNs:
Scalar activations
Pooling for invariance
Hierarchical features
CapsNets:
Vector activations
Routing for equivariance
Part-whole relationships
12.2 PerformanceΒΆ
ImageNet:
CNNs: 75-90% top-1 accuracy
CapsNets: Not competitive at this scale yet
MNIST:
CNNs: 99.7%
CapsNets: 99.75%
SmallNORB (viewpoint):
CNNs: ~95%
CapsNets: ~97%
12.3 Computational CostΒΆ
Forward pass:
CNNs: O(H Β· W Β· C Β· KΒ²) per layer
CapsNets: O(H Β· W Β· C Β· D) + O(r Β· L Β· H Β· DΒ²) routing
Memory:
CNNs: O(H Β· W Β· C)
CapsNets: O(H Β· W Β· C Β· D) (D = capsule dim)
Training time: CapsNets typically 2-5Γ slower
13. Implementation ChallengesΒΆ
13.1 Numerical StabilityΒΆ
Issues:
Division by ||s|| in squashing
Softmax overflow in routing
Solutions:
Add Ξ΅ to denominators: s / (||s|| + Ξ΅)
Log-space softmax
13.2 Gradient FlowΒΆ
Problem: Routing iterations can block gradients.
Solutions:
Gradient checkpointing
Differentiable routing (no stop_gradient)
13.3 ScalabilityΒΆ
Challenge: Routing doesnβt scale to millions of capsules.
Solutions:
Convolutional capsules (local routing)
Sparse routing (route to k nearest)
Hierarchical routing (multi-stage)
13.4 Hardware OptimizationΒΆ
GPU utilization: Routing is memory-bound, not compute-bound.
Solutions:
Fused kernels (combine routing iterations)
Mixed precision training
Model parallelism (split capsules across GPUs)
14. Benchmarks and DatasetsΒΆ
14.1 MNISTΒΆ
Task: Digit classification (10 classes)
CapsNet performance: 99.75% test accuracy
Ablations:
Without routing: 99.23%
Without reconstruction: 99.52%
14.2 CIFAR-10ΒΆ
Task: Object classification (10 classes)
CapsNet performance: ~75% test accuracy
CNN performance: 95%+ (ResNet, EfficientNet)
Observation: CapsNets lag behind CNNs on complex natural images.
14.3 SmallNORBΒΆ
Task: 3D object recognition with viewpoint variation
CapsNet performance: 97.3% test accuracy
CNN performance: ~95%
Why CapsNets excel: Viewpoint equivariance!
14.4 MultiMNISTΒΆ
Task: Recognize overlapping digits
CapsNet performance: 95% (both digits correct)
CNN performance: ~80%
Reason: Capsules segment objects via routing.
15. Future DirectionsΒΆ
15.1 Open ProblemsΒΆ
Scalability: Scale to ImageNet-size images efficiently
Architecture search: Optimal capsule layer configurations
Theoretical understanding: Formal analysis of routing dynamics
Pre-training: Effective self-supervised pre-training for capsules
Hardware: Specialized accelerators for capsule operations
15.2 Promising DirectionsΒΆ
Vision Transformers: Combine with attention mechanisms
3D vision: Point clouds, meshes, volumes
Generative models: Disentangled VAEs, GANs
Graph neural networks: Capsules as graph nodes
Reinforcement learning: Equivariant policies
15.3 Potential BreakthroughsΒΆ
Continuous capsules: Neural ODEs for routing
Quantum capsules: Leverage quantum superposition
Neuromorphic capsules: Spiking neural implementations
Causal capsules: Model causal relationships
16. Key TakeawaysΒΆ
Capsules = vectors: Length = probability, orientation = properties
Routing by agreement: Iterative algorithm to connect capsules
Squashing: Non-linear activation preserving direction
Margin loss: Encourage ||v_class|| > 0.9, ||v_other|| < 0.1
Reconstruction: Regularization and interpretability
Equivariance: Activities change with transformations (not pooling)
Part-whole: Explicit hierarchical modeling
Robustness: More resistant to adversarial attacks
Trade-off: Better representations vs. computational cost
Future: Hybrid architectures combining capsules with transformers/GNNs
Core insight: CapsNets preserve information through equivariance rather than discarding it through invariance, enabling better modeling of part-whole relationships and geometric transformations.
17. Mathematical SummaryΒΆ
Squashing:
v_j = (||s_j||Β² / (1 + ||s_j||Β²)) Β· (s_j / ||s_j||)
Dynamic routing:
c_{ij} = exp(b_{ij}) / Ξ£_k exp(b_{ik})
s_j = Ξ£_i c_{ij} W_{ij} u_i
v_j = squash(s_j)
b_{ij} β b_{ij} + W_{ij} u_i Β· v_j
Margin loss:
L_k = T_k max(0, 0.9 - ||v_k||)Β² + 0.5 (1 - T_k) max(0, ||v_k|| - 0.1)Β²
Total loss:
L = Ξ£_k L_k + 0.0005 ||x - xΜ||Β²
ReferencesΒΆ
Sabour et al. (2017) βDynamic Routing Between Capsulesβ
Hinton et al. (2018) βMatrix Capsules with EM Routingβ
Kosiorek et al. (2019) βStacked Capsule Autoencodersβ
Hahn et al. (2019) βSelf-Routing Capsule Networksβ
Gu et al. (2021) βFast Capsule Networks via Interest Routingβ
Lenssen et al. (2018) βGroup Equivariant Capsule Networksβ
Hinton et al. (2011) βTransforming Auto-Encodersβ
Sabour et al. (2018) βAdversarial Manipulation of Deep Representationsβ
Xi et al. (2017) βCapsule Network Performance on Complex Dataβ
Rajasegaran et al. (2019) βDeepCaps: Going Deeper with Capsule Networksβ
"""
Advanced Capsule Networks Implementations
Production-ready PyTorch implementations of Capsule Networks with dynamic routing.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List
import matplotlib.pyplot as plt
# ============================================================================
# 1. Squashing Activation
# ============================================================================
def squash(s: torch.Tensor, dim: int = -1, epsilon: float = 1e-8) -> torch.Tensor:
"""
Squashing non-linearity for capsule networks.
v = (||s||Β² / (1 + ||s||Β²)) * (s / ||s||)
Properties:
- Short vectors β nearly 0
- Long vectors β length close to 1
- Direction preserved
Args:
s: Input tensor [..., capsule_dim]
dim: Dimension to compute norm over
epsilon: Small constant for numerical stability
Returns:
v: Squashed tensor with same shape
"""
s_norm_sq = torch.sum(s ** 2, dim=dim, keepdim=True)
s_norm = torch.sqrt(s_norm_sq + epsilon)
# v = (||s||Β² / (1 + ||s||Β²)) * (s / ||s||)
scale = s_norm_sq / (1.0 + s_norm_sq)
v = scale * s / s_norm
return v
# ============================================================================
# 2. Dynamic Routing
# ============================================================================
class DynamicRouting(nn.Module):
"""
Dynamic routing algorithm for capsule networks.
Routes lower-level capsules to higher-level capsules based on agreement.
Args:
num_iterations: Number of routing iterations (typically 3)
"""
def __init__(self, num_iterations: int = 3):
super().__init__()
self.num_iterations = num_iterations
def forward(
self,
u_hat: torch.Tensor,
batch_size: int
) -> torch.Tensor:
"""
Perform dynamic routing.
Args:
u_hat: Prediction vectors [batch, num_lower, num_higher, capsule_dim]
batch_size: Batch size
Returns:
v: Higher-level capsule activations [batch, num_higher, capsule_dim]
"""
num_lower = u_hat.size(1)
num_higher = u_hat.size(2)
capsule_dim = u_hat.size(3)
# Initialize routing logits b_ij to 0
# [batch, num_lower, num_higher, 1]
b = torch.zeros(batch_size, num_lower, num_higher, 1, device=u_hat.device)
for iteration in range(self.num_iterations):
# Softmax over higher capsules (dim=2)
# c_ij = exp(b_ij) / Ξ£_k exp(b_ik)
c = F.softmax(b, dim=2) # [batch, num_lower, num_higher, 1]
# Weighted sum of predictions
# s_j = Ξ£_i c_ij * u_hat_j|i
s = torch.sum(c * u_hat, dim=1) # [batch, num_higher, capsule_dim]
# Squashing
v = squash(s, dim=-1) # [batch, num_higher, capsule_dim]
# Update routing logits (except last iteration)
if iteration < self.num_iterations - 1:
# Agreement: u_hat_j|i Β· v_j
# [batch, num_lower, num_higher, capsule_dim] Β· [batch, 1, num_higher, capsule_dim]
v_expanded = v.unsqueeze(1) # [batch, 1, num_higher, capsule_dim]
agreement = torch.sum(u_hat * v_expanded, dim=-1, keepdim=True)
# b_ij β b_ij + agreement
b = b + agreement
return v
# ============================================================================
# 3. Primary Capsules
# ============================================================================
class PrimaryCaps(nn.Module):
"""
Primary Capsule Layer.
Converts convolutional feature maps into capsules.
Args:
in_channels: Number of input channels
num_capsules: Number of capsule types
capsule_dim: Dimension of each capsule
kernel_size: Convolution kernel size
stride: Convolution stride
"""
def __init__(
self,
in_channels: int,
num_capsules: int,
capsule_dim: int,
kernel_size: int = 9,
stride: int = 2
):
super().__init__()
self.num_capsules = num_capsules
self.capsule_dim = capsule_dim
# Convolutional layers for each capsule type
self.capsules = nn.ModuleList([
nn.Conv2d(
in_channels,
capsule_dim,
kernel_size=kernel_size,
stride=stride,
padding=0
)
for _ in range(num_capsules)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, in_channels, height, width]
Returns:
capsules: [batch, num_capsules * spatial_size, capsule_dim]
"""
batch_size = x.size(0)
# Apply each capsule convolution
outputs = [capsule(x) for capsule in self.capsules]
# Stack: [batch, num_capsules, capsule_dim, height, width]
outputs = torch.stack(outputs, dim=1)
# Reshape to [batch, num_capsules, capsule_dim, height * width]
outputs = outputs.view(
batch_size,
self.num_capsules,
self.capsule_dim,
-1
)
# Transpose to [batch, num_capsules, height * width, capsule_dim]
outputs = outputs.transpose(2, 3)
# Flatten spatial dimensions
# [batch, num_capsules * height * width, capsule_dim]
outputs = outputs.contiguous().view(batch_size, -1, self.capsule_dim)
# Apply squashing
outputs = squash(outputs, dim=-1)
return outputs
# ============================================================================
# 4. Digit Capsules (Routing Capsules)
# ============================================================================
class DigitCaps(nn.Module):
"""
Digit Capsule Layer with dynamic routing.
Routes from lower capsules to higher capsules.
Args:
num_lower: Number of lower-level capsules
num_higher: Number of higher-level capsules
lower_dim: Dimension of lower capsules
higher_dim: Dimension of higher capsules
num_routing: Number of routing iterations
"""
def __init__(
self,
num_lower: int,
num_higher: int,
lower_dim: int,
higher_dim: int,
num_routing: int = 3
):
super().__init__()
self.num_lower = num_lower
self.num_higher = num_higher
self.lower_dim = lower_dim
self.higher_dim = higher_dim
# Transformation matrices W_ij for each pair (i, j)
# [num_lower, num_higher, higher_dim, lower_dim]
self.W = nn.Parameter(
torch.randn(num_lower, num_higher, higher_dim, lower_dim)
)
# Dynamic routing
self.routing = DynamicRouting(num_iterations=num_routing)
def forward(self, u: torch.Tensor) -> torch.Tensor:
"""
Args:
u: Lower capsules [batch, num_lower, lower_dim]
Returns:
v: Higher capsules [batch, num_higher, higher_dim]
"""
batch_size = u.size(0)
# Compute prediction vectors u_hat_j|i = W_ij * u_i
# Expand u: [batch, num_lower, 1, lower_dim, 1]
u_expanded = u.unsqueeze(2).unsqueeze(-1)
# Expand W: [1, num_lower, num_higher, higher_dim, lower_dim]
W_expanded = self.W.unsqueeze(0)
# Matrix multiply: [batch, num_lower, num_higher, higher_dim, 1]
u_hat = torch.matmul(W_expanded, u_expanded)
# Squeeze: [batch, num_lower, num_higher, higher_dim]
u_hat = u_hat.squeeze(-1)
# Dynamic routing
v = self.routing(u_hat, batch_size)
return v
# ============================================================================
# 5. CapsNet Architecture
# ============================================================================
class CapsNet(nn.Module):
"""
Capsule Network for image classification.
Architecture:
1. Convolutional layer
2. Primary capsules
3. Digit capsules (with routing)
4. Optional reconstruction network
Args:
input_channels: Number of input channels (1 for grayscale, 3 for RGB)
num_classes: Number of output classes
num_routing: Number of routing iterations
use_reconstruction: Whether to include reconstruction network
"""
def __init__(
self,
input_channels: int = 1,
num_classes: int = 10,
num_routing: int = 3,
use_reconstruction: bool = True
):
super().__init__()
self.num_classes = num_classes
self.use_reconstruction = use_reconstruction
# Layer 1: Convolutional
self.conv1 = nn.Conv2d(
in_channels=input_channels,
out_channels=256,
kernel_size=9,
stride=1,
padding=0
)
# Layer 2: Primary Capsules
# 32 capsule types, each 8D
self.primary_caps = PrimaryCaps(
in_channels=256,
num_capsules=32,
capsule_dim=8,
kernel_size=9,
stride=2
)
# Calculate number of primary capsules
# For 28x28 MNIST: (28-9+1) = 20, then (20-9+1)/2 = 6
# 32 capsule types * 6 * 6 = 1152 capsules
num_primary_caps = 32 * 6 * 6
# Layer 3: Digit Capsules
self.digit_caps = DigitCaps(
num_lower=num_primary_caps,
num_higher=num_classes,
lower_dim=8,
higher_dim=16,
num_routing=num_routing
)
# Reconstruction network
if use_reconstruction:
self.decoder = nn.Sequential(
nn.Linear(16 * num_classes, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, input_channels * 28 * 28),
nn.Sigmoid()
)
def forward(
self,
x: torch.Tensor,
y: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x: Input images [batch, channels, height, width]
y: Ground truth labels [batch] (for reconstruction masking)
Returns:
class_scores: Length of digit capsules [batch, num_classes]
digit_caps: Digit capsule activations [batch, num_classes, 16]
reconstruction: Reconstructed images [batch, C, H, W] or None
"""
batch_size = x.size(0)
# Convolutional layer
x = F.relu(self.conv1(x))
# Primary capsules
primary = self.primary_caps(x)
# Digit capsules
digits = self.digit_caps(primary)
# Class scores: Length of capsule vectors
class_scores = torch.sqrt(torch.sum(digits ** 2, dim=-1))
# Reconstruction
reconstruction = None
if self.use_reconstruction:
# Mask: Keep only correct class (training) or predicted class (testing)
if y is None:
# Testing: Use predicted class
_, max_indices = class_scores.max(dim=1)
y = max_indices
# Create mask
mask = torch.zeros(batch_size, self.num_classes, device=x.device)
mask[torch.arange(batch_size), y] = 1.0
# Apply mask: [batch, num_classes, 16]
masked = digits * mask.unsqueeze(-1)
# Flatten
masked_flat = masked.view(batch_size, -1)
# Decode
reconstruction = self.decoder(masked_flat)
reconstruction = reconstruction.view(batch_size, -1, 28, 28)
return class_scores, digits, reconstruction
# ============================================================================
# 6. Loss Functions
# ============================================================================
class MarginLoss(nn.Module):
"""
Margin loss for capsule networks.
L_k = T_k * max(0, m+ - ||v_k||)Β² + Ξ» * (1 - T_k) * max(0, ||v_k|| - m-)Β²
Args:
m_plus: Upper margin (default: 0.9)
m_minus: Lower margin (default: 0.1)
lambda_: Down-weighting for absent classes (default: 0.5)
"""
def __init__(
self,
m_plus: float = 0.9,
m_minus: float = 0.1,
lambda_: float = 0.5
):
super().__init__()
self.m_plus = m_plus
self.m_minus = m_minus
self.lambda_ = lambda_
def forward(
self,
class_scores: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
class_scores: Length of capsules [batch, num_classes]
labels: Ground truth [batch]
Returns:
loss: Margin loss
loss_dict: Breakdown of loss components
"""
batch_size = class_scores.size(0)
num_classes = class_scores.size(1)
# One-hot encoding
labels_one_hot = F.one_hot(labels, num_classes).float()
# Present class loss: T_k * max(0, m+ - ||v_k||)Β²
loss_present = labels_one_hot * F.relu(self.m_plus - class_scores) ** 2
# Absent class loss: Ξ» * (1 - T_k) * max(0, ||v_k|| - m-)Β²
loss_absent = self.lambda_ * (1 - labels_one_hot) * F.relu(class_scores - self.m_minus) ** 2
# Total
loss = torch.sum(loss_present + loss_absent, dim=1).mean()
loss_dict = {
'margin': loss.item(),
'present': torch.sum(loss_present).item() / batch_size,
'absent': torch.sum(loss_absent).item() / batch_size
}
return loss, loss_dict
class ReconstructionLoss(nn.Module):
"""
Reconstruction loss for capsule networks.
L_recon = ||x - x_reconstructed||Β²
"""
def __init__(self):
super().__init__()
def forward(
self,
reconstruction: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""
Args:
reconstruction: Reconstructed images [batch, C, H, W]
target: Original images [batch, C, H, W]
Returns:
loss: MSE reconstruction loss
"""
loss = F.mse_loss(reconstruction, target, reduction='sum')
loss = loss / target.size(0) # Average over batch
return loss
class CapsNetLoss(nn.Module):
"""
Combined loss for CapsNet: Margin + Reconstruction.
L_total = L_margin + Ξ± * L_recon
Args:
alpha: Weight for reconstruction loss (default: 0.0005)
"""
def __init__(self, alpha: float = 0.0005):
super().__init__()
self.alpha = alpha
self.margin_loss = MarginLoss()
self.recon_loss = ReconstructionLoss()
def forward(
self,
class_scores: torch.Tensor,
labels: torch.Tensor,
reconstruction: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
class_scores: [batch, num_classes]
labels: [batch]
reconstruction: [batch, C, H, W] or None
target: [batch, C, H, W] or None
Returns:
total_loss, loss_dict
"""
# Margin loss
margin_loss, margin_dict = self.margin_loss(class_scores, labels)
# Reconstruction loss
if reconstruction is not None and target is not None:
recon_loss = self.recon_loss(reconstruction, target)
total_loss = margin_loss + self.alpha * recon_loss
else:
recon_loss = 0.0
total_loss = margin_loss
loss_dict = {
'total': total_loss.item(),
'margin': margin_loss.item(),
'recon': recon_loss if isinstance(recon_loss, float) else recon_loss.item(),
**margin_dict
}
return total_loss, loss_dict
# ============================================================================
# 7. Demo and Visualization
# ============================================================================
def demo_squashing():
"""Demonstrate squashing activation."""
print("=" * 80)
print("Squashing Activation Demo")
print("=" * 80)
# Generate input vectors with varying magnitudes
magnitudes = torch.linspace(0, 5, 100)
directions = torch.randn(100, 8)
directions = F.normalize(directions, p=2, dim=1)
# Scale by magnitudes
s = directions * magnitudes.unsqueeze(1)
# Apply squashing
v = squash(s, dim=1)
# Compute output norms
v_norms = torch.sqrt(torch.sum(v ** 2, dim=1))
print(f"\nInput magnitudes: {magnitudes[:5].tolist()}")
print(f"Output norms: {v_norms[:5].tolist()}")
print(f"\nMax output norm: {v_norms.max().item():.4f} (should be < 1)")
print(f"Output norm at ||s||=5: {v_norms[-1].item():.4f}")
# Check direction preservation
cosine_sim = F.cosine_similarity(s, v, dim=1)
print(f"\nDirection preservation (cosine similarity): {cosine_sim.mean().item():.4f}")
def demo_dynamic_routing():
"""Demonstrate dynamic routing."""
print("\n" + "=" * 80)
print("Dynamic Routing Demo")
print("=" * 80)
batch_size = 4
num_lower = 1152 # Primary capsules
num_higher = 10 # Digit capsules
lower_dim = 8
higher_dim = 16
# Create random prediction vectors
u_hat = torch.randn(batch_size, num_lower, num_higher, higher_dim)
# Apply routing
routing = DynamicRouting(num_iterations=3)
v = routing(u_hat, batch_size)
print(f"\nInput shape (predictions): {u_hat.shape}")
print(f"Output shape (higher capsules): {v.shape}")
# Check output norms
v_norms = torch.sqrt(torch.sum(v ** 2, dim=-1))
print(f"\nOutput norms (should be < 1):")
print(f" Mean: {v_norms.mean().item():.4f}")
print(f" Max: {v_norms.max().item():.4f}")
print(f" Min: {v_norms.min().item():.4f}")
def demo_capsnet():
"""Demonstrate full CapsNet."""
print("\n" + "=" * 80)
print("CapsNet Architecture Demo")
print("=" * 80)
# Create model
model = CapsNet(
input_channels=1,
num_classes=10,
num_routing=3,
use_reconstruction=True
)
# Generate dummy batch
batch_size = 4
images = torch.randn(batch_size, 1, 28, 28)
labels = torch.randint(0, 10, (batch_size,))
# Forward pass
print(f"\nInput shape: {images.shape}")
class_scores, digit_caps, reconstruction = model(images, labels)
print(f"Class scores shape: {class_scores.shape}")
print(f"Digit capsules shape: {digit_caps.shape}")
print(f"Reconstruction shape: {reconstruction.shape if reconstruction is not None else None}")
# Compute loss
loss_fn = CapsNetLoss(alpha=0.0005)
total_loss, loss_dict = loss_fn(class_scores, labels, reconstruction, images)
print(f"\nLoss breakdown:")
print(f" Total: {loss_dict['total']:.4f}")
print(f" Margin: {loss_dict['margin']:.4f}")
print(f" Reconstruction: {loss_dict['recon']:.4f}")
print(f" Present class: {loss_dict['present']:.4f}")
print(f" Absent class: {loss_dict['absent']:.4f}")
# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {num_params:,}")
def print_performance_comparison():
"""Print performance comparison."""
print("\n" + "=" * 80)
print("CAPSULE NETWORKS PERFORMANCE")
print("=" * 80)
comparison = """
MNIST Classification:
ββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Test Acc (%) β Parameters β Notes β
ββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β Baseline CNN β 99.70 β ~1M β Standard β
β CapsNet (no recon) β 99.52 β 8.2M β Routing only β
β CapsNet (+ recon) β 99.75 β 8.2M + 6.8M β Full model β
β Ensemble CNN (5) β 99.79 β ~5M β 5 models β
ββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
SmallNORB (Viewpoint Variation):
ββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Test Acc (%) β Notes β
ββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β CNN Baseline β 92.8 β ResNet-like β
β Convolutional CNN β 95.0 β Optimized β
β CapsNet (3 routing) β 97.3 β Equivariant β
β Matrix Capsules β 98.5 β EM routing β
ββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
MultiMNIST (Overlapping Digits):
ββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Both Correct β Notes β
ββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β CNN β ~80% β Struggles β
β CapsNet β 95.0% β Routing wins β
ββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
CIFAR-10:
ββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Method β Test Acc (%) β Notes β
ββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β ResNet-110 β 93.6 β Standard β
β DenseNet-BC β 95.5 β State-of-art β
β CapsNet (basic) β 75.0 β Not scalable β
β Efficient CapsNet β 82.0 β Optimized β
ββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
Routing Iterations Impact (MNIST):
ββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
β Iterations β Accuracy (%) β Time (ms) β
ββββββββββββββββΌβββββββββββββββΌβββββββββββββββ€
β 1 β 99.23 β 15.2 β
β 3 β 99.75 β 18.7 β
β 5 β 99.76 β 23.1 β
β 7 β 99.75 β 28.4 β
ββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ
Computational Complexity:
ββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββ
β Operation β Complexity β
ββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββ€
β Conv Layer β O(HΒ·WΒ·CΒ·KΒ²) β
β Primary Caps β O(HΒ·WΒ·CΒ·D) β
β Dynamic Routing β O(rΒ·LΒ·HΒ·DΒ²) β
β r = iterations β 3 typical β
β L = lower caps β 1152 (MNIST) β
β H = higher caps β 10 (MNIST) β
β D = capsule dim β 16 (MNIST) β
ββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββ
Key Advantages:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β β Equivariance: Activities change with transforms β
β β Part-whole: Explicit hierarchical modeling β
β β Robustness: Resistant to adversarial attacks β
β β Interpretability: Capsule dims have meaning β
β β Overlapping: Segments overlapping objects β
β β Viewpoint: Excellent on pose variation tasks β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Key Disadvantages:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β β Scalability: Expensive for high-res images β
β β Speed: 2-5Γ slower than CNNs β
β β Memory: Vector activations use more RAM β
β β Optimization: Less mature than CNNs β
β β Hardware: Not optimized for GPUs β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Decision Guide:
ββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββ
β Use CapsNets When: β Use CNNs When: β
ββββββββββββββββββββββββββββββββΌββββββββββββββββββββββ€
β β’ Viewpoint variation β β’ Large images β
β β’ Part-whole relationships β β’ Speed critical β
β β’ Overlapping objects β β’ Simple tasks β
β β’ Small datasets β β’ Well-established β
β β’ Interpretability needed β β’ ImageNet scale β
β β’ 3D understanding β β’ Production ready β
ββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββ
"""
print(comparison)
print("\nKey Insights:")
print("1. CapsNets excel on viewpoint variation (SmallNORB: 97.3% vs 95.0%)")
print("2. Routing enables segmentation of overlapping objects (MultiMNIST: 95%)")
print("3. 3 routing iterations optimal (diminishing returns after)")
print("4. Reconstruction regularization improves accuracy (+0.23% on MNIST)")
print("5. Computational cost: O(rΒ·LΒ·HΒ·DΒ²) limits scalability")
print("6. Trade-off: Better representations vs. speed (2-5Γ slower)")
print("7. MNIST/SmallNORB: CapsNets competitive or better")
print("8. CIFAR-10/ImageNet: CNNs still superior")
print("9. Future: Hybrid architectures combining best of both")
print("10. Best for: Pose, 3D, overlapping objects, interpretability")
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
demo_squashing()
demo_dynamic_routing()
demo_capsnet()
print_performance_comparison()
print("\n" + "=" * 80)
print("Capsule Networks Implementations Complete!")
print("=" * 80)
Advanced Capsule Networks: Mathematical Foundations and Modern ArchitecturesΒΆ
1. Introduction to Capsule NetworksΒΆ
Capsule Networks (CapsNets) represent a paradigm shift in neural network design, addressing fundamental limitations of CNNs through hierarchical part-whole relationships and equivariant representations.
Core motivation: CNNs use scalar neurons that lose spatial relationships between features. Capsules use vector neurons that explicitly encode pose (position, orientation, scale).
Key innovation (Hinton et al., 2017): Replace scalar activations with capsule vectors $\(\mathbf{v}_j = \text{squash}(\mathbf{s}_j) = \frac{\|\mathbf{s}_j\|^2}{1 + \|\mathbf{s}_j\|^2} \frac{\mathbf{s}_j}{\|\mathbf{s}_j\|}\)$
where:
\(\mathbf{v}_j \in \mathbb{R}^d\): Capsule output (length = probability, direction = properties)
\(\mathbf{s}_j\): Total input to capsule \(j\)
Squashing: Non-linear activation that preserves direction, scales length to \((0, 1)\)
Fundamental properties:
Length: Probability that entity exists (\(0 \leq \|\mathbf{v}_j\| \leq 1\))
Direction: Instantiation parameters (pose, texture, deformation)
Equivariance: Activities change smoothly with viewpoint
2. Dynamic Routing Between CapsulesΒΆ
2.1 Routing-by-Agreement AlgorithmΒΆ
Goal: Route information from lower-level capsules to higher-level capsules based on agreement.
Prediction vectors: Lower capsule \(i\) predicts properties of higher capsule \(j\) $\(\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i\)$
where:
\(\mathbf{u}_i \in \mathbb{R}^{d_{\text{in}}}\): Output of capsule \(i\) in layer \(L\)
\(\mathbf{W}_{ij} \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}\): Transformation matrix (learnable)
\(\hat{\mathbf{u}}_{j|i} \in \mathbb{R}^{d_{\text{out}}}\): Prediction (what \(i\) thinks \(j\) should be)
Routing coefficients: Softmax over all capsules in layer \(L+1\) $\(c_{ij} = \frac{\exp(b_{ij})}{\sum_{k} \exp(b_{ik})}\)$
where \(b_{ij}\) are routing logits (updated iteratively, not learned).
Weighted sum: Input to capsule \(j\) $\(\mathbf{s}_j = \sum_{i} c_{ij} \hat{\mathbf{u}}_{j|i}\)$
Agreement: Update routing logits based on dot product $\(b_{ij} \leftarrow b_{ij} + \hat{\mathbf{u}}_{j|i} \cdot \mathbf{v}_j\)$
Intuition: If prediction \(\hat{\mathbf{u}}_{j|i}\) aligns with output \(\mathbf{v}_j\), increase \(c_{ij}\) (send more information).
2.2 Routing Algorithm (Detailed)ΒΆ
Input: \(\mathbf{u}_i\) for all capsules \(i\) in layer \(L\)
Output: \(\mathbf{v}_j\) for all capsules \(j\) in layer \(L+1\)
Procedure:
1. Initialize routing logits: b_ij β 0 for all i, j
2. For r iterations:
a. Compute coupling coefficients: c_ij β softmax_j(b_ij)
b. Compute weighted input: s_j β Ξ£_i c_ij Γ»_{j|i}
c. Apply squashing: v_j β squash(s_j)
d. Update routing logits: b_ij β b_ij + Γ»_{j|i} Β· v_j
3. Return v_j
Typical: \(r = 3\) iterations (diminishing returns beyond).
2.3 Mathematical AnalysisΒΆ
Squashing function properties: $\(\text{squash}(\mathbf{s}) = \frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2} \frac{\mathbf{s}}{\|\mathbf{s}\|}\)$
Derivative (for backprop): $\(\frac{\partial \text{squash}(\mathbf{s})}{\partial \mathbf{s}} = \frac{2\|\mathbf{s}\|}{(1 + \|\mathbf{s}\|^2)^2} \mathbf{I} + \frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2} \left(\frac{1}{\|\mathbf{s}\|} \mathbf{I} - \frac{\mathbf{s}\mathbf{s}^T}{\|\mathbf{s}\|^3}\right)\)$
Length behavior:
\(\|\mathbf{s}\| \to 0\): \(\|\text{squash}(\mathbf{s})\| \approx \|\mathbf{s}\|\) (linear)
\(\|\mathbf{s}\| \to \infty\): \(\|\text{squash}(\mathbf{s})\| \to 1\) (saturates)
Routing as optimization: Iterative routing approximates EM algorithm
E-step: Compute \(c_{ij}\) (soft assignment)
M-step: Update \(\mathbf{v}_j\) (recompute means)
Connection to attention: \(c_{ij}\) similar to attention weights, but uses agreement instead of learned query/key.
3. Margin Loss for Digit ExistenceΒΆ
Motivation: Binary classification per capsule (entity present or not).
Margin loss (per class \(k\)): $\(\mathcal{L}_k = T_k \max(0, m^+ - \|\mathbf{v}_k\|)^2 + \lambda (1 - T_k) \max(0, \|\mathbf{v}_k\| - m^-)^2\)$
where:
\(T_k \in \{0, 1\}\): Ground truth (1 if class \(k\) present)
\(m^+ = 0.9\): Upper margin (capsule should be long if class present)
\(m^- = 0.1\): Lower margin (capsule should be short if class absent)
\(\lambda = 0.5\): Down-weighting for absent classes (prevent all capsules collapsing)
Interpretation:
If \(T_k = 1\): Penalize \(\|\mathbf{v}_k\| < 0.9\)
If \(T_k = 0\): Penalize \(\|\mathbf{v}_k\| > 0.1\) (but with weight \(0.5\))
Total loss: $\(\mathcal{L}_{\text{margin}} = \sum_{k=1}^K \mathcal{L}_k\)$
Multi-label: Allow multiple classes simultaneously (unlike softmax).
4. Reconstruction RegularizationΒΆ
Motivation: Encourage capsules to learn meaningful representations by reconstructing input.
Decoder: Fully connected network $\(\mathbf{x}_{\text{recon}} = \text{Decoder}(\mathbf{v}_{\text{class}})\)$
where \(\mathbf{v}_{\text{class}}\) is the capsule vector for the true class (masked during training).
Reconstruction loss: $\(\mathcal{L}_{\text{recon}} = \|\mathbf{x} - \mathbf{x}_{\text{recon}}\|^2\)$
Total loss: $\(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{margin}} + \alpha \mathcal{L}_{\text{recon}}\)$
where \(\alpha = 0.0005\) (small weight to avoid dominating margin loss).
Benefits:
Regularization (prevent overfitting)
Interpretability (manipulate capsule dimensions to see effect)
Ensures capsule encodes all information about entity
5. CapsNet Architecture (Original)ΒΆ
5.1 Architecture OverviewΒΆ
Layer 1: Convolutional Layer
Input: \(28 \times 28\) grayscale image
Filters: 256 filters, \(9 \times 9\) kernel, stride 1, ReLU
Output: \(20 \times 20 \times 256\)
Layer 2: PrimaryCaps
32 primary capsules, each 8D
Convolution: \(9 \times 9\) kernel, stride 2
Reshape: \(6 \times 6 \times 32 \times 8\) β \(1152 \times 8\) (flatten spatial)
Output: 1152 capsules, each 8D
Layer 3: DigitCaps
10 digit capsules (one per class), each 16D
Routing: Dynamic routing from PrimaryCaps to DigitCaps
Total connections: \(1152 \times 10\) transformation matrices \(\mathbf{W}_{ij} \in \mathbb{R}^{16 \times 8}\)
Output: 10 capsules, each 16D
Layer 4: Decoder (Reconstruction)
Mask: Select capsule for true class
FC layers: 16 β 512 β 1024 β 784 (28Γ28)
Activation: ReLU (except output sigmoid)
5.2 Parameter CountΒΆ
Conv1: \((9 \times 9 \times 1 + 1) \times 256 = 20,992\)
PrimaryCaps:
Each of 32 capsules has 8 conv filters
\((9 \times 9 \times 256 + 1) \times 8 \times 32 = 5,308,672\)
DigitCaps (routing weights):
\(1152 \times 10 \times (8 \times 16) = 1,474,560\)
Decoder:
\(16 \times 512 = 8,192\)
\(512 \times 1024 = 524,288\)
\(1024 \times 784 = 802,816\)
Total: ~8.2M parameters (comparable to small CNN).
6. Improvements and VariantsΒΆ
6.1 EM Routing (Hinton et al., 2018)ΒΆ
Motivation: Replace iterative routing with EM algorithm for Gaussian mixture models.
Capsule as Gaussian: Capsule \(j\) represents \(\mathcal{N}(\boldsymbol{\mu}_j, \boldsymbol{\sigma}_j)\)
E-step: Compute assignment probabilities $\(r_{ij} = \frac{a_i p(\mathbf{v}_i | j)}{\sum_k a_k p(\mathbf{v}_i | k)}\)$
where:
\(a_i\): Activation probability of capsule \(i\)
\(p(\mathbf{v}_i | j)\): Gaussian likelihood
M-step: Update mean and variance $\(\boldsymbol{\mu}_j = \frac{\sum_i r_{ij} \mathbf{v}_i}{\sum_i r_{ij}}, \quad \boldsymbol{\sigma}_j^2 = \frac{\sum_i r_{ij} (\mathbf{v}_i - \boldsymbol{\mu}_j)^2}{\sum_i r_{ij}}\)$
Activation: Based on total probability mass $\(a_j = \text{logistic}(\lambda(\beta_a - \sum_h (\beta_h + \log \sigma_h)))\)$
Advantages:
Probabilistic framework
Better gradient flow
Faster convergence
Results: 45% fewer parameters, SOTA on smallNORB (2.7% error).
6.2 Self-Attention RoutingΒΆ
Key idea: Replace iterative routing with single-pass self-attention.
Attention weights: $\(c_{ij} = \frac{\exp(\mathbf{q}_j^T \mathbf{k}_i / \sqrt{d})}{\sum_{k} \exp(\mathbf{q}_j^T \mathbf{k}_k / \sqrt{d})}\)$
where:
\(\mathbf{q}_j = \mathbf{W}_Q \mathbf{v}_j\): Query from higher capsule
\(\mathbf{k}_i = \mathbf{W}_K \mathbf{u}_i\): Key from lower capsule
Advantages:
Parallelizable (no iterative updates)
Compatible with Transformers
Faster inference
6.3 Stacked Capsule Autoencoders (SCAE)ΒΆ
Part-Capsules: Encode object parts (position, presence)
Object-Capsules: Encode whole objects (pose, presence)
Set Transformer: Route parts to objects via permutation-invariant attention $\(\mathbf{v}_{\text{obj}} = \text{SetTransformer}(\{\mathbf{v}_{\text{part}}^{(i)}\}_{i=1}^N)\)$
Unsupervised learning: Constellation prior (object = set of parts)
Results: Better unsupervised segmentation, interpretable representations.
6.4 Capsule Graph Neural NetworksΒΆ
Graph structure: Capsules as nodes, routing as edges
Message passing: $\(\mathbf{h}_j^{(t+1)} = \text{Update}\left(\mathbf{h}_j^{(t)}, \sum_{i \in \mathcal{N}(j)} \text{Message}(\mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)})\right)\)$
Advantages:
Flexible graph topology
Longer-range dependencies
Adaptable to irregular data (point clouds, molecules)
7. Equivariance and InvarianceΒΆ
7.1 Equivariance in Capsule NetworksΒΆ
Definition: Transformation of input β equivalent transformation of representation $\(f(T(x)) = T'(f(x))\)$
CNNs: Translation equivariant (but pooling breaks it).
CapsNets: Aim for viewpoint equivariance
Rotation of input β rotation of capsule pose parameters
Not achieved by original CapsNet (requires coordinate addition)
7.2 Coordinate Addition (Hinton, 2021)ΒΆ
Idea: Explicitly represent coordinates in capsule vectors.
Implementation:
Capsule \(= [\text{pose}, \text{activation}]\)
Pose includes \((x, y, \theta)\) coordinates
Apply affine transformations to pose
Benefits:
True equivariance (not just approximately)
Systematic generalization to novel viewpoints
7.3 Steerable CapsulesΒΆ
Group equivariance: \(G\)-equivariance (e.g., \(G = SO(2)\) for rotations)
Steerable filters: Filters transform predictably under group actions $\(T_g \star f = (T_g f) \star (T_g \psi)\)$
Application to capsules:
Capsule features are \(G\)-equivariant
Routing preserves equivariance
Results: Better rotation robustness, data efficiency.
8. Applications and ResultsΒΆ
8.1 MNISTΒΆ
Original CapsNet (Sabour et al., 2017):
Test error: 0.25% (SOTA among non-augmented)
Reconstruction: High-quality digit reconstruction
Robustness: 79% accuracy on affNIST (CNNs ~66%)
EM Routing:
Test error: 0.21%
Fewer parameters (45% reduction)
8.2 CIFAR-10ΒΆ
Challenges: More complex, natural images
Results:
CapsNet (3-layer): 10.6% error (vs 6.5% ResNet)
Deep CapsNet: 8.5% error (with augmentation)
Observation: Gap with CNNs larger for complex datasets (architectural improvements needed).
8.3 smallNORBΒΆ
Dataset: 3D object recognition, jittered stereo images
EM Routing: 2.7% error (SOTA)
CNNs: ~3-5% error
Viewpoint generalization: Superior to CNNs
Insight: Capsules excel when viewpoint variation is key.
8.4 Adversarial RobustnessΒΆ
Hinton et al. (2018): CapsNets more robust to adversarial attacks than CNNs.
White-box attacks (FGSM, PGD):
CapsNet: 65-70% accuracy under attack
CNN: 40-50% accuracy
Reason: Distributed representations less susceptible to small perturbations.
8.5 Object SegmentationΒΆ
SCAE (Kosiorek et al., 2019):
Unsupervised object discovery
Pixel-level segmentation without labels
State-of-art on CLEVR, Shapestacks
Part-whole decomposition: Natural for segmentation.
9. Theoretical AnalysisΒΆ
9.1 Routing ConvergenceΒΆ
Theorem (informal): Dynamic routing converges to local optimum of clustering objective.
Proof sketch:
Define energy: \(E = -\sum_{ij} c_{ij} \hat{\mathbf{u}}_{j|i} \cdot \mathbf{v}_j\)
Each routing iteration decreases \(E\)
Converges to local minimum (depends on initialization)
Empirical: 3 iterations sufficient (diminishing improvements).
9.2 Expressive PowerΒΆ
Theorem: Capsule networks are universal approximators (with sufficient capsules and dimensions).
Comparison to CNNs:
CNNs: Spatial pooling loses information
CapsNets: Preserve pose information (more expressive)
Caveat: Practical expressiveness depends on routing quality.
9.3 Sample ComplexityΒΆ
Hypothesis: Capsules reduce sample complexity by exploiting part-whole structure.
Evidence:
Better few-shot learning (Meta-CapsNet)
Superior generalization on smallNORB (3D geometry)
Theoretical gap: Formal bounds lacking (open problem).
10. Computational ComplexityΒΆ
10.1 Time ComplexityΒΆ
Forward pass:
Conv layers: \(O(K^2 C_{\text{in}} C_{\text{out}} H W)\) (same as CNN)
Prediction: \(O(N_L N_{L+1} d_{\text{in}} d_{\text{out}})\) (matrix multiplications)
Routing: \(O(r N_L N_{L+1} d_{\text{out}})\) (iterative updates, \(r\) iterations)
Total: \(O(r N_L N_{L+1} d_{\text{out}})\) dominates (quadratic in number of capsules).
Comparison:
CNN: \(O(K^2 C H W)\) per layer
CapsNet: \(O(r N^2 d)\) for routing
Bottleneck: All-to-all routing (needs sparsification for large networks).
10.2 Space ComplexityΒΆ
Parameters:
Transformation matrices: \(N_L \times N_{L+1} \times d_{\text{in}} \times d_{\text{out}}\)
Activations (during training):
Capsule outputs: \(N \times d\) per layer
Routing coefficients: \(N_L \times N_{L+1}\) (not stored with gradient checkpointing)
Comparison: 2-5Γ more parameters than equivalent CNN (due to transformation matrices).
10.3 Optimization ChallengesΒΆ
Gradient flow: Routing iterations create deep computational graph
Backprop through routing: \(O(r)\) depth
Can cause vanishing/exploding gradients
Solutions:
Gradient clipping
Careful initialization
Shorter routing (1-3 iterations)
11. Limitations and Open ProblemsΒΆ
11.1 ScalabilityΒΆ
Problem: Quadratic cost in number of capsules limits scaling.
Current: Works well for <10K capsules, struggles beyond.
Solutions:
Sparse routing (route only to subset)
Hierarchical capsules (local routing)
Attention-based routing (single-pass)
11.2 Deep CapsNet ArchitecturesΒΆ
Challenge: Stacking many capsule layers unstable.
Issues:
Gradient flow through multiple routing layers
Increased computational cost
Diminishing returns (benefits plateau)
State-of-art: 3-5 capsule layers (vs 100+ for ResNets).
11.3 Lack of Theoretical UnderstandingΒΆ
Open questions:
Why does routing-by-agreement work?
Optimal number of routing iterations?
Formal expressiveness compared to CNNs?
Sample complexity bounds?
Provable robustness guarantees?
Current: Mostly empirical understanding.
11.4 Performance Gap on Complex DatasetsΒΆ
ImageNet: CapsNets lag behind SOTA CNNs/Transformers.
Reasons:
Architectural maturity (CNNs: 30+ years, CapsNets: <10)
Computational cost limits depth
Engineering optimizations (fewer for CapsNets)
Hope: Architectural innovations (e.g., Transformer-CapsNet hybrids) may close gap.
12. Recent Advances (2020-2024)ΒΆ
12.1 Efficient CapsNetsΒΆ
Depthwise Capsules: Reduce parameters via depthwise-separable idea
Factorize \(\mathbf{W}_{ij}\) into smaller matrices
3-5Γ parameter reduction, similar accuracy
Inverted Capsules: Inspired by MobileNets
Expand β route β project
Better parameter efficiency
12.2 Vision Transformers + CapsulesΒΆ
Capsule Transformers: Replace self-attention with capsule routing
Query/Key/Value β capsule predictions
Routing β attention weights
Benefits:
Part-whole relationships in Transformers
Better interpretability
Results: Competitive on ImageNet (75-78% top-1).
12.3 Equivariant CapsulesΒΆ
E(n)-equivariant capsules: Equivariant to Euclidean group
Rotation, translation, reflection equivariance
Applications: 3D point clouds, molecular graphs
Tensor Field Networks: Capsules with tensor features
Higher-order equivariance (beyond vectors)
12.4 Capsules for NLPΒΆ
Text Capsules: Route word embeddings to sentence capsules
Applications:
Sentiment analysis (better than RNNs on some datasets)
Intent detection
Multi-label text classification
Challenges: Less impact than in vision (Transformers dominate).
13. Implementation Best PracticesΒΆ
13.1 InitializationΒΆ
Routing logits: Initialize to zero (\(b_{ij} = 0\))
Transformation matrices: Xavier/He initialization $\(\mathbf{W}_{ij} \sim \mathcal{N}\left(0, \frac{2}{d_{\text{in}} + d_{\text{out}}}\right)\)$
Decoder: Standard FC initialization
13.2 Training HyperparametersΒΆ
Optimizer: Adam with \(\beta_1 = 0.9\), \(\beta_2 = 0.999\)
Learning rate:
Initial: \(10^{-3}\)
Decay: Exponential decay (0.96 every epoch) or cosine annealing
Batch size: 128 (larger if memory permits)
Routing iterations: 3 (default, tune if needed)
Reconstruction weight: \(\alpha = 0.0005\) (small to avoid dominating)
Margin loss:
\(m^+ = 0.9\)
\(m^- = 0.1\)
\(\lambda = 0.5\)
13.3 Debugging TipsΒΆ
Check capsule lengths: Should be in \((0, 1)\) after squashing.
Monitor routing coefficients: Should converge (stabilize) after 2-3 iterations.
Reconstruction quality: Decoder should produce recognizable images (qualitative check).
Gradient norms: Clip if exploding (max norm 5-10).
13.4 RegularizationΒΆ
Dropout: On capsule outputs (not within routing)
Rate: 0.2-0.3
Weight decay: L2 regularization
Coefficient: \(10^{-4}\)
Data augmentation: Random crops, flips (standard CV)
14. Comparison with Other ArchitecturesΒΆ
14.1 CapsNets vs CNNsΒΆ
Aspect |
CapsNets |
CNNs |
|---|---|---|
Neuron type |
Vector (pose + probability) |
Scalar (activation) |
Pooling |
No pooling (routing) |
Max/Avg pooling (info loss) |
Equivariance |
Viewpoint (intended) |
Translation |
Invariance |
Learned (via routing) |
Built-in (pooling) |
Parameters |
More (transformation matrices) |
Fewer (shared filters) |
Interpretability |
High (capsule = entity) |
Low (features opaque) |
Adversarial robustness |
Better |
Worse |
Scalability |
Limited (quadratic routing) |
Excellent |
Performance (ImageNet) |
Moderate (75-78%) |
SOTA (85-90%) |
Recommendation: CapsNets for interpretability, viewpoint robustness; CNNs for general-purpose SOTA.
14.2 CapsNets vs TransformersΒΆ
Aspect |
CapsNets |
Transformers |
|---|---|---|
Routing |
Agreement-based |
Attention-based |
Complexity |
\(O(N^2 d)\) |
\(O(N^2 d)\) |
Position encoding |
Implicit (capsule pose) |
Explicit (positional) |
Part-whole |
Native |
Requires design |
Interpretability |
High |
Moderate (attention maps) |
Scalability |
Limited |
Excellent (efficient variants) |
Performance |
Moderate |
SOTA |
Synergy: Capsule-Transformer hybrids combine strengths.
14.3 CapsNets vs Graph Neural NetworksΒΆ
Aspect |
CapsNets |
GNNs |
|---|---|---|
Structure |
Fixed layers |
Arbitrary graphs |
Routing |
Dynamic |
Message passing |
Applications |
Vision (primarily) |
Graphs, molecules, social |
Equivariance |
Viewpoint |
Graph isomorphism |
Overlap: Capsule GNNs merge both paradigms.
15. Summary and Future DirectionsΒΆ
15.1 Key TakeawaysΒΆ
Innovations:
Vector neurons: Encode entity properties (pose, probability)
Dynamic routing: Route by agreement (not max-pooling)
Margin loss: Multi-label classification
Reconstruction: Regularization + interpretability
Advantages:
Better viewpoint robustness (smallNORB)
Adversarial robustness (65-70% under attack)
Interpretability (capsule dimensions manipulable)
Part-whole relationships (natural for segmentation)
Challenges:
Scalability (quadratic routing)
Performance gap on complex datasets (ImageNet)
Deep architectures unstable
Theoretical understanding limited
15.2 Open ProblemsΒΆ
Efficient routing: Reduce \(O(N^2)\) complexity
Deep CapsNets: Stable training for 10+ layers
Theoretical foundations: Sample complexity, expressiveness bounds
Large-scale datasets: Close gap with CNNs/Transformers on ImageNet
Equivariance: Provable viewpoint equivariance
Hybrid architectures: Capsules + Transformers/GNNs
15.3 Future DirectionsΒΆ
Technical:
Sparse routing (attend to subset)
Hierarchical capsules (tree structure)
Learned routing algorithms (meta-learning)
Continuous capsules (neural ODEs)
Applications:
3D vision (point clouds, mesh processing)
Medical imaging (part-based anatomy)
Robotics (viewpoint-invariant perception)
Scientific discovery (molecule generation)
Theoretical:
Formalize routing as optimization
Prove expressiveness theorems
Sample complexity bounds
Connection to causality (part-whole = causal structure)
15.4 Recommended Use CasesΒΆ
Use CapsNets when:
Viewpoint variation is critical (3D objects)
Interpretability required (medical, safety-critical)
Part-whole structure important (segmentation)
Adversarial robustness needed
Small-to-medium datasets (few-shot learning)
Avoid when:
Need SOTA on ImageNet (use CNNs/ViTs)
Real-time inference critical (routing overhead)
Large-scale (10M+ images) datasets
Resources limited (CapsNets more expensive)
16. ConclusionΒΆ
Capsule Networks represent a conceptually elegant alternative to CNNs, addressing fundamental issues:
Spatial relationships: Preserved via vector representations
Viewpoint robustness: Improved generalization across views
Interpretability: Capsule dimensions encode meaningful properties
Current state (2024):
Proven on small-to-medium datasets (MNIST, smallNORB)
Architectural innovations ongoing (Transformers, equivariance)
Performance gap with CNNs narrowing but not closed
Verdict: Capsule Networks are a promising research direction with practical applications in specialized domains. Not yet ready to replace CNNs/Transformers for general-purpose vision, but offer unique advantages for problems where part-whole relationships and viewpoint equivariance are critical.
Hintonβs vision: βCNNs are doomedβ (because they discard spatial info). Whether capsules fulfill this prophecy remains an open question, but they have undeniably enriched our understanding of neural representations.
"""
Advanced Capsule Networks - Production Implementation
Comprehensive PyTorch implementations with dynamic routing and modern variants
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List
# ===========================
# 1. Squashing Function
# ===========================
def squash(s: torch.Tensor, dim: int = -1, epsilon: float = 1e-8) -> torch.Tensor:
"""
Squashing non-linearity: v = ||s||Β² / (1 + ||s||Β²) * s / ||s||
Args:
s: Input tensor (B, ..., D)
dim: Dimension to compute norm over
epsilon: Small constant for numerical stability
Returns:
Squashed tensor with same shape as input
"""
squared_norm = torch.sum(s ** 2, dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
unit = s / torch.sqrt(squared_norm + epsilon)
return scale * unit
# ===========================
# 2. Dynamic Routing
# ===========================
class DynamicRouting(nn.Module):
"""
Dynamic routing by agreement (Sabour et al., 2017)
Routes information from lower capsules to higher capsules
"""
def __init__(self,
in_capsules: int,
out_capsules: int,
in_dim: int,
out_dim: int,
num_iterations: int = 3):
super().__init__()
self.in_capsules = in_capsules
self.out_capsules = out_capsules
self.in_dim = in_dim
self.out_dim = out_dim
self.num_iterations = num_iterations
# Transformation matrices W_ij: (in_capsules, out_capsules, out_dim, in_dim)
self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
nn.init.kaiming_normal_(self.W)
def forward(self, u: torch.Tensor) -> torch.Tensor:
"""
Args:
u: (batch_size, in_capsules, in_dim) - lower-level capsule outputs
Returns:
v: (batch_size, out_capsules, out_dim) - higher-level capsule outputs
"""
batch_size = u.size(0)
# Compute predictions Γ»_{j|i} = W_ij @ u_i
# u: (B, in_caps, in_dim) -> (B, in_caps, 1, in_dim, 1)
# W: (in_caps, out_caps, out_dim, in_dim) -> (1, in_caps, out_caps, out_dim, in_dim)
# Γ»: (B, in_caps, out_caps, out_dim)
u_expanded = u[:, :, None, :, None] # (B, in_caps, 1, in_dim, 1)
W_expanded = self.W[None, :, :, :, :] # (1, in_caps, out_caps, out_dim, in_dim)
u_hat = torch.matmul(W_expanded, u_expanded).squeeze(-1) # (B, in_caps, out_caps, out_dim)
# Initialize routing logits b_ij to zero
b = torch.zeros(batch_size, self.in_capsules, self.out_capsules,
device=u.device) # (B, in_caps, out_caps)
# Iterative routing
for iteration in range(self.num_iterations):
# Softmax over output capsules: c_ij = softmax_j(b_ij)
c = F.softmax(b, dim=2) # (B, in_caps, out_caps)
# Weighted sum: s_j = Ξ£_i c_ij Γ»_{j|i}
c_expanded = c[:, :, :, None] # (B, in_caps, out_caps, 1)
s = torch.sum(c_expanded * u_hat, dim=1) # (B, out_caps, out_dim)
# Squashing: v_j = squash(s_j)
v = squash(s, dim=-1) # (B, out_caps, out_dim)
# Update routing logits: b_ij += Γ»_{j|i} Β· v_j (except last iteration)
if iteration < self.num_iterations - 1:
v_expanded = v[:, None, :, :] # (B, 1, out_caps, out_dim)
agreement = torch.sum(u_hat * v_expanded, dim=-1) # (B, in_caps, out_caps)
b = b + agreement
return v
# ===========================
# 3. Primary Capsules
# ===========================
class PrimaryCaps(nn.Module):
"""
Primary capsule layer: convolutional layer reshaped into capsules
"""
def __init__(self,
in_channels: int,
num_capsules: int,
capsule_dim: int,
kernel_size: int = 9,
stride: int = 2):
super().__init__()
self.num_capsules = num_capsules
self.capsule_dim = capsule_dim
# Convolutional layer: output has num_capsules * capsule_dim channels
self.conv = nn.Conv2d(in_channels,
num_capsules * capsule_dim,
kernel_size=kernel_size,
stride=stride,
padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) input feature map
Returns:
capsules: (B, num_capsules * H' * W', capsule_dim)
"""
# Convolution
out = self.conv(x) # (B, num_capsules * capsule_dim, H', W')
batch_size, _, height, width = out.shape
# Reshape to capsules: (B, num_capsules, capsule_dim, H', W')
out = out.view(batch_size, self.num_capsules, self.capsule_dim, height, width)
# Flatten spatial dimensions: (B, num_capsules * H' * W', capsule_dim)
out = out.permute(0, 1, 3, 4, 2).contiguous()
out = out.view(batch_size, -1, self.capsule_dim)
# Squash
return squash(out, dim=-1)
# ===========================
# 4. Digit Capsules (with Routing)
# ===========================
class DigitCaps(nn.Module):
"""
Digit capsule layer with dynamic routing
"""
def __init__(self,
in_capsules: int,
in_dim: int,
out_capsules: int,
out_dim: int,
num_iterations: int = 3):
super().__init__()
self.routing = DynamicRouting(in_capsules, out_capsules, in_dim, out_dim, num_iterations)
def forward(self, u: torch.Tensor) -> torch.Tensor:
"""
Args:
u: (B, in_capsules, in_dim)
Returns:
v: (B, out_capsules, out_dim)
"""
return self.routing(u)
# ===========================
# 5. Reconstruction Decoder
# ===========================
class Decoder(nn.Module):
"""
Decoder network for reconstruction regularization
"""
def __init__(self,
capsule_dim: int = 16,
num_classes: int = 10,
img_size: int = 28,
hidden_dims: List[int] = [512, 1024]):
super().__init__()
self.capsule_dim = capsule_dim
self.num_classes = num_classes
self.img_size = img_size
# Build FC layers
layers = []
in_dim = capsule_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU(inplace=True))
in_dim = hidden_dim
# Output layer
layers.append(nn.Linear(in_dim, img_size * img_size))
layers.append(nn.Sigmoid())
self.decoder = nn.Sequential(*layers)
def forward(self, v: torch.Tensor, labels: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
v: (B, num_classes, capsule_dim) - digit capsule outputs
labels: (B,) - ground truth labels for masking (None during inference)
Returns:
reconstruction: (B, img_size, img_size)
"""
batch_size = v.size(0)
if labels is not None:
# Training: mask out all but the true class
mask = torch.zeros_like(v)
mask[torch.arange(batch_size), labels] = 1.0
v_masked = v * mask
else:
# Inference: use capsule with largest length
lengths = torch.sqrt(torch.sum(v ** 2, dim=-1)) # (B, num_classes)
max_idx = torch.argmax(lengths, dim=1) # (B,)
mask = torch.zeros_like(v)
mask[torch.arange(batch_size), max_idx] = 1.0
v_masked = v * mask
# Flatten: (B, num_classes * capsule_dim) -> (B, capsule_dim) after masking
v_flat = v_masked.view(batch_size, -1)
# Decode
reconstruction = self.decoder(v_flat)
reconstruction = reconstruction.view(batch_size, self.img_size, self.img_size)
return reconstruction
# ===========================
# 6. CapsNet (Complete Architecture)
# ===========================
class CapsNet(nn.Module):
"""
Complete Capsule Network architecture (Sabour et al., 2017)
"""
def __init__(self,
img_channels: int = 1,
img_size: int = 28,
num_classes: int = 10,
primary_caps: int = 32,
primary_dim: int = 8,
digit_dim: int = 16,
num_routing: int = 3):
super().__init__()
self.img_channels = img_channels
self.img_size = img_size
self.num_classes = num_classes
# Layer 1: Convolutional layer
self.conv1 = nn.Conv2d(img_channels, 256, kernel_size=9, stride=1, padding=0)
# Conv output size: (28 - 9 + 1) = 20
conv_output_size = img_size - 8
# Layer 2: Primary capsules
self.primary_caps = PrimaryCaps(256, primary_caps, primary_dim, kernel_size=9, stride=2)
# Primary caps output size: (20 - 9) / 2 + 1 = 6
primary_output_size = (conv_output_size - 8) // 2
num_primary_capsules = primary_caps * primary_output_size * primary_output_size
# Layer 3: Digit capsules
self.digit_caps = DigitCaps(num_primary_capsules, primary_dim,
num_classes, digit_dim, num_routing)
# Layer 4: Decoder
self.decoder = Decoder(digit_dim, num_classes, img_size)
def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: (B, C, H, W) input images
labels: (B,) ground truth labels (for reconstruction masking)
Returns:
digit_caps: (B, num_classes, digit_dim)
reconstruction: (B, H, W)
"""
# Conv layer
x = F.relu(self.conv1(x)) # (B, 256, 20, 20)
# Primary capsules
primary = self.primary_caps(x) # (B, 1152, 8) for MNIST
# Digit capsules
digits = self.digit_caps(primary) # (B, 10, 16)
# Reconstruction
reconstruction = self.decoder(digits, labels) # (B, 28, 28)
return digits, reconstruction
# ===========================
# 7. Margin Loss
# ===========================
class MarginLoss(nn.Module):
"""
Margin loss for multi-label classification in CapsNets
"""
def __init__(self,
m_plus: float = 0.9,
m_minus: float = 0.1,
lambda_: float = 0.5):
super().__init__()
self.m_plus = m_plus
self.m_minus = m_minus
self.lambda_ = lambda_
def forward(self, v: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Args:
v: (B, num_classes, capsule_dim) - digit capsule outputs
labels: (B,) - ground truth labels
Returns:
loss: scalar
"""
batch_size, num_classes, _ = v.shape
# Capsule lengths (probabilities)
lengths = torch.sqrt(torch.sum(v ** 2, dim=-1)) # (B, num_classes)
# One-hot encoding
targets = F.one_hot(labels, num_classes).float() # (B, num_classes)
# Margin loss per class
# L_k = T_k max(0, m+ - ||v_k||)Β² + Ξ»(1 - T_k) max(0, ||v_k|| - m-)Β²
positive_loss = targets * F.relu(self.m_plus - lengths) ** 2
negative_loss = self.lambda_ * (1 - targets) * F.relu(lengths - self.m_minus) ** 2
loss = torch.sum(positive_loss + negative_loss, dim=1)
return loss.mean()
# ===========================
# 8. CapsNet Trainer
# ===========================
class CapsNetTrainer:
"""Training utilities for CapsNet"""
def __init__(self,
model: CapsNet,
margin_loss: MarginLoss,
recon_weight: float = 0.0005,
device: str = 'cuda'):
self.model = model.to(device)
self.margin_loss = margin_loss
self.recon_weight = recon_weight
self.device = device
def train_step(self, x: torch.Tensor, labels: torch.Tensor,
optimizer: torch.optim.Optimizer) -> dict:
"""Single training step"""
self.model.train()
optimizer.zero_grad()
x = x.to(self.device)
labels = labels.to(self.device)
# Forward
digit_caps, reconstruction = self.model(x, labels)
# Margin loss
loss_margin = self.margin_loss(digit_caps, labels)
# Reconstruction loss
loss_recon = F.mse_loss(reconstruction, x.squeeze(1))
# Total loss
loss_total = loss_margin + self.recon_weight * loss_recon
# Backward
loss_total.backward()
optimizer.step()
# Accuracy
lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
pred = torch.argmax(lengths, dim=1)
acc = (pred == labels).float().mean()
return {
'loss_total': loss_total.item(),
'loss_margin': loss_margin.item(),
'loss_recon': loss_recon.item(),
'accuracy': acc.item()
}
@torch.no_grad()
def eval_step(self, x: torch.Tensor, labels: torch.Tensor) -> dict:
"""Evaluation step"""
self.model.eval()
x = x.to(self.device)
labels = labels.to(self.device)
# Forward
digit_caps, reconstruction = self.model(x, labels)
# Losses
loss_margin = self.margin_loss(digit_caps, labels)
loss_recon = F.mse_loss(reconstruction, x.squeeze(1))
loss_total = loss_margin + self.recon_weight * loss_recon
# Accuracy
lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
pred = torch.argmax(lengths, dim=1)
acc = (pred == labels).float().mean()
return {
'loss_total': loss_total.item(),
'loss_margin': loss_margin.item(),
'loss_recon': loss_recon.item(),
'accuracy': acc.item()
}
# ===========================
# 9. EM Routing (Advanced)
# ===========================
class EMRouting(nn.Module):
"""
EM routing with Gaussian mixture models (Hinton et al., 2018)
More principled than dynamic routing
"""
def __init__(self,
in_capsules: int,
out_capsules: int,
in_dim: int,
out_dim: int,
num_iterations: int = 3,
temperature: float = 1.0):
super().__init__()
self.in_capsules = in_capsules
self.out_capsules = out_capsules
self.in_dim = in_dim
self.out_dim = out_dim
self.num_iterations = num_iterations
self.temperature = temperature
# Transformation matrices
self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
nn.init.kaiming_normal_(self.W)
# Learnable parameters for Gaussian (optional)
self.beta_a = nn.Parameter(torch.zeros(1))
self.beta_u = nn.Parameter(torch.zeros(out_capsules, out_dim))
def forward(self, u: torch.Tensor, activation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
u: (B, in_capsules, in_dim) - pose vectors
activation: (B, in_capsules) - activation probabilities
Returns:
v: (B, out_capsules, out_dim) - output poses
a: (B, out_capsules) - output activations
"""
batch_size = u.size(0)
# Predictions
u_expanded = u[:, :, None, :, None]
W_expanded = self.W[None, :, :, :, :]
u_hat = torch.matmul(W_expanded, u_expanded).squeeze(-1) # (B, in_caps, out_caps, out_dim)
# Initialize routing probabilities uniformly
r = torch.ones(batch_size, self.in_capsules, self.out_capsules,
device=u.device) / self.out_capsules
# EM iterations
for iteration in range(self.num_iterations):
# M-step: Compute mean and variance
r_weighted = r * activation[:, :, None] # (B, in_caps, out_caps)
r_sum = torch.sum(r_weighted, dim=1, keepdim=True) + 1e-8 # (B, 1, out_caps)
# Mean: ΞΌ_j = Ξ£_i r_ij Γ»_{j|i} / Ξ£_i r_ij
mu = torch.sum(r_weighted[:, :, :, None] * u_hat, dim=1) / r_sum # (B, out_caps, out_dim)
# Variance: ΟΒ²_j = Ξ£_i r_ij (Γ»_{j|i} - ΞΌ_j)Β² / Ξ£_i r_ij
diff = u_hat - mu[:, None, :, :] # (B, in_caps, out_caps, out_dim)
variance = torch.sum(r_weighted[:, :, :, None] * diff ** 2, dim=1) / r_sum
variance = variance + 0.01 # Minimum variance for stability
# Activation: Based on cost (simplified)
cost = torch.sum(variance, dim=-1) # (B, out_caps)
a = torch.sigmoid(-cost + self.beta_a) # (B, out_caps)
# E-step: Update routing probabilities (except last iteration)
if iteration < self.num_iterations - 1:
# Log probability of Γ»_{j|i} under Gaussian(ΞΌ_j, ΟΒ²_j)
log_p = -0.5 * torch.sum((diff ** 2) / variance[:, None, :, :], dim=-1) # (B, in_caps, out_caps)
log_p = log_p - 0.5 * torch.sum(torch.log(variance[:, None, :, :] + 1e-8), dim=-1)
# Weighted by activation
log_p = log_p + torch.log(a[:, None, :] + 1e-8)
# Softmax
r = F.softmax(log_p / self.temperature, dim=2)
return mu, a
# ===========================
# 10. Demo Functions
# ===========================
def demo_squash():
"""Demonstrate squashing function"""
print("=" * 50)
print("Demo: Squashing Function")
print("=" * 50)
# Test squashing behavior
s = torch.tensor([[0.1, 0.2], [1.0, 2.0], [10.0, 20.0]])
v = squash(s, dim=-1)
print(f"Input vectors:")
print(s)
print(f"\nNorms: {torch.norm(s, dim=-1)}")
print(f"\nSquashed vectors:")
print(v)
print(f"Squashed norms: {torch.norm(v, dim=-1)}")
print(f"All norms in (0,1): {torch.all((torch.norm(v, dim=-1) > 0) & (torch.norm(v, dim=-1) < 1))}")
print()
def demo_dynamic_routing():
"""Demonstrate dynamic routing"""
print("=" * 50)
print("Demo: Dynamic Routing")
print("=" * 50)
routing = DynamicRouting(in_capsules=1152, out_capsules=10,
in_dim=8, out_dim=16, num_iterations=3)
u = torch.randn(4, 1152, 8) # Batch of 4, 1152 primary capsules
v = routing(u)
print(f"Input: {u.shape} (batch_size, in_capsules, in_dim)")
print(f"Output: {v.shape} (batch_size, out_capsules, out_dim)")
print(f"Capsule lengths: {torch.norm(v, dim=-1)[0]}")
print(f"Transformation matrices: {routing.W.shape}")
num_params = routing.W.numel()
print(f"Parameters: {num_params:,} ({num_params * 4 / 1024**2:.2f} MB)")
print()
def demo_capsnet_forward():
"""Demonstrate full CapsNet forward pass"""
print("=" * 50)
print("Demo: CapsNet Forward Pass")
print("=" * 50)
model = CapsNet(img_channels=1, img_size=28, num_classes=10)
x = torch.randn(4, 1, 28, 28)
labels = torch.randint(0, 10, (4,))
digit_caps, reconstruction = model(x, labels)
print(f"Input: {x.shape}")
print(f"Digit capsules: {digit_caps.shape}")
print(f"Reconstruction: {reconstruction.shape}")
# Capsule lengths (predictions)
lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
predictions = torch.argmax(lengths, dim=1)
print(f"\nCapsule lengths (probabilities):")
print(lengths[0])
print(f"Predicted classes: {predictions}")
print(f"True labels: {labels}")
# Parameter count
num_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {num_params:,} (~{num_params * 4 / 1024**2:.1f} MB)")
print()
def demo_margin_loss():
"""Demonstrate margin loss"""
print("=" * 50)
print("Demo: Margin Loss")
print("=" * 50)
loss_fn = MarginLoss(m_plus=0.9, m_minus=0.1, lambda_=0.5)
# Simulate capsule outputs
digit_caps = torch.randn(4, 10, 16)
labels = torch.tensor([3, 7, 1, 0])
loss = loss_fn(digit_caps, labels)
print(f"Digit capsules: {digit_caps.shape}")
print(f"Labels: {labels}")
print(f"Margin loss: {loss.item():.4f}")
# Show capsule lengths
lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
print(f"\nCapsule lengths:")
print(lengths)
print(f"Target class lengths: {lengths[torch.arange(4), labels]}")
print()
def demo_training_step():
"""Demonstrate training step"""
print("=" * 50)
print("Demo: Training Step")
print("=" * 50)
device = 'cpu'
model = CapsNet(img_channels=1, img_size=28, num_classes=10)
margin_loss = MarginLoss()
trainer = CapsNetTrainer(model, margin_loss, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Dummy data
x = torch.randn(8, 1, 28, 28)
labels = torch.randint(0, 10, (8,))
print("Training for 3 steps...")
for step in range(3):
metrics = trainer.train_step(x, labels, optimizer)
print(f"Step {step+1}: loss={metrics['loss_total']:.4f}, "
f"margin={metrics['loss_margin']:.4f}, "
f"recon={metrics['loss_recon']:.4f}, "
f"acc={metrics['accuracy']:.2%}")
print()
def demo_em_routing():
"""Demonstrate EM routing"""
print("=" * 50)
print("Demo: EM Routing")
print("=" * 50)
routing = EMRouting(in_capsules=1152, out_capsules=10,
in_dim=8, out_dim=16, num_iterations=3)
u = torch.randn(4, 1152, 8)
activation = torch.rand(4, 1152)
mu, a = routing(u, activation)
print(f"Input poses: {u.shape}")
print(f"Input activations: {activation.shape}")
print(f"Output poses (means): {mu.shape}")
print(f"Output activations: {a.shape}")
print(f"\nOutput activation values (sample): {a[0]}")
print()
def print_performance_comparison():
"""Comprehensive performance comparison and decision guide"""
print("=" * 80)
print("PERFORMANCE COMPARISON: Capsule Networks")
print("=" * 80)
# 1. MNIST Results
print("\n1. MNIST Results (Test Error %)")
print("-" * 80)
data = [
("Model", "Error %", "Parameters", "Notes"),
("-" * 30, "-" * 10, "-" * 12, "-" * 30),
("Baseline CNN", "0.39", "~35K", "Simple CNN baseline"),
("CapsNet (3 routing)", "0.25", "8.2M", "Original (Sabour 2017)"),
("CapsNet (1 routing)", "0.31", "8.2M", "Faster, slight quality loss"),
("EM Routing CapsNet", "0.21", "4.5M", "45% fewer params (Hinton 2018)"),
("", "", "", ""),
("SOTA (non-augmented)", "0.17", "N/A", "Ensemble methods"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<10} {row[2]:<12} {row[3]:<30}")
# 2. affNIST (Viewpoint Robustness)
print("\n2. affNIST - Viewpoint Robustness (Accuracy %)")
print("-" * 80)
data = [
("Model", "Accuracy", "Notes"),
("-" * 35, "-" * 10, "-" * 35),
("CNN (baseline)", "66%", "Poor generalization to affine transforms"),
("CNN (expanded training)", "79%", "With affine augmentation"),
("CapsNet (no augmentation)", "79%", "Naturally robust to viewpoint changes"),
("", "", ""),
("Observation:", "", "CapsNets achieve CNN+augmentation performance"),
("", "", "without needing explicit augmentation"),
]
for row in data:
print(f"{row[0]:<35} {row[1]:<10} {row[2]:<35}")
# 3. smallNORB (3D Objects)
print("\n3. smallNORB - 3D Object Recognition (Test Error %)")
print("-" * 80)
data = [
("Model", "Error %", "Notes"),
("-" * 40, "-" * 10, "-" * 35),
("CNN baseline", "5.2", ""),
("Convolutional Deep Belief Net", "4.5", ""),
("CapsNet (dynamic routing)", "3.1", "Better viewpoint invariance"),
("EM Routing CapsNet", "2.7", "SOTA (Hinton 2018)"),
("", "", ""),
("Key:", "", "3D geometry β capsules excel"),
]
for row in data:
print(f"{row[0]:<40} {row[1]:<10} {row[2]:<35}")
# 4. CIFAR-10
print("\n4. CIFAR-10 - Natural Images (Test Error %)")
print("-" * 80)
data = [
("Model", "Error %", "Parameters", "Notes"),
("-" * 30, "-" * 10, "-" * 12, "-" * 30),
("CapsNet (3-layer)", "10.6", "11.8M", "Original architecture"),
("Deep CapsNet", "8.5", "~20M", "Deeper architecture + augmentation"),
("", "", "", ""),
("ResNet-56", "6.5", "850K", "Standard CNN (fewer params)"),
("ResNet-110", "5.5", "1.7M", "Deeper CNN"),
("Wide ResNet-28-10", "3.9", "36M", "SOTA CNN"),
("", "", "", ""),
("Gap:", "", "", "CapsNets lag on complex natural images"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<10} {row[2]:<12} {row[3]:<30}")
# 5. Adversarial Robustness
print("\n5. Adversarial Robustness (Accuracy under attack %)")
print("-" * 80)
data = [
("Model", "Clean", "FGSM", "PGD", "Notes"),
("-" * 25, "-" * 8, "-" * 8, "-" * 8, "-" * 25),
("CNN (baseline)", "99%", "45%", "40%", "Vulnerable to adversarial"),
("CapsNet", "99.2%", "68%", "65%", "More robust (distributed repr)"),
("Adversarial Training CNN", "98%", "85%", "80%", "Needs explicit defense"),
("", "", "", "", ""),
("Insight:", "", "", "", "CapsNets naturally more robust"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<8} {row[2]:<8} {row[3]:<8} {row[4]:<25}")
# 6. Computational Complexity
print("\n6. Computational Complexity")
print("-" * 80)
data = [
("Operation", "Complexity", "Notes"),
("-" * 35, "-" * 25, "-" * 30),
("Conv layers", "O(KΒ²Β·CΒ·HΒ·W)", "Same as standard CNN"),
("Primary Caps", "O(KΒ²Β·CΒ·HΒ·W)", "Convolution + reshape"),
("Dynamic Routing", "O(rΒ·N_inΒ·N_outΒ·d)", "r=3 iterations typical"),
("Transformation Matrices", "O(N_inΒ·N_outΒ·dΒ²)", "Largest parameter cost"),
("", "", ""),
("Bottleneck:", "Routing (quadratic)", "Limits scalability"),
("Total (MNIST)", "~10Γ slower than CNN", "Per forward pass"),
]
for row in data:
print(f"{row[0]:<35} {row[1]:<25} {row[2]:<30}")
# 7. Training Hyperparameters
print("\n7. Recommended Training Hyperparameters")
print("-" * 80)
data = [
("Parameter", "MNIST", "CIFAR-10", "Notes"),
("-" * 25, "-" * 15, "-" * 15, "-" * 30),
("Batch size", "128", "128-256", "Larger if memory allows"),
("Learning rate", "1e-3", "1e-3", "Adam optimizer"),
("LR decay", "Exponential 0.96", "Cosine", "Every epoch or cycle"),
("", "", "", ""),
("Routing iterations", "3", "3", "Diminishing returns beyond"),
("Reconstruction weight", "0.0005", "0.0005", "Small to avoid dominance"),
("", "", "", ""),
("m+ (upper margin)", "0.9", "0.9", "Target for present classes"),
("m- (lower margin)", "0.1", "0.1", "Target for absent classes"),
("Ξ» (absent weight)", "0.5", "0.5", "Down-weight absent classes"),
("", "", "", ""),
("Primary capsules", "32 Γ 8D", "32 Γ 8D", "Standard config"),
("Digit capsules", "10 Γ 16D", "10 Γ 16D", "One per class"),
("", "", "", ""),
("Training epochs", "50-100", "200-300", "Until convergence"),
("Gradient clipping", "5.0", "5.0", "Max norm for stability"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<15} {row[2]:<15} {row[3]:<30}")
# 8. Architecture Variants
print("\n8. Capsule Network Variants")
print("-" * 80)
data = [
("Variant", "Key Innovation", "Pros", "Cons"),
("-" * 20, "-" * 30, "-" * 25, "-" * 25),
("Dynamic Routing", "Routing by agreement", "Simple, effective", "Iterative (slow)"),
("EM Routing", "Gaussian mixture model", "45% fewer params", "More complex"),
("Self-Attention", "Attention-based routing", "Parallelizable", "Less interpretable"),
("SCAE", "Unsupervised learning", "Object discovery", "Training unstable"),
("Capsule GNN", "Graph message passing", "Irregular data", "Architecture complex"),
]
for row in data:
print(f"{row[0]:<20} {row[1]:<30} {row[2]:<25} {row[3]:<25}")
# 9. Use Case Decision Guide
print("\n9. DECISION GUIDE: When to Use Capsule Networks")
print("=" * 80)
print("\nβ USE Capsule Networks When:")
advantages = [
"β’ Viewpoint variation is critical (3D objects, poses)",
"β’ Need interpretable part-whole representations",
"β’ Adversarial robustness important (intrinsically more robust)",
"β’ Dataset has limited augmentation (affine transforms)",
"β’ Object segmentation/discovery required (unsupervised)",
"β’ Small-to-medium datasets (MNIST, smallNORB)",
"β’ Multi-label classification (margin loss advantage)",
]
for adv in advantages:
print(adv)
print("\nβ AVOID Capsule Networks When:")
limitations = [
"β’ Need SOTA on ImageNet-scale datasets (CNNs/Transformers better)",
"β’ Real-time inference critical (10Γ slower than CNNs)",
"β’ Very deep architectures required (stability issues)",
"β’ Large number of capsules needed (>10K, quadratic cost)",
"β’ Limited computational budget (more expensive training)",
]
for lim in limitations:
print(lim)
print("\nβ RECOMMENDED ALTERNATIVES:")
alternatives = [
"β’ General vision β ResNet, EfficientNet, Vision Transformers",
"β’ Fast inference β MobileNets, SqueezeNet",
"β’ Adversarial robustness β Adversarial training, certified defenses",
"β’ Interpretability β Attention mechanisms, GradCAM",
]
for alt in alternatives:
print(alt)
# 10. Comparison with Other Architectures
print("\n10. Comparison: CapsNets vs CNNs vs Transformers")
print("=" * 80)
data = [
("Aspect", "CapsNets", "CNNs", "Transformers"),
("-" * 25, "-" * 20, "-" * 20, "-" * 20),
("Neuron type", "Vector (pose)", "Scalar (activation)", "Embedding vectors"),
("Spatial relations", "Preserved (routing)", "Lost (pooling)", "Positional encoding"),
("Equivariance", "Viewpoint (goal)", "Translation", "None (learned)"),
("Interpretability", "High (part-whole)", "Low", "Moderate (attention)"),
("Scalability", "Limited (NΒ²)", "Excellent", "Good (NΒ² in seq)"),
("Training stability", "Moderate", "Good", "Good (with tricks)"),
("Adversarial robust", "Better", "Worse", "Moderate"),
("Parameter count", "High", "Low-Medium", "Very High"),
("ImageNet accuracy", "75-78%", "85-90%", "88-92%"),
("", "", "", ""),
("Best for:", "3D, viewpoint", "General vision", "Large-scale, NLP"),
]
for row in data:
print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
# 11. Troubleshooting
print("\n11. TROUBLESHOOTING COMMON ISSUES")
print("=" * 80)
data = [
("Problem", "Possible Cause", "Solution"),
("-" * 30, "-" * 35, "-" * 40),
("Poor convergence", "Learning rate too high", "Reduce LR to 1e-4 or 1e-5"),
("", "Routing iterations wrong", "Try 1-3 iterations"),
("", "", ""),
("Exploding gradients", "Routing depth", "Gradient clipping (max_norm=5)"),
("", "Bad initialization", "Kaiming/Xavier init"),
("", "", ""),
("All capsules same", "Reconstruction weight too high", "Reduce Ξ± to 0.0005 or lower"),
("", "Margin loss not working", "Check m+, m-, Ξ» values"),
("", "", ""),
("Slow training", "Too many routing iters", "Use 1-2 iterations (faster)"),
("", "Large model", "Reduce primary caps or dims"),
("", "", ""),
("Blurry reconstruction", "Decoder too small", "Add hidden layers"),
("", "Weight too small", "Increase Ξ± to 0.001"),
("", "", ""),
("Low accuracy", "Insufficient training", "Train longer (50-100 epochs)"),
("", "Poor architecture", "Try EM routing or variants"),
]
for row in data:
print(f"{row[0]:<30} {row[1]:<35} {row[2]:<40}")
print("\n" + "=" * 80)
print("Summary: Capsule Networks offer interpretable part-whole representations")
print("and viewpoint robustness, but lag on complex datasets. Best for 3D objects,")
print("adversarial robustness, and when interpretability is critical.")
print("=" * 80)
print()
# ===========================
# Run All Demos
# ===========================
if __name__ == "__main__":
print("\n" + "=" * 80)
print("CAPSULE NETWORKS - COMPREHENSIVE IMPLEMENTATION")
print("=" * 80 + "\n")
demo_squash()
demo_dynamic_routing()
demo_capsnet_forward()
demo_margin_loss()
demo_training_step()
demo_em_routing()
print_performance_comparison()
print("\n" + "=" * 80)
print("All demos completed successfully!")
print("=" * 80)