import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Capsule Networks ConceptΒΆ

Capsule:ΒΆ

Group of neurons representing:

  • Length: probability entity exists

  • Orientation: instantiation parameters

Squashing Function:ΒΆ

\[v_j = \frac{\|s_j\|^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}\]

Keeps direction, squashes magnitude to [0,1].

πŸ“š Reference Materials:

def squash(s, dim=-1):
    """Squashing non-linearity."""
    squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * s / torch.sqrt(squared_norm + 1e-8)

# Test squashing
s = torch.randn(10, 16)
v = squash(s)
norms = torch.norm(v, dim=1)
print(f"Squashed norms - min: {norms.min():.3f}, max: {norms.max():.3f}")

2. Dynamic RoutingΒΆ

Routing by Agreement:ΒΆ

\[s_j = \sum_i c_{ij} \hat{u}_{j|i}\]

where \(\hat{u}_{j|i} = W_{ij} u_i\) is prediction vector.

Update Rule:ΒΆ

\[c_{ij} = \frac{\exp(b_{ij})}{\sum_k \exp(b_{ik})}\]
\[b_{ij} \leftarrow b_{ij} + \hat{u}_{j|i} \cdot v_j\]
class DynamicRouting(nn.Module):
    """Dynamic routing between capsule layers."""
    
    def __init__(self, in_caps, out_caps, in_dim, out_dim, num_iterations=3):
        super().__init__()
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.num_iterations = num_iterations
        
        # Transformation matrix for each capsule pair
        self.W = nn.Parameter(torch.randn(in_caps, out_caps, out_dim, in_dim))
    
    def forward(self, u):
        """
        Args:
            u: (batch, in_caps, in_dim)
        Returns:
            v: (batch, out_caps, out_dim)
        """
        batch_size = u.size(0)
        
        # Compute predictions
        u_expand = u.unsqueeze(2).unsqueeze(4)  # (batch, in_caps, 1, in_dim, 1)
        W_expand = self.W.unsqueeze(0)  # (1, in_caps, out_caps, out_dim, in_dim)
        
        # u_hat = W @ u for each capsule pair
        u_hat = torch.matmul(W_expand, u_expand)  # (batch, in_caps, out_caps, out_dim, 1)
        u_hat = u_hat.squeeze(-1)  # (batch, in_caps, out_caps, out_dim)
        
        # Initialize routing logits
        b = torch.zeros(batch_size, self.in_caps, self.out_caps, 1).to(u.device)
        
        # Routing iterations
        for iteration in range(self.num_iterations):
            # Coupling coefficients
            c = F.softmax(b, dim=2)  # (batch, in_caps, out_caps, 1)
            
            # Weighted sum
            s = (c * u_hat).sum(dim=1)  # (batch, out_caps, out_dim)
            
            # Squashing
            v = squash(s, dim=-1)  # (batch, out_caps, out_dim)
            
            # Update routing logits (except last iteration)
            if iteration < self.num_iterations - 1:
                # Agreement: dot product
                agreement = torch.matmul(u_hat, v.unsqueeze(-1))  # (batch, in_caps, out_caps, 1)
                b = b + agreement
        
        return v

print("DynamicRouting defined")

Primary CapsulesΒΆ

Primary capsules form the first capsule layer, converting traditional convolutional features into capsule format. Instead of scalar activations, each capsule outputs a vector whose length represents the probability that a particular entity exists and whose orientation encodes the entity’s properties (pose, deformation, texture). The squashing function \(v_j = \frac{\|s_j\|^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}\) ensures the vector length stays between 0 and 1 (interpretable as a probability) while preserving direction. Primary capsules detect basic visual patterns (edges, textures, simple shapes) and pass their pose information up to higher-level capsules via dynamic routing.

class PrimaryCaps(nn.Module):
    """Primary capsule layer."""
    
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride):
        super().__init__()
        self.dim_caps = dim_caps
        self.num_caps = out_channels // dim_caps
        
        # Convolutional layer
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_channels, H, W)
        Returns:
            capsules: (batch, num_caps, dim_caps)
        """
        # Convolution
        out = self.conv(x)  # (batch, out_channels, H', W')
        
        batch_size = out.size(0)
        
        # Reshape into capsules
        out = out.view(batch_size, self.num_caps, self.dim_caps, -1)
        out = out.view(batch_size, self.num_caps * out.size(-1), self.dim_caps)
        
        # Squashing
        return squash(out, dim=-1)

print("PrimaryCaps defined")

CapsNet ArchitectureΒΆ

The full CapsNet architecture consists of: (1) a standard convolutional layer for initial feature extraction, (2) primary capsules that convert features to capsule vectors, and (3) digit capsules (or class capsules) connected via dynamic routing. Dynamic routing iteratively refines the coupling coefficients between lower-level and higher-level capsules: capsules that agree on the predicted pose of a higher-level entity strengthen their connection, while disagreeing capsules weaken theirs. The length of each digit capsule vector gives the predicted probability for that class, and a reconstruction network can decode the capsule vector back to an image, serving as a regularizer.

class CapsNet(nn.Module):
    """Capsule Network for MNIST."""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Initial conv layer
        self.conv1 = nn.Conv2d(1, 256, 9, stride=1)
        
        # Primary capsules
        self.primary_caps = PrimaryCaps(
            in_channels=256,
            out_channels=256,
            dim_caps=8,
            kernel_size=9,
            stride=2
        )
        
        # Digit capsules
        self.digit_caps = DynamicRouting(
            in_caps=32 * 6 * 6,  # Primary caps
            out_caps=num_classes,
            in_dim=8,
            out_dim=16,
            num_iterations=3
        )
        
        # Decoder for regularization
        self.decoder = nn.Sequential(
            nn.Linear(16 * num_classes, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
    
    def forward(self, x, y=None):
        # Conv1
        x = F.relu(self.conv1(x))
        
        # Primary capsules
        x = self.primary_caps(x)
        
        # Digit capsules
        digit_caps = self.digit_caps(x)
        
        # Length as class scores
        lengths = torch.sqrt((digit_caps ** 2).sum(dim=-1))
        
        # Reconstruction
        if y is None:
            # Use predicted class
            index = lengths.argmax(dim=1)
        else:
            # Use true class
            index = y
        
        # Mask
        mask = torch.zeros_like(digit_caps)
        mask[torch.arange(digit_caps.size(0)), index] = 1
        
        # Decode
        masked = (digit_caps * mask).view(digit_caps.size(0), -1)
        reconstruction = self.decoder(masked)
        
        return lengths, reconstruction

print("CapsNet defined")

5. Margin LossΒΆ

\[L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\| - m^-)^2\]

where \(T_k=1\) if class \(k\) present, \(m^+=0.9\), \(m^-=0.1\).

def margin_loss(lengths, labels, m_plus=0.9, m_minus=0.1, lambda_=0.5):
    """Margin loss for capsule networks."""
    # One-hot labels
    T = F.one_hot(labels, num_classes=10).float()
    
    # Loss for present classes
    loss_present = T * F.relu(m_plus - lengths) ** 2
    
    # Loss for absent classes
    loss_absent = (1 - T) * F.relu(lengths - m_minus) ** 2
    
    # Total
    loss = loss_present + lambda_ * loss_absent
    return loss.sum(dim=1).mean()

print("Margin loss defined")

Train CapsNetΒΆ

CapsNet training uses a margin loss per class: \(L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\| - m^-)^2\), where \(T_k = 1\) if class \(k\) is present and \(m^+, m^-\) are the positive and negative margins. An optional reconstruction loss (MSE between the original image and the reconstruction from the correct class capsule) encourages the capsule vectors to encode meaningful pose information. The total loss balances classification accuracy with reconstruction quality, with a small weight on the reconstruction term to prevent it from dominating.

# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

# Model
model = CapsNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        
        # Forward
        lengths, reconstruction = model(x, y)
        
        # Margin loss
        loss_margin = margin_loss(lengths, y)
        
        # Reconstruction loss
        loss_recon = F.mse_loss(reconstruction, x.view(x.size(0), -1))
        
        # Total loss
        loss = loss_margin + 0.0005 * loss_recon
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

# Train
losses = []
for epoch in range(5):
    loss = train_epoch(model, train_loader)
    losses.append(loss)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

Evaluate and VisualizeΒΆ

Beyond classification accuracy, CapsNet evaluation includes dimension perturbation analysis: varying each dimension of the class capsule vector independently and observing the effect on the reconstructed image. Each dimension should control a distinct visual attribute (stroke width, skew, translation, etc.), demonstrating that capsules learn disentangled representations of pose and appearance. This interpretability is a key advantage over standard CNNs, where individual neuron activations rarely correspond to single, interpretable factors of variation.

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        lengths, _ = model(x)
        pred = lengths.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

print(f"\nTest Accuracy: {100 * correct / total:.2f}%")

# Visualize reconstructions
x_test, y_test = next(iter(test_loader))
x_test = x_test[:8].to(device)
y_test = y_test[:8].to(device)

with torch.no_grad():
    _, recon = model(x_test, y_test)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    # Original
    axes[0, i].imshow(x_test[i].cpu().squeeze(), cmap='gray')
    axes[0, i].axis('off')
    
    # Reconstruction
    axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Original', fontsize=11)
axes[1, 0].set_ylabel('Reconstructed', fontsize=11)
plt.suptitle('CapsNet Reconstructions', fontsize=12)
plt.tight_layout()
plt.show()

SummaryΒΆ

Capsule Networks:ΒΆ

Key Ideas:

  1. Capsules = groups of neurons

  2. Length = existence probability

  3. Orientation = instantiation parameters

  4. Dynamic routing by agreement

Advantages:

  • Viewpoint equivariance

  • Part-whole relationships

  • Better generalization to novel viewpoints

  • Interpretable representations

Challenges:

  • Computationally expensive

  • Harder to scale

  • Limited to small images

Applications:ΒΆ

  • Overlapping digits

  • 3D object recognition

  • Medical imaging

  • Video understanding

Advanced Capsule Networks TheoryΒΆ

1. Mathematical FoundationsΒΆ

Capsule RepresentationΒΆ

Definition: A capsule is a group of neurons whose activity vector represents instantiation parameters of an entity.

Capsule Output: v_j ∈ β„α΅ˆ where:

  • Magnitude ||v_j||: Probability entity exists (0 to 1)

  • Direction: Instantiation parameters (pose, deformation, texture)

Key Difference from CNNs:

  • CNN scalar neuron: Binary feature detection

  • Capsule vector: Rich entity representation with properties

Squashing Non-linearityΒΆ

Function: Maps any vector to unit length while preserving direction:

\[v_j = \frac{\|s_j\|^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}\]

where s_j is total input to capsule j.

Properties:

  1. ||v_j|| ∈ [0, 1): Short vectors β†’ near 0, long vectors β†’ near 1

  2. Direction preserved: v_j βˆ₯ s_j

  3. Differentiable: Smooth for gradient descent

Gradient:

\[\frac{\partial v_j}{\partial s_j} = \frac{1}{(1 + \|s_j\|^2)^2} \left[ (1 + \|s_j\|^2)I - \frac{2s_j s_j^T}{\|s_j\|} \right]\]

2. Dynamic Routing by AgreementΒΆ

Routing AlgorithmΒΆ

Intuition: Lower-level capsules β€œvote” for higher-level capsules based on agreement.

Prediction Vectors: Each lower capsule i predicts higher capsule j’s output:

\[\hat{u}_{j|i} = W_{ij} u_i\]

where W_ij ∈ ℝ^{d_out Γ— d_in} is transformation matrix.

Routing Coefficients: Probability capsule i sends output to capsule j:

\[c_{ij} = \frac{\exp(b_{ij})}{\sum_k \exp(b_{ik})}\]

where b_ij are routing logits (initially 0).

Total Input: Weighted sum of predictions:

\[s_j = \sum_i c_{ij} \hat{u}_{j|i}\]

Update Rule (iterative):

\[b_{ij} \leftarrow b_{ij} + \hat{u}_{j|i} \cdot v_j\]

Agreement: Dot product measures how well prediction aligns with actual output.

Routing Algorithm (Detailed)ΒΆ

Procedure: DynamicRouting(u, r, W)
  Input: u_i (lower capsule outputs), r (iterations), W_ij (weights)
  Output: v_j (higher capsule outputs)
  
  1. Initialize: b_ij ← 0 for all i, j
  2. For iteration in 1 to r:
       a. c_ij ← softmax(b_ij) over j
       b. s_j ← Ξ£_i c_ij (W_ij u_i)
       c. v_j ← squash(s_j)
       d. b_ij ← b_ij + (W_ij u_i) Β· v_j  (except last iteration)
  3. Return v_j

Computational Cost: O(r Β· I Β· J Β· dΒ²) where r = iterations, I = input caps, J = output caps, d = dimension.

Theoretical JustificationΒΆ

EM Perspective: Routing is similar to EM algorithm:

  • E-step: Compute c_ij (soft assignment)

  • M-step: Update v_j (cluster means)

Convergence: Routing maximizes agreement between predictions and outputs.

3. Margin LossΒΆ

Classification LossΒΆ

Objective: Separate present and absent classes with margin.

Loss for Class k:

\[L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\| - m^-)^2\]

where:

  • T_k ∈ {0, 1}: Indicator if class k present

  • m⁺ = 0.9: Target magnitude for present classes

  • m⁻ = 0.1: Max magnitude for absent classes

  • Ξ» = 0.5: Down-weight absent class loss

Total Loss:

\[L = \sum_k L_k\]

Intuition:

  • Present class: Push ||v_k|| above 0.9

  • Absent class: Push ||v_k|| below 0.1

  • Down-weight absent to prevent initial learning stagnation

Reconstruction RegularizationΒΆ

Decoder Network: Reconstruct input from capsule outputs.

Objective: Force capsules to encode useful information.

Loss:

\[L_{recon} = \|X - \hat{X}\|^2\]

where X is input image, XΜ‚ is reconstruction.

Combined:

\[L_{total} = L_{margin} + \alpha L_{recon}\]

Typical: Ξ± = 0.0005 (small to avoid drowning out margin loss).

4. EM Routing (Matrix Capsules)ΒΆ

MotivationΒΆ

Limitation of Dynamic Routing: Scalar weights c_ij can’t represent uncertainty.

Solution: Model each capsule as Gaussian with mean ΞΌ_j and variance Οƒ_jΒ².

Capsule as GaussianΒΆ

Activation a_j: Probability capsule j is active.

Pose Matrix M_j ∈ ℝ^{4Γ—4}: Viewpoint-invariant representation.

Distribution: p(M_j) = N(ΞΌ_j, Οƒ_jΒ²I)

EM Routing AlgorithmΒΆ

E-Step: Compute assignment probabilities R_ij:

\[R_{ij} = \frac{a_i p(V_i | \mu_j, \sigma_j^2)}{\sum_k a_k p(V_i | \mu_k, \sigma_k^2)}\]

where V_i = W_ij M_i is vote from capsule i.

M-Step: Update higher-level capsule parameters:

\[a_j = \frac{\sum_i R_{ij} a_i}{n}\]
\[\mu_j = \frac{\sum_i R_{ij} a_i V_i}{\sum_i R_{ij} a_i}\]
\[\sigma_j^2 = \frac{\sum_i R_{ij} a_i \|V_i - \mu_j\|^2}{\sum_i R_{ij} a_i}\]

Cost Function: Negative log-likelihood with activation cost:

\[-\sum_j \sum_i R_{ij} a_i \log p(V_i | \mu_j, \sigma_j^2) + \beta \sum_j a_j\]

where Ξ² controls sparsity.

5. Equivariance PropertiesΒΆ

Viewpoint EquivarianceΒΆ

Definition: Transformation of input causes corresponding transformation of capsule output.

Mathematical Formulation: For transformation T:

\[T(input) \rightarrow capsule(T(input)) = T'(capsule(input))\]

where T’ is corresponding transformation in capsule space.

Example: Rotate image 15Β° β†’ capsule orientation rotates 15Β°.

Advantage over CNNs: CNNs only have translation equivariance (via convolution).

Part-Whole RelationshipsΒΆ

Compositional Structure: Higher capsules represent wholes, lower capsules represent parts.

Coordinate Frame Agreement: Routing ensures parts agree on whole’s pose.

Example: Two eye capsules route to face capsule if they agree face is present.

6. Capsule Architecture VariantsΒΆ

Original CapsNet (Sabour et al., 2017)ΒΆ

Structure:

Input (28Γ—28) 
  β†’ Conv(256, 9Γ—9, stride=1) + ReLU
  β†’ PrimaryCaps(32 caps Γ— 8D, 9Γ—9, stride=2)
  β†’ DigitCaps(10 caps Γ— 16D, dynamic routing)
  β†’ Length as class probability

Parameters: ~8.2M (similar to baseline CNN)

Performance:

  • MNIST: 99.75% (SOTA at time)

  • MultiMNIST (overlapping): 95.7% (vs 93.3% baseline)

Matrix Capsules (Hinton et al., 2018)ΒΆ

Key Changes:

  • 4Γ—4 pose matrices instead of vectors

  • EM routing instead of dynamic routing

  • Coordinate addition for translation

Structure:

Input β†’ Conv β†’ Primary Capsules (pose + activation)
       β†’ Class Capsules (EM routing, K iterations)
       β†’ Spread loss

Performance: SmallNORB 97.8% (viewpoint generalization)

Stacked Capsule Autoencoders (Kosiorek et al., 2019)ΒΆ

Unsupervised Learning: Learn capsule representations without labels.

Components:

  1. Part Capsule Encoder: Detect object parts

  2. Object Capsule Encoder: Compose parts into objects

  3. Decoder: Reconstruct from object capsules

Set-Based Routing: Treat capsules as sets, use attention.

7. Training TechniquesΒΆ

Learning Rate SchedulingΒΆ

Strategy: Exponential decay with warmup.

Schedule:

lr(t) = lr_base * min(1, t/T_warmup) * decay^(t/T_decay)

Typical: lr_base=1e-3, T_warmup=5000 steps, decay=0.96, T_decay=2000 steps.

RegularizationΒΆ

Techniques:

  1. Reconstruction loss (Ξ±=0.0005)

  2. Weight decay (1e-4)

  3. Dropout on decoder (p=0.5)

No BatchNorm: Interferes with capsule magnitudes.

Data AugmentationΒΆ

Affine Transformations:

  • Random rotation: Β±15Β°

  • Translation: Β±2 pixels

  • Scaling: 0.9-1.1Γ—

Cutout: Randomly mask image patches (improves robustness).

8. Computational EfficiencyΒΆ

Memory ComplexityΒΆ

Dynamic Routing: O(B Β· I Β· J Β· d) per iteration

  • B: Batch size

  • I: Input capsules

  • J: Output capsules

  • d: Capsule dimension

Bottleneck: Storing u_hat for all pairs (IΓ—JΓ—d).

Speed OptimizationsΒΆ

1. Fewer Routing Iterations:

  • r=3 is standard, r=1 often sufficient

  • Diminishing returns beyond r=3

2. Sparse Routing:

  • Route each input to top-k output capsules only

  • Reduces O(IΒ·J) to O(IΒ·k)

3. Grouped Capsules:

  • Split capsules into independent groups

  • Route within groups only

4. Parallel Implementation:

  • Routing iterations sequential, but batch parallel

  • GPU-friendly matrix operations

9. Limitations and ChallengesΒΆ

Scalability IssuesΒΆ

Problem: Quadratic complexity in number of capsules.

Impact: Hard to scale to ImageNet (224Γ—224).

Solutions:

  • Multi-scale architecture

  • Local routing (spatial locality)

  • Attention-based routing

Adversarial VulnerabilityΒΆ

Finding: Capsule networks as vulnerable as CNNs to adversarial examples.

Reason: Margin loss doesn’t inherently provide robustness.

Mitigation: Adversarial training + capsules.

Training InstabilityΒΆ

Issues:

  • Gradient explosion in early training

  • Routing coefficients may collapse to one capsule

Solutions:

  • Careful initialization (Xavier/He)

  • Gradient clipping (max norm 1.0)

  • Warmup learning rate schedule

10. Recent Advances (2020-2024)ΒΆ

Self-Routing CapsulesΒΆ

Idea: Learn routing mechanism instead of iterative algorithm.

Method: Attention-based routing in single forward pass.

Advantage: Faster, differentiable routing weights.

Efficient Capsules (Mazzia et al., 2021)ΒΆ

Improvements:

  • Depthwise separable convolutions in capsule layers

  • Squash activation only at final layer

  • Faster than original CapsNet by 3Γ—

Capsule Networks for Vision TransformersΒΆ

Hybrid Architecture: Combine ViT attention with capsule routing.

Benefit: Part-whole relationships + long-range dependencies.

11. ApplicationsΒΆ

Medical ImagingΒΆ

Use Case: Lesion detection, organ segmentation.

Advantage: Handles overlapping structures better than CNNs.

Example: Brain tumor segmentation (BraTS dataset).

3D Object RecognitionΒΆ

Task: Classify objects under different viewpoints.

Dataset: ShapeNet, ModelNet40.

Performance: Better viewpoint generalization than CNNs.

Action RecognitionΒΆ

Method: Temporal capsules for video understanding.

Structure: 3D convolution β†’ temporal capsules β†’ action class.

Benefit: Model hierarchical action parts.

Adversarial DefenseΒΆ

Approach: Capsule networks as certified defense.

Mechanism: Reconstruction loss detects adversarial perturbations.

Limitation: Still vulnerable to sophisticated attacks.

12. Comparison with Other ArchitecturesΒΆ

Capsules vs CNNsΒΆ

Aspect

CNN

Capsule

Feature

Scalar

Vector

Pooling

Max pool (loses info)

Routing (preserves)

Equivariance

Translation only

Viewpoint

Part-whole

No

Yes (via routing)

Overlapping

Struggles

Handles well

Speed

Fast

Slower (routing)

Capsules vs TransformersΒΆ

Similarities:

  • Both use attention/routing mechanisms

  • Both model relationships between entities

Differences:

  • Capsules: Hierarchical part-whole (compositional)

  • Transformers: Flat self-attention (relational)

  • Capsules: Viewpoint equivariance

  • Transformers: Permutation equivariance

Capsules vs Graph Neural NetworksΒΆ

Connection: Both propagate information between entities.

Difference:

  • GNN: General graph structure

  • Capsules: Hierarchical (layer-wise routing)

13. Theoretical InsightsΒΆ

Routing as ClusteringΒΆ

Interpretation: Dynamic routing clusters lower capsules to higher capsules.

Objective: Minimize within-cluster variance (maximize agreement).

Connection to k-means: Soft assignment via softmax (vs hard assignment).

Information BottleneckΒΆ

View: Capsule network as layered information bottleneck.

Compression: Routing compresses lower-level information.

Preservation: Maintains task-relevant information in higher capsules.

Generalization BoundsΒΆ

Analysis: VC dimension and Rademacher complexity for capsule networks.

Finding: Similar generalization to CNNs with comparable parameters.

Advantage: Better generalization on viewpoint variations (empirical).

14. Implementation Best PracticesΒΆ

InitializationΒΆ

Weight Matrices W_ij:

  • Xavier/Glorot: $\(W \sim \mathcal{N}(0, \frac{2}{d_{in} + d_{out}})\)$

  • He initialization: $\(W \sim \mathcal{N}(0, \frac{2}{d_{in}})\)$

Routing Logits b_ij: Initialize to 0 (uniform routing).

Hyperparameter TuningΒΆ

Parameter

Typical Value

Sensitivity

Routing iterations

3

Low (2-5 works)

Capsule dimension

8-16

Medium

Reconstruction weight

0.0005

High

Margin m⁺

0.9

Medium

Margin m⁻

0.1

Low

Down-weight Ξ»

0.5

Low

Debugging TipsΒΆ

Check:

  1. Capsule norms: Should be in [0, 1]

  2. Routing coefficients: Should sum to 1 across j

  3. Gradient norms: Clip if >1.0

  4. Loss breakdown: Margin vs reconstruction ratio

Common Issues:

  • Routing collapse: All c_ij route to one capsule β†’ use entropy regularization

  • Vanishing capsules: ||v_j|| always near 0 β†’ increase learning rate, check initialization

  • Reconstruction dominates: Scale down Ξ±

15. Key PapersΒΆ

FoundationalΒΆ

  • Sabour, Frosst, Hinton (2017): β€œDynamic Routing Between Capsules” (Original CapsNet)

  • Hinton, Sabour, Frosst (2018): β€œMatrix Capsules with EM Routing”

ImprovementsΒΆ

  • Mazzia et al. (2021): β€œEfficient-CapsNet: Capsule Network with Self-Attention Routing”

  • Kosiorek et al. (2019): β€œStacked Capsule Autoencoders” (Unsupervised)

  • Wang & Liu (2020): β€œGroup Equivariant Capsule Networks”

ApplicationsΒΆ

  • LaLonde & Bagci (2018): β€œCapsules for Object Segmentation” (Medical imaging)

  • Zhao et al. (2019): β€œ3D Point Capsule Networks” (Point cloud classification)

TheoryΒΆ

  • Paik et al. (2019): β€œCapsule Networks Need an Improved Routing Algorithm”

  • Zhang et al. (2020): β€œRethinking the Inception Architecture for Capsule Networks”

16. Future DirectionsΒΆ

Open ProblemsΒΆ

  1. Scalability: How to scale capsules to ImageNet-scale images?

  2. Architecture Search: Optimal capsule layer configurations?

  3. Theoretical Understanding: Why do capsules work? Formal guarantees?

  4. Routing Alternatives: Better than dynamic routing?

Promising ApproachesΒΆ

Self-Attention Routing: Replace iterative routing with learned attention.

Sparse Capsules: Activate only relevant capsules (conditional computation).

Continuous Capsules: Infinite-dimensional capsules (functional representation).

Hybrid Models: Combine capsules with transformers/GNNs.

17. When to Use Capsule NetworksΒΆ

Choose Capsules When:

  • Viewpoint equivariance critical (3D object recognition)

  • Overlapping/occluded objects common (medical imaging)

  • Part-whole relationships important (compositional reasoning)

  • Interpretable representations desired

Avoid Capsules When:

  • Speed critical (real-time inference)

  • Large-scale images (computational cost)

  • Standard classification sufficient (CNNs faster)

  • Limited compute resources (capsules memory-intensive)

# Advanced Capsule Networks Implementations

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional

class SquashActivation(nn.Module):
    """
    Squashing non-linearity: v = (||s||^2 / (1 + ||s||^2)) * (s / ||s||)
    Maps vectors to [0,1) while preserving direction.
    """
    
    def __init__(self, dim=-1, eps=1e-8):
        super().__init__()
        self.dim = dim
        self.eps = eps
    
    def forward(self, s):
        """Apply squashing to input vectors."""
        squared_norm = (s ** 2).sum(dim=self.dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        unit_vector = s / torch.sqrt(squared_norm + self.eps)
        return scale * unit_vector


class DynamicRoutingLayer(nn.Module):
    """
    Dynamic routing between capsule layers (Sabour et al., 2017).
    Iteratively routes lower capsules to higher capsules based on agreement.
    """
    
    def __init__(self, in_capsules, out_capsules, in_dim, out_dim, 
                 num_iterations=3, share_weights=False):
        super().__init__()
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_iterations = num_iterations
        
        # Transformation matrices
        if share_weights:
            # Share weights across all input capsules
            self.W = nn.Parameter(torch.randn(1, out_capsules, out_dim, in_dim))
        else:
            # Separate weights for each (input, output) capsule pair
            self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
        
        self.squash = SquashActivation(dim=-1)
    
    def forward(self, u):
        """
        Args:
            u: (batch, in_capsules, in_dim) - lower capsule outputs
        Returns:
            v: (batch, out_capsules, out_dim) - higher capsule outputs
        """
        batch_size = u.size(0)
        
        # Compute prediction vectors: u_hat = W @ u
        # u shape: (batch, in_caps, in_dim)
        # W shape: (in_caps, out_caps, out_dim, in_dim)
        u_expand = u.unsqueeze(2).unsqueeze(4)  # (batch, in_caps, 1, in_dim, 1)
        W_expand = self.W.unsqueeze(0)  # (1, in_caps, out_caps, out_dim, in_dim)
        
        # Matrix multiplication: (out_dim, in_dim) @ (in_dim, 1) = (out_dim, 1)
        u_hat = torch.matmul(W_expand, u_expand)  # (batch, in_caps, out_caps, out_dim, 1)
        u_hat = u_hat.squeeze(-1)  # (batch, in_caps, out_caps, out_dim)
        
        # Stop gradient on predictions (don't backprop through routing)
        u_hat_detached = u_hat.detach()
        
        # Initialize routing logits to zero
        b = torch.zeros(batch_size, self.in_capsules, self.out_capsules, 1,
                       device=u.device, dtype=u.dtype)
        
        # Routing iterations
        for iteration in range(self.num_iterations):
            # Softmax to get coupling coefficients: c_ij = softmax(b_ij)
            c = F.softmax(b, dim=2)  # (batch, in_caps, out_caps, 1)
            
            # Weighted sum of predictions: s_j = sum_i c_ij * u_hat_j|i
            if iteration == self.num_iterations - 1:
                # Last iteration: use actual predictions (with gradient)
                s = (c * u_hat).sum(dim=1)  # (batch, out_caps, out_dim)
            else:
                # Other iterations: use detached predictions (no gradient)
                s = (c * u_hat_detached).sum(dim=1)
            
            # Squashing non-linearity: v_j = squash(s_j)
            v = self.squash(s)  # (batch, out_caps, out_dim)
            
            # Update routing logits based on agreement (except last iteration)
            if iteration < self.num_iterations - 1:
                # Agreement: a_ij = u_hat_j|i Β· v_j
                v_expand = v.unsqueeze(1)  # (batch, 1, out_caps, out_dim)
                agreement = (u_hat_detached * v_expand).sum(dim=-1, keepdim=True)
                # b_ij += a_ij
                b = b + agreement
        
        return v


class EMRoutingLayer(nn.Module):
    """
    EM routing between capsule layers (Hinton et al., 2018).
    Models capsules as Gaussians, uses EM for routing.
    """
    
    def __init__(self, in_capsules, out_capsules, pose_dim=4, num_iterations=3):
        super().__init__()
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.pose_dim = pose_dim
        self.num_iterations = num_iterations
        
        # Transformation matrices for pose
        self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, 
                                         pose_dim, pose_dim))
        
        # Learned parameters
        self.beta_v = nn.Parameter(torch.randn(out_capsules))  # Activation cost
        self.beta_a = nn.Parameter(torch.randn(out_capsules))  # Activation threshold
    
    def forward(self, a_in, M_in):
        """
        Args:
            a_in: (batch, in_capsules) - input activations
            M_in: (batch, in_capsules, pose_dim) - input pose matrices
        Returns:
            a_out: (batch, out_capsules) - output activations
            M_out: (batch, out_capsules, pose_dim) - output poses
        """
        batch_size = a_in.size(0)
        
        # Compute votes: V_ij = W_ij @ M_i
        M_expand = M_in.unsqueeze(2).unsqueeze(4)  # (batch, in_caps, 1, pose_dim, 1)
        W_expand = self.W.unsqueeze(0)  # (1, in_caps, out_caps, pose_dim, pose_dim)
        V = torch.matmul(W_expand, M_expand).squeeze(-1)  # (batch, in_caps, out_caps, pose_dim)
        
        # Initialize parameters
        R = torch.ones(batch_size, self.in_capsules, self.out_capsules, 
                      device=a_in.device) / self.out_capsules  # Uniform assignment
        
        # EM iterations
        for iteration in range(self.num_iterations):
            # M-step: Compute capsule parameters
            # Weighted votes
            R_expand = R.unsqueeze(-1)  # (batch, in_caps, out_caps, 1)
            a_expand = a_in.unsqueeze(2).unsqueeze(3)  # (batch, in_caps, 1, 1)
            
            # Activation
            sum_R_a = (R_expand * a_expand).sum(dim=1)  # (batch, out_caps, 1)
            a_out = torch.sigmoid(sum_R_a.squeeze(-1) - self.beta_a)  # (batch, out_caps)
            
            # Mean pose
            weighted_votes = R_expand * a_expand * V  # (batch, in_caps, out_caps, pose_dim)
            sum_weighted = weighted_votes.sum(dim=1)  # (batch, out_caps, pose_dim)
            M_out = sum_weighted / (sum_R_a + 1e-8)  # (batch, out_caps, pose_dim)
            
            # Variance (assume isotropic)
            M_expand = M_out.unsqueeze(1)  # (batch, 1, out_caps, pose_dim)
            diff = V - M_expand  # (batch, in_caps, out_caps, pose_dim)
            variance = ((R_expand * a_expand * diff ** 2).sum(dim=1) / 
                       (sum_R_a + 1e-8))  # (batch, out_caps, pose_dim)
            sigma_sq = variance.mean(dim=-1, keepdim=True)  # (batch, out_caps, 1)
            
            # E-step: Update assignment probabilities
            if iteration < self.num_iterations - 1:
                # Log probability under Gaussian
                log_p = -0.5 * ((diff ** 2) / (sigma_sq.unsqueeze(1) + 1e-8)).sum(dim=-1)
                log_p = log_p - 0.5 * self.pose_dim * torch.log(2 * np.pi * sigma_sq.unsqueeze(1))
                
                # Weighted by input activation
                log_p_weighted = log_p + torch.log(a_expand.squeeze(-1) + 1e-8)
                
                # Softmax to get new assignments
                R = F.softmax(log_p_weighted, dim=2)  # (batch, in_caps, out_caps)
        
        return a_out, M_out


class PrimaryCapsuleLayer(nn.Module):
    """
    Primary capsule layer: converts CNN features to capsules.
    """
    
    def __init__(self, in_channels, num_capsules, capsule_dim, 
                 kernel_size=9, stride=2, padding=0):
        super().__init__()
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim
        
        # Convolution to produce capsule features
        out_channels = num_capsules * capsule_dim
        self.conv = nn.Conv2d(in_channels, out_channels, 
                            kernel_size, stride, padding)
        
        self.squash = SquashActivation(dim=-1)
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_channels, H, W)
        Returns:
            capsules: (batch, num_capsules * H' * W', capsule_dim)
        """
        # Convolution
        conv_out = self.conv(x)  # (batch, num_caps * cap_dim, H', W')
        
        batch_size = conv_out.size(0)
        H_out, W_out = conv_out.size(2), conv_out.size(3)
        
        # Reshape to capsules
        # (batch, num_caps, cap_dim, H', W')
        conv_out = conv_out.view(batch_size, self.num_capsules, 
                                self.capsule_dim, H_out, W_out)
        
        # Merge spatial dimensions: (batch, num_caps * H' * W', cap_dim)
        conv_out = conv_out.permute(0, 1, 3, 4, 2).contiguous()
        capsules = conv_out.view(batch_size, -1, self.capsule_dim)
        
        # Apply squashing
        return self.squash(capsules)


class CapsuleNetworkClassifier(nn.Module):
    """
    Complete CapsNet for classification with reconstruction.
    Based on Sabour et al. (2017).
    """
    
    def __init__(self, input_channels=1, num_classes=10, 
                 routing_iterations=3, reconstruction=True):
        super().__init__()
        self.num_classes = num_classes
        self.reconstruction = reconstruction
        
        # Initial convolution
        self.conv1 = nn.Conv2d(input_channels, 256, kernel_size=9, stride=1)
        
        # Primary capsules: 32 capsules of 8D
        self.primary_caps = PrimaryCapsuleLayer(
            in_channels=256,
            num_capsules=32,
            capsule_dim=8,
            kernel_size=9,
            stride=2
        )
        
        # Calculate number of primary capsules (depends on input size)
        # For 28x28 input: conv1(28->20) -> primary(20->6) -> 32*6*6 = 1152 capsules
        self.num_primary_caps = 32 * 6 * 6  # For MNIST 28x28
        
        # Digit capsules: 10 capsules of 16D (one per class)
        self.digit_caps = DynamicRoutingLayer(
            in_capsules=self.num_primary_caps,
            out_capsules=num_classes,
            in_dim=8,
            out_dim=16,
            num_iterations=routing_iterations
        )
        
        # Decoder for reconstruction regularization
        if self.reconstruction:
            self.decoder = nn.Sequential(
                nn.Linear(16 * num_classes, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 28 * 28),
                nn.Sigmoid()
            )
    
    def forward(self, x, labels=None):
        """
        Args:
            x: (batch, channels, H, W) - input images
            labels: (batch,) - true labels (for reconstruction)
        Returns:
            class_probs: (batch, num_classes) - classification probabilities
            reconstruction: (batch, H*W) - reconstructed images (if enabled)
        """
        # Initial convolution with ReLU
        conv_out = F.relu(self.conv1(x))  # (batch, 256, 20, 20)
        
        # Primary capsules
        primary_caps = self.primary_caps(conv_out)  # (batch, 1152, 8)
        
        # Digit capsules via dynamic routing
        digit_caps = self.digit_caps(primary_caps)  # (batch, 10, 16)
        
        # Class probabilities: length of capsule vectors
        class_probs = torch.sqrt((digit_caps ** 2).sum(dim=-1))  # (batch, 10)
        
        # Reconstruction
        if self.reconstruction:
            # Mask: select capsule for correct class (training) or predicted class (testing)
            if labels is not None:
                mask_indices = labels
            else:
                mask_indices = class_probs.argmax(dim=1)
            
            # Create mask
            mask = torch.zeros_like(digit_caps)
            mask[torch.arange(digit_caps.size(0)), mask_indices] = 1
            
            # Masked capsules
            masked_caps = (digit_caps * mask).view(digit_caps.size(0), -1)
            
            # Decode
            reconstruction = self.decoder(masked_caps)
        else:
            reconstruction = None
        
        return class_probs, reconstruction


class MarginLoss(nn.Module):
    """
    Margin loss for capsule networks.
    L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1-T_k) * max(0, ||v_k|| - m-)^2
    """
    
    def __init__(self, m_plus=0.9, m_minus=0.1, lambda_=0.5):
        super().__init__()
        self.m_plus = m_plus
        self.m_minus = m_minus
        self.lambda_ = lambda_
    
    def forward(self, lengths, labels):
        """
        Args:
            lengths: (batch, num_classes) - capsule lengths
            labels: (batch,) - true class labels
        Returns:
            loss: scalar
        """
        # One-hot encode labels
        T = F.one_hot(labels, num_classes=lengths.size(1)).float()
        
        # Loss for present classes
        loss_present = T * F.relu(self.m_plus - lengths) ** 2
        
        # Loss for absent classes
        loss_absent = (1 - T) * F.relu(lengths - self.m_minus) ** 2
        
        # Total loss
        loss = loss_present + self.lambda_ * loss_absent
        
        return loss.sum(dim=1).mean()


class CapsuleTrainer:
    """
    Trainer for capsule networks with margin loss and reconstruction.
    """
    
    def __init__(self, model, optimizer, reconstruction_weight=0.0005):
        self.model = model
        self.optimizer = optimizer
        self.margin_loss = MarginLoss()
        self.reconstruction_weight = reconstruction_weight
    
    def train_step(self, images, labels):
        """Single training step."""
        self.model.train()
        
        # Forward pass
        class_probs, reconstruction = self.model(images, labels)
        
        # Margin loss
        loss_margin = self.margin_loss(class_probs, labels)
        
        # Reconstruction loss
        if reconstruction is not None:
            target = images.view(images.size(0), -1)
            loss_recon = F.mse_loss(reconstruction, target)
            loss = loss_margin + self.reconstruction_weight * loss_recon
        else:
            loss_recon = torch.tensor(0.0)
            loss = loss_margin
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return {
            'loss': loss.item(),
            'margin_loss': loss_margin.item(),
            'recon_loss': loss_recon.item() if isinstance(loss_recon, torch.Tensor) else 0.0
        }
    
    def evaluate(self, data_loader):
        """Evaluate model on dataset."""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                class_probs, _ = self.model(images)
                predictions = class_probs.argmax(dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
        
        accuracy = 100.0 * correct / total
        return accuracy


# ============================================================================
# Demonstrations
# ============================================================================

print("=" * 70)
print("Capsule Networks - Advanced Implementations")
print("=" * 70)

# 1. Squashing function
print("\n1. Squashing Non-linearity:")
squash = SquashActivation(dim=-1)
s_test = torch.tensor([[0.1, 0.2], [1.0, 2.0], [5.0, 10.0]])
v_test = squash(s_test)
norms = torch.norm(v_test, dim=1)
print(f"   Input vectors: {s_test.shape}")
print(f"   Input norms: {torch.norm(s_test, dim=1).numpy()}")
print(f"   Output norms: {norms.numpy()}")
print(f"   Property: All norms in [0, 1): βœ“")

# 2. Dynamic routing
print("\n2. Dynamic Routing:")
routing = DynamicRoutingLayer(
    in_capsules=10, 
    out_capsules=5, 
    in_dim=8, 
    out_dim=16, 
    num_iterations=3
)
u_test = torch.randn(2, 10, 8)
v_out = routing(u_test)
print(f"   Input: {u_test.shape} (batch, in_caps, in_dim)")
print(f"   Output: {v_out.shape} (batch, out_caps, out_dim)")
print(f"   Routing iterations: 3")
print(f"   Mechanism: Agreement-based soft assignment")
print(f"   Parameters: {sum(p.numel() for p in routing.parameters()):,}")

# 3. EM routing
print("\n3. EM Routing:")
em_routing = EMRoutingLayer(
    in_capsules=10,
    out_capsules=5,
    pose_dim=4,
    num_iterations=3
)
a_in = torch.rand(2, 10)  # Activations
M_in = torch.randn(2, 10, 4)  # Poses
a_out, M_out = em_routing(a_in, M_in)
print(f"   Input: activations {a_in.shape}, poses {M_in.shape}")
print(f"   Output: activations {a_out.shape}, poses {M_out.shape}")
print(f"   Algorithm: E-step (soft assignment) + M-step (Gaussian params)")
print(f"   Advantage: Models uncertainty in routing")

# 4. Complete CapsNet
print("\n4. CapsNet Architecture:")
capsnet = CapsuleNetworkClassifier(
    input_channels=1,
    num_classes=10,
    routing_iterations=3,
    reconstruction=True
)
x_test = torch.randn(2, 1, 28, 28)
labels_test = torch.tensor([3, 7])
probs, recon = capsnet(x_test, labels_test)
print(f"   Input: {x_test.shape}")
print(f"   Conv1: 1 -> 256 channels (9x9)")
print(f"   Primary Caps: 32 capsules Γ— 8D (1,152 total)")
print(f"   Digit Caps: 10 capsules Γ— 16D (dynamic routing)")
print(f"   Output probs: {probs.shape}")
print(f"   Reconstruction: {recon.shape}")
print(f"   Total parameters: {sum(p.numel() for p in capsnet.parameters()):,}")

# 5. Margin loss
print("\n5. Margin Loss:")
margin_loss_fn = MarginLoss(m_plus=0.9, m_minus=0.1, lambda_=0.5)
lengths_test = torch.tensor([[0.95, 0.05, 0.02], [0.1, 0.85, 0.03]])
labels_test = torch.tensor([0, 1])
loss_test = margin_loss_fn(lengths_test, labels_test)
print(f"   Formula: L_k = T_k·max(0, m⁺ - ||v_k||)² + λ(1-T_k)·max(0, ||v_k|| - m⁻)²")
print(f"   m⁺ = 0.9 (target for present class)")
print(f"   m⁻ = 0.1 (max for absent classes)")
print(f"   Ξ» = 0.5 (down-weight absent classes)")
print(f"   Example loss: {loss_test.item():.4f}")

# 6. Routing comparison
print("\n6. Dynamic vs EM Routing:")
print("   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("   β”‚ Aspect           β”‚ Dynamic     β”‚ EM           β”‚ Best For     β”‚")
print("   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("   β”‚ Representation   β”‚ Vectors     β”‚ Gaussians    β”‚ EM (richer)  β”‚")
print("   β”‚ Uncertainty      β”‚ No          β”‚ Yes          β”‚ EM           β”‚")
print("   β”‚ Complexity       β”‚ Lower       β”‚ Higher       β”‚ Dynamic      β”‚")
print("   β”‚ Pose info        β”‚ Implicit    β”‚ Explicit     β”‚ EM           β”‚")
print("   β”‚ Speed            β”‚ Faster      β”‚ Slower       β”‚ Dynamic      β”‚")
print("   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")

# 7. Equivariance demonstration
print("\n7. Viewpoint Equivariance:")
print("   CNN: Translation equivariant only")
print("     rotate(image) -> features != rotate(features)")
print("   ")
print("   Capsule: Viewpoint equivariant")
print("     rotate(image) -> capsule orientation rotates")
print("     Capsule direction encodes pose information")
print("   ")
print("   Benefit: Better generalization to novel viewpoints")

# 8. When to use guide
print("\n8. When to Use Capsule Networks:")
print("   Use Capsules when:")
print("     βœ“ Viewpoint equivariance critical (3D objects)")
print("     βœ“ Overlapping objects common (medical imaging)")
print("     βœ“ Part-whole relationships important")
print("     βœ“ Interpretable representations desired")
print("\n   Use CNNs when:")
print("     βœ“ Speed critical (real-time inference)")
print("     βœ“ Large-scale images (computational cost)")
print("     βœ“ Standard classification sufficient")
print("     βœ“ Limited compute resources")

# 9. Training tips
print("\n9. Training Best Practices:")
print("   Initialization:")
print("     β€’ Xavier/He for weight matrices")
print("     β€’ Zero for routing logits")
print("   ")
print("   Hyperparameters:")
print("     β€’ Routing iterations: 3 (2-5 works)")
print("     β€’ Reconstruction weight: 0.0005")
print("     β€’ Margin m⁺: 0.9, m⁻: 0.1")
print("     β€’ Learning rate: 1e-3 with decay")
print("   ")
print("   Regularization:")
print("     β€’ Gradient clipping (max norm 1.0)")
print("     β€’ Reconstruction loss (forces useful encoding)")
print("     β€’ No batch normalization (interferes with magnitudes)")

print("\n" + "=" * 70)

Advanced Capsule Networks TheoryΒΆ

1. Introduction to Capsule NetworksΒΆ

Capsule Networks (CapsNets) represent a paradigm shift from traditional CNNs, replacing scalar neurons with vector-valued capsules.

1.1 MotivationΒΆ

Problems with CNNs:

  • Pooling loses spatial information: Max pooling discards precise localization

  • Lack of part-whole relationships: CNNs don’t model hierarchical structure

  • Pose insensitivity: Difficult to capture viewpoint, orientation, scale

  • Adversarial vulnerability: Small perturbations fool the network

Capsule philosophy: β€œActivities of neurons in a capsule represent various properties of the same entity.”

1.2 What is a Capsule?ΒΆ

Capsule: Group of neurons whose activity vector represents:

  • Length: Probability that entity exists (||v|| ∈ [0, 1])

  • Orientation: Instantiation parameters (pose, texture, deformation, etc.)

Example: Face detection capsule

  • Length: Confidence face is present

  • Orientation: Face angle, position, lighting

Key innovation: Replace scalar activations with vectors.

1.3 Core PrinciplesΒΆ

  1. Equivariance: Activities should change predictably with viewpoint

  2. Part-whole relationships: Lower capsules vote for higher capsules

  3. Routing by agreement: Higher capsules represent agreements of lower capsules

  4. Inverse graphics: Reconstruct input from capsule activations

2. Dynamic Routing AlgorithmΒΆ

Dynamic routing replaces pooling with iterative routing between capsule layers.

2.1 Routing ProcedureΒΆ

Goal: Route lower-level capsules to higher-level capsules based on agreement.

Notation:

  • u_i: Activation of capsule i in layer l

  • v_j: Activation of capsule j in layer l+1

  • W_{ij}: Transformation matrix from i to j

  • c_{ij}: Coupling coefficient (routing weight)

Prediction vector (vote):

Γ»_{j|i} = W_{ij} u_i

Weighted sum:

s_j = Ξ£_i c_{ij} Γ»_{j|i}

Squashing (non-linearity):

v_j = squash(s_j) = (||s_j||Β² / (1 + ||s_j||Β²)) Β· (s_j / ||s_j||)

Properties of squash:

  • Short vectors β†’ nearly 0

  • Long vectors β†’ length close to 1

  • Direction preserved

2.2 Routing by AgreementΒΆ

Iterative procedure:

  1. Initialize: b_{ij} = 0 (log prior probabilities)

  2. Repeat for r iterations:

    • Softmax: c_{ij} = exp(b_{ij}) / Ξ£_k exp(b_{ik})

    • Weighted sum: s_j = Ξ£_i c_{ij} Γ»_{j|i}

    • Squash: v_j = squash(s_j)

    • Update: b_{ij} ← b_{ij} + Γ»_{j|i} Β· v_j

Agreement: Γ»_{j|i} Β· v_j measures how well prediction matches output

Intuition:

  • High agreement β†’ increase c_{ij}

  • Low agreement β†’ decrease c_{ij}

2.3 Routing AlgorithmΒΆ

Procedure: DynamicRouting(Γ»_{j|i}, r, l)
  for all capsule i in layer l and capsule j in layer l+1:
    b_{ij} ← 0
  
  for iteration in 1 to r:
    for all capsule i in layer l:
      c_i ← softmax(b_i)  # c_i = [c_{i1}, ..., c_{iJ}]
    
    for all capsule j in layer l+1:
      s_j ← Ξ£_i c_{ij} Γ»_{j|i}
      v_j ← squash(s_j)
    
    for all capsule i in layer l and capsule j in layer l+1:
      b_{ij} ← b_{ij} + Γ»_{j|i} Β· v_j
  
  return v

Complexity: O(r Β· L Β· H Β· dΒ²)

  • r: Routing iterations (typically 3)

  • L: Lower capsules

  • H: Higher capsules

  • d: Capsule dimension

3. CapsNet ArchitectureΒΆ

3.1 Original CapsNet (Sabour et al., 2017)ΒΆ

Architecture for MNIST:

Layer 1: Convolutional

  • 256 filters, 9Γ—9 kernel, stride 1, ReLU

  • Output: [batch, 20, 20, 256]

Layer 2: PrimaryCaps

  • 32 channels of 8D capsules

  • Each channel: Conv 9Γ—9, stride 2

  • Output: [batch, 6, 6, 32, 8] = 1,152 capsules of dim 8

Layer 3: DigitCaps

  • 10 capsules (one per class), dim 16

  • Dynamic routing from PrimaryCaps

  • Output: [batch, 10, 16]

Prediction: argmax_j ||v_j||

3.2 Squashing FunctionΒΆ

Mathematical form:

squash(s) = (||s||Β² / (1 + ||s||Β²)) Β· (s / ||s||)

Derivative:

βˆ‚squash(s) / βˆ‚s = (1 / (1 + ||s||Β²)) [I + (s s^T / ||s||Β²) Β· (2||s||Β² / (1 + ||s||Β²) - 1)]

Properties:

  • squash(0) = 0

  • squash(s) β†’ s/||s|| as ||s|| β†’ ∞

  • ||squash(s)|| < 1 for all s

3.3 Margin LossΒΆ

For multi-class classification:

L_k = T_k max(0, m⁺ - ||v_k||)² + λ (1 - T_k) max(0, ||v_k|| - m⁻)²

Where:

  • T_k = 1 if class k present, 0 otherwise

  • m⁺ = 0.9: Upper margin

  • m⁻ = 0.1: Lower margin

  • Ξ» = 0.5: Down-weighting for absent classes

Total loss:

L_margin = Ξ£_k L_k

Intuition:

  • Present class: ||v_k|| should be > 0.9

  • Absent class: ||v_k|| should be < 0.1

3.4 Reconstruction RegularizationΒΆ

Decoder network: Reconstruct input from DigitCaps

Mask: Keep only correct class capsule (or predicted during testing)

Architecture:

DigitCaps (16D) β†’ FC(512) β†’ FC(1024) β†’ FC(784) β†’ Sigmoid β†’ Reconstruction

Reconstruction loss:

L_recon = ||x - xΜ‚||Β²

Total loss:

L_total = L_margin + Ξ± L_recon

Where Ξ± = 0.0005 (small weight to not dominate)

Purpose:

  • Regularization

  • Force capsules to learn meaningful representations

  • Enable interpretability

4. Routing AlgorithmsΒΆ

4.1 Dynamic RoutingΒΆ

Original algorithm (Sabour et al., 2017):

  • Iterative (3 iterations typical)

  • Coupling coefficients sum to 1

  • Agreement-based updates

Advantages:

  • Captures part-whole relationships

  • Equivariant to affine transformations

Disadvantages:

  • Computationally expensive

  • Non-parallelizable across routing iterations

4.2 EM RoutingΒΆ

EM Routing [Hinton et al., 2018]: Treat routing as EM clustering.

E-step: Assign lower capsules to higher capsules

r_{ij} = a_i Β· p_j Β· exp(-cost_{ij})

Where cost_{ij} is Mahalanobis distance.

M-step: Update higher capsule parameters

ΞΌ_j = (Ξ£_i r_{ij} Γ»_{j|i}) / (Ξ£_i r_{ij})
σ²_j = (Ξ£_i r_{ij} (Γ»_{j|i} - ΞΌ_j)Β²) / (Ξ£_i r_{ij})

Activation:

a_j = logistic(Ξ» (Ξ²_a - Ξ£_i cost_j))

Advantages:

  • Probabilistic interpretation

  • Models uncertainty (variance)

Disadvantages:

  • More complex

  • More hyperparameters

4.3 Self-RoutingΒΆ

Self-Routing [Hahn et al., 2019]: Non-iterative, single-pass routing.

Key idea: Predict routing coefficients directly.

c_{ij} = softmax_j(MLP([u_i, Γ»_{j|i}]))

Advantages:

  • Faster (no iterations)

  • Parallelizable

  • Gradient-friendly

Disadvantages:

  • Less expressive than iterative routing

4.4 Attention RoutingΒΆ

Use attention mechanism for routing:

c_{ij} = softmax_j((W_Q u_i)^T (W_K û_{j|i}) / √d)

Similar to transformer attention but for capsule routing.

5. Capsule LayersΒΆ

5.1 Primary CapsulesΒΆ

First capsule layer: Converts CNN features to capsules.

Implementation:

conv_output = Conv(input)  # [batch, H, W, C]
capsules = reshape(conv_output, [batch, H, W, num_capsules, capsule_dim])
capsules = squash(capsules)

Purpose:

  • Extract low-level features

  • Initialize capsule hierarchy

5.2 Convolutional CapsulesΒΆ

ConvCaps [Hinton et al., 2018]: Capsules with local receptive fields.

Votes:

Γ»_{j|i} = W_{ij} u_i

Where i ranges over spatial neighborhood.

Advantages:

  • Translation equivariance

  • Fewer parameters than fully-connected capsules

5.3 Class CapsulesΒΆ

Final layer: One capsule per class.

Length: Class probability Orientation: Class-specific properties

6. Variants and ExtensionsΒΆ

6.1 Matrix Capsules (Hinton et al., 2018)ΒΆ

Capsules as matrices instead of vectors:

Activation: Matrix M_{ij}

  • Represents pose (rotation, scale, etc.)

Routing: EM algorithm

Advantages:

  • More expressive (4Γ—4 matrix = 16 parameters)

  • Better models affine transformations

6.2 Stacked Capsule Autoencoders (SCAE)ΒΆ

SCAE [Kosiorek et al., 2019]: Unsupervised capsule learning.

Components:

  • Part Capsules: Discover object parts

  • Object Capsules: Discover objects from parts

  • Template-based: Learn templates for parts

Training:

  • Reconstruction loss

  • Set loss (permutation invariant)

6.3 Efficient CapsNetsΒΆ

Bottleneck: Routing is expensive (O(L Β· H Β· dΒ²)).

Solutions:

1. Sparse Routing:

  • Route only to k nearest capsules

  • k-means clustering

2. Low-Rank Approximation:

W_{ij} = U_i V_j^T

Reduces parameters from dΒ² to 2dr.

3. Shared Weights:

  • Weight sharing across spatial positions

  • Similar to convolution

6.4 Capsule GANsΒΆ

CapsGAN: Use capsules in generator/discriminator.

Benefits:

  • Disentangled representations

  • Controllable generation

6.5 Capsule TransformersΒΆ

CapsFormer: Combine capsules with transformers.

Routing as attention:

Attention(Q, K, V) = softmax(QK^T / √d) V

Where Q, K, V are capsule activations.

7. Training ConsiderationsΒΆ

7.1 Loss FunctionsΒΆ

Margin loss (classification):

L_k = T_k max(0, m⁺ - ||v_k||)² + λ (1 - T_k) max(0, ||v_k|| - m⁻)²

Spread loss [Hinton et al., 2018]:

L = Σ_{i≠t} max(0, margin - (a_t - a_i))²

Encourages correct class activation to exceed others by margin.

Reconstruction loss:

L_recon = ||x - decoder(mask(v))||Β²

7.2 HyperparametersΒΆ

Critical hyperparameters:

Parameter

Typical Value

Notes

Routing iterations r

3

More iterations = better, but slower

Capsule dimension

8-16

Higher = more expressive

m⁺ (upper margin)

0.9

Target for present class

m⁻ (lower margin)

0.1

Target for absent class

Ξ» (down-weight)

0.5

Balance false positives

Ξ± (recon weight)

0.0005

Regularization strength

7.3 InitializationΒΆ

Capsule weights: Xavier/He initialization

Routing logits: Initialize b_{ij} = 0 (uniform routing initially)

Reconstruction decoder: Standard initialization

7.4 OptimizationΒΆ

Optimizer: Adam typical (lr = 0.001)

Gradient clipping: Important due to routing dynamics

Batch normalization: Can be applied to capsule activations

8. Theoretical PropertiesΒΆ

8.1 EquivarianceΒΆ

Capsules exhibit equivariance to transformations:

If input undergoes affine transformation T:

v(T(x)) β‰ˆ T(v(x))

Proof sketch:

  • Transformation matrices W_{ij} model geometric relationships

  • Routing preserves structure through agreement

8.2 Part-Whole RelationshipsΒΆ

Capsules model compositional hierarchy:

Lower capsule u_i represents part. Higher capsule v_j represents whole. Agreement Γ»_{j|i} Β· v_j measures compatibility.

Example: Face recognition

  • Lower: Eyes, nose, mouth capsules

  • Higher: Face capsule

  • Routing: Face capsule active when parts agree on face pose

8.3 Viewpoint InvarianceΒΆ

Traditional CNNs: Achieve invariance via pooling (loses info).

CapsNets: Achieve equivariance (retains info).

Key difference:

  • Invariance: f(T(x)) = f(x)

  • Equivariance: f(T(x)) = T(f(x))

Equivariance is more informative!

8.4 RobustnessΒΆ

Adversarial robustness: Capsules more robust than CNNs [Hinton et al., 2018].

Reason:

  • Requires fooling multiple capsule dimensions

  • Agreement mechanism filters inconsistent perturbations

Empirical evidence:

  • CapsNets achieve higher accuracy under adversarial attacks

  • More interpretable failure modes

9. ApplicationsΒΆ

9.1 Image ClassificationΒΆ

MNIST: 99.75% accuracy (state-of-art among non-ensemble methods)

CIFAR-10: ~75% accuracy (competitive but not state-of-art)

SmallNORB: Excellent performance (viewpoint variation)

Observation: CapsNets excel on tasks requiring pose understanding.

9.2 Object DetectionΒΆ

CapsNet-based detectors:

  • Replace RPN with capsule proposals

  • Capsule features for bounding box regression

Challenges:

  • Scalability to high-resolution images

  • Computational cost

9.3 SegmentationΒΆ

Segmentation CapsNets:

  • Encoder-decoder with capsule layers

  • Pixel-wise capsules

  • Deconvolutional capsules for upsampling

Benefits:

  • Better boundary delineation

  • Part-aware segmentation

9.4 Medical ImagingΒΆ

Success stories:

  • Tumor segmentation

  • Pathology classification

  • X-ray analysis

Reasons for success:

  • Small datasets (regularization via reconstruction)

  • Viewpoint variation (capsule equivariance)

  • Interpretability (clinician trust)

9.5 Video AnalysisΒΆ

Temporal Capsules:

  • 3D convolutions β†’ 3D capsules

  • Routing across time

Applications:

  • Action recognition

  • Video prediction

  • Anomaly detection

10. Advantages and DisadvantagesΒΆ

10.1 AdvantagesΒΆ

  1. Equivariance: Capsule activities change predictably with transformations

  2. Part-whole relationships: Explicit modeling of hierarchical structure

  3. Fewer parameters: Can achieve comparable performance with fewer params

  4. Interpretability: Capsule dimensions have semantic meaning

  5. Robustness: More robust to adversarial attacks

  6. Data efficiency: Reconstruction regularization helps with small datasets

10.2 DisadvantagesΒΆ

  1. Computational cost: Routing is expensive (O(r Β· L Β· H Β· dΒ²))

  2. Memory: Vector activations require more memory than scalars

  3. Scalability: Difficult to scale to ImageNet-size images

  4. Training instability: Routing dynamics can be unstable

  5. Limited adoption: Not as widely used as CNNs/Transformers

  6. Engineering: Less optimized libraries and hardware support

11. Recent Advances (2019-2024)ΒΆ

11.1 Efficient RoutingΒΆ

Fast Capsule Networks [Gu et al., 2021]:

  • Approximate routing via matrix sketching

  • O(L Β· H Β· d) complexity (linear in d)

Parallel Routing:

  • Parallelize across routing iterations

  • Use gradient checkpointing to save memory

11.2 Group Equivariant CapsulesΒΆ

G-CapsNets [Lenssen et al., 2018]:

  • Equivariance to group transformations (rotations, reflections)

  • Steerable filters in capsule space

SE(3)-equivariant capsules: For 3D data (point clouds, molecules)

11.3 Self-Supervised CapsulesΒΆ

Contrastive Capsules:

  • Pre-train capsules with contrastive learning

  • Augmentation-aware capsules

Masked Capsule Modeling:

  • Similar to BERT

  • Predict masked capsules from context

11.4 Hybrid ArchitecturesΒΆ

CapsNet + Transformers:

  • Use capsules for low-level features

  • Transformers for global context

CapsNet + Graph Networks:

  • Capsules as graph nodes

  • Message passing for routing

11.5 Theoretical InsightsΒΆ

Connection to EM algorithm [Hinton et al., 2018]:

  • Routing as E-step (assignment)

  • Capsule updates as M-step (parameter updates)

Neural ODEs: Routing as continuous dynamics

dv/dt = f(v, u, W)

12. Comparison with CNNsΒΆ

12.1 ArchitectureΒΆ

CNNs:

  • Scalar activations

  • Pooling for invariance

  • Hierarchical features

CapsNets:

  • Vector activations

  • Routing for equivariance

  • Part-whole relationships

12.2 PerformanceΒΆ

ImageNet:

  • CNNs: 75-90% top-1 accuracy

  • CapsNets: Not competitive at this scale yet

MNIST:

  • CNNs: 99.7%

  • CapsNets: 99.75%

SmallNORB (viewpoint):

  • CNNs: ~95%

  • CapsNets: ~97%

12.3 Computational CostΒΆ

Forward pass:

  • CNNs: O(H Β· W Β· C Β· KΒ²) per layer

  • CapsNets: O(H Β· W Β· C Β· D) + O(r Β· L Β· H Β· DΒ²) routing

Memory:

  • CNNs: O(H Β· W Β· C)

  • CapsNets: O(H Β· W Β· C Β· D) (D = capsule dim)

Training time: CapsNets typically 2-5Γ— slower

13. Implementation ChallengesΒΆ

13.1 Numerical StabilityΒΆ

Issues:

  • Division by ||s|| in squashing

  • Softmax overflow in routing

Solutions:

  • Add Ξ΅ to denominators: s / (||s|| + Ξ΅)

  • Log-space softmax

13.2 Gradient FlowΒΆ

Problem: Routing iterations can block gradients.

Solutions:

  • Gradient checkpointing

  • Differentiable routing (no stop_gradient)

13.3 ScalabilityΒΆ

Challenge: Routing doesn’t scale to millions of capsules.

Solutions:

  • Convolutional capsules (local routing)

  • Sparse routing (route to k nearest)

  • Hierarchical routing (multi-stage)

13.4 Hardware OptimizationΒΆ

GPU utilization: Routing is memory-bound, not compute-bound.

Solutions:

  • Fused kernels (combine routing iterations)

  • Mixed precision training

  • Model parallelism (split capsules across GPUs)

14. Benchmarks and DatasetsΒΆ

14.1 MNISTΒΆ

Task: Digit classification (10 classes)

CapsNet performance: 99.75% test accuracy

Ablations:

  • Without routing: 99.23%

  • Without reconstruction: 99.52%

14.2 CIFAR-10ΒΆ

Task: Object classification (10 classes)

CapsNet performance: ~75% test accuracy

CNN performance: 95%+ (ResNet, EfficientNet)

Observation: CapsNets lag behind CNNs on complex natural images.

14.3 SmallNORBΒΆ

Task: 3D object recognition with viewpoint variation

CapsNet performance: 97.3% test accuracy

CNN performance: ~95%

Why CapsNets excel: Viewpoint equivariance!

14.4 MultiMNISTΒΆ

Task: Recognize overlapping digits

CapsNet performance: 95% (both digits correct)

CNN performance: ~80%

Reason: Capsules segment objects via routing.

15. Future DirectionsΒΆ

15.1 Open ProblemsΒΆ

  1. Scalability: Scale to ImageNet-size images efficiently

  2. Architecture search: Optimal capsule layer configurations

  3. Theoretical understanding: Formal analysis of routing dynamics

  4. Pre-training: Effective self-supervised pre-training for capsules

  5. Hardware: Specialized accelerators for capsule operations

15.2 Promising DirectionsΒΆ

  1. Vision Transformers: Combine with attention mechanisms

  2. 3D vision: Point clouds, meshes, volumes

  3. Generative models: Disentangled VAEs, GANs

  4. Graph neural networks: Capsules as graph nodes

  5. Reinforcement learning: Equivariant policies

15.3 Potential BreakthroughsΒΆ

  • Continuous capsules: Neural ODEs for routing

  • Quantum capsules: Leverage quantum superposition

  • Neuromorphic capsules: Spiking neural implementations

  • Causal capsules: Model causal relationships

16. Key TakeawaysΒΆ

  1. Capsules = vectors: Length = probability, orientation = properties

  2. Routing by agreement: Iterative algorithm to connect capsules

  3. Squashing: Non-linear activation preserving direction

  4. Margin loss: Encourage ||v_class|| > 0.9, ||v_other|| < 0.1

  5. Reconstruction: Regularization and interpretability

  6. Equivariance: Activities change with transformations (not pooling)

  7. Part-whole: Explicit hierarchical modeling

  8. Robustness: More resistant to adversarial attacks

  9. Trade-off: Better representations vs. computational cost

  10. Future: Hybrid architectures combining capsules with transformers/GNNs

Core insight: CapsNets preserve information through equivariance rather than discarding it through invariance, enabling better modeling of part-whole relationships and geometric transformations.

17. Mathematical SummaryΒΆ

Squashing:

v_j = (||s_j||Β² / (1 + ||s_j||Β²)) Β· (s_j / ||s_j||)

Dynamic routing:

c_{ij} = exp(b_{ij}) / Ξ£_k exp(b_{ik})
s_j = Ξ£_i c_{ij} W_{ij} u_i
v_j = squash(s_j)
b_{ij} ← b_{ij} + W_{ij} u_i Β· v_j

Margin loss:

L_k = T_k max(0, 0.9 - ||v_k||)Β² + 0.5 (1 - T_k) max(0, ||v_k|| - 0.1)Β²

Total loss:

L = Ξ£_k L_k + 0.0005 ||x - xΜ‚||Β²

ReferencesΒΆ

  1. Sabour et al. (2017) β€œDynamic Routing Between Capsules”

  2. Hinton et al. (2018) β€œMatrix Capsules with EM Routing”

  3. Kosiorek et al. (2019) β€œStacked Capsule Autoencoders”

  4. Hahn et al. (2019) β€œSelf-Routing Capsule Networks”

  5. Gu et al. (2021) β€œFast Capsule Networks via Interest Routing”

  6. Lenssen et al. (2018) β€œGroup Equivariant Capsule Networks”

  7. Hinton et al. (2011) β€œTransforming Auto-Encoders”

  8. Sabour et al. (2018) β€œAdversarial Manipulation of Deep Representations”

  9. Xi et al. (2017) β€œCapsule Network Performance on Complex Data”

  10. Rajasegaran et al. (2019) β€œDeepCaps: Going Deeper with Capsule Networks”

"""
Advanced Capsule Networks Implementations

Production-ready PyTorch implementations of Capsule Networks with dynamic routing.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List
import matplotlib.pyplot as plt

# ============================================================================
# 1. Squashing Activation
# ============================================================================

def squash(s: torch.Tensor, dim: int = -1, epsilon: float = 1e-8) -> torch.Tensor:
    """
    Squashing non-linearity for capsule networks.
    
    v = (||s||Β² / (1 + ||s||Β²)) * (s / ||s||)
    
    Properties:
    - Short vectors β†’ nearly 0
    - Long vectors β†’ length close to 1
    - Direction preserved
    
    Args:
        s: Input tensor [..., capsule_dim]
        dim: Dimension to compute norm over
        epsilon: Small constant for numerical stability
        
    Returns:
        v: Squashed tensor with same shape
    """
    s_norm_sq = torch.sum(s ** 2, dim=dim, keepdim=True)
    s_norm = torch.sqrt(s_norm_sq + epsilon)
    
    # v = (||s||Β² / (1 + ||s||Β²)) * (s / ||s||)
    scale = s_norm_sq / (1.0 + s_norm_sq)
    v = scale * s / s_norm
    
    return v


# ============================================================================
# 2. Dynamic Routing
# ============================================================================

class DynamicRouting(nn.Module):
    """
    Dynamic routing algorithm for capsule networks.
    
    Routes lower-level capsules to higher-level capsules based on agreement.
    
    Args:
        num_iterations: Number of routing iterations (typically 3)
    """
    def __init__(self, num_iterations: int = 3):
        super().__init__()
        self.num_iterations = num_iterations
        
    def forward(
        self,
        u_hat: torch.Tensor,
        batch_size: int
    ) -> torch.Tensor:
        """
        Perform dynamic routing.
        
        Args:
            u_hat: Prediction vectors [batch, num_lower, num_higher, capsule_dim]
            batch_size: Batch size
            
        Returns:
            v: Higher-level capsule activations [batch, num_higher, capsule_dim]
        """
        num_lower = u_hat.size(1)
        num_higher = u_hat.size(2)
        capsule_dim = u_hat.size(3)
        
        # Initialize routing logits b_ij to 0
        # [batch, num_lower, num_higher, 1]
        b = torch.zeros(batch_size, num_lower, num_higher, 1, device=u_hat.device)
        
        for iteration in range(self.num_iterations):
            # Softmax over higher capsules (dim=2)
            # c_ij = exp(b_ij) / Ξ£_k exp(b_ik)
            c = F.softmax(b, dim=2)  # [batch, num_lower, num_higher, 1]
            
            # Weighted sum of predictions
            # s_j = Ξ£_i c_ij * u_hat_j|i
            s = torch.sum(c * u_hat, dim=1)  # [batch, num_higher, capsule_dim]
            
            # Squashing
            v = squash(s, dim=-1)  # [batch, num_higher, capsule_dim]
            
            # Update routing logits (except last iteration)
            if iteration < self.num_iterations - 1:
                # Agreement: u_hat_j|i Β· v_j
                # [batch, num_lower, num_higher, capsule_dim] Β· [batch, 1, num_higher, capsule_dim]
                v_expanded = v.unsqueeze(1)  # [batch, 1, num_higher, capsule_dim]
                agreement = torch.sum(u_hat * v_expanded, dim=-1, keepdim=True)
                
                # b_ij ← b_ij + agreement
                b = b + agreement
        
        return v


# ============================================================================
# 3. Primary Capsules
# ============================================================================

class PrimaryCaps(nn.Module):
    """
    Primary Capsule Layer.
    
    Converts convolutional feature maps into capsules.
    
    Args:
        in_channels: Number of input channels
        num_capsules: Number of capsule types
        capsule_dim: Dimension of each capsule
        kernel_size: Convolution kernel size
        stride: Convolution stride
    """
    def __init__(
        self,
        in_channels: int,
        num_capsules: int,
        capsule_dim: int,
        kernel_size: int = 9,
        stride: int = 2
    ):
        super().__init__()
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim
        
        # Convolutional layers for each capsule type
        self.capsules = nn.ModuleList([
            nn.Conv2d(
                in_channels,
                capsule_dim,
                kernel_size=kernel_size,
                stride=stride,
                padding=0
            )
            for _ in range(num_capsules)
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, in_channels, height, width]
            
        Returns:
            capsules: [batch, num_capsules * spatial_size, capsule_dim]
        """
        batch_size = x.size(0)
        
        # Apply each capsule convolution
        outputs = [capsule(x) for capsule in self.capsules]
        
        # Stack: [batch, num_capsules, capsule_dim, height, width]
        outputs = torch.stack(outputs, dim=1)
        
        # Reshape to [batch, num_capsules, capsule_dim, height * width]
        outputs = outputs.view(
            batch_size,
            self.num_capsules,
            self.capsule_dim,
            -1
        )
        
        # Transpose to [batch, num_capsules, height * width, capsule_dim]
        outputs = outputs.transpose(2, 3)
        
        # Flatten spatial dimensions
        # [batch, num_capsules * height * width, capsule_dim]
        outputs = outputs.contiguous().view(batch_size, -1, self.capsule_dim)
        
        # Apply squashing
        outputs = squash(outputs, dim=-1)
        
        return outputs


# ============================================================================
# 4. Digit Capsules (Routing Capsules)
# ============================================================================

class DigitCaps(nn.Module):
    """
    Digit Capsule Layer with dynamic routing.
    
    Routes from lower capsules to higher capsules.
    
    Args:
        num_lower: Number of lower-level capsules
        num_higher: Number of higher-level capsules
        lower_dim: Dimension of lower capsules
        higher_dim: Dimension of higher capsules
        num_routing: Number of routing iterations
    """
    def __init__(
        self,
        num_lower: int,
        num_higher: int,
        lower_dim: int,
        higher_dim: int,
        num_routing: int = 3
    ):
        super().__init__()
        self.num_lower = num_lower
        self.num_higher = num_higher
        self.lower_dim = lower_dim
        self.higher_dim = higher_dim
        
        # Transformation matrices W_ij for each pair (i, j)
        # [num_lower, num_higher, higher_dim, lower_dim]
        self.W = nn.Parameter(
            torch.randn(num_lower, num_higher, higher_dim, lower_dim)
        )
        
        # Dynamic routing
        self.routing = DynamicRouting(num_iterations=num_routing)
        
    def forward(self, u: torch.Tensor) -> torch.Tensor:
        """
        Args:
            u: Lower capsules [batch, num_lower, lower_dim]
            
        Returns:
            v: Higher capsules [batch, num_higher, higher_dim]
        """
        batch_size = u.size(0)
        
        # Compute prediction vectors u_hat_j|i = W_ij * u_i
        # Expand u: [batch, num_lower, 1, lower_dim, 1]
        u_expanded = u.unsqueeze(2).unsqueeze(-1)
        
        # Expand W: [1, num_lower, num_higher, higher_dim, lower_dim]
        W_expanded = self.W.unsqueeze(0)
        
        # Matrix multiply: [batch, num_lower, num_higher, higher_dim, 1]
        u_hat = torch.matmul(W_expanded, u_expanded)
        
        # Squeeze: [batch, num_lower, num_higher, higher_dim]
        u_hat = u_hat.squeeze(-1)
        
        # Dynamic routing
        v = self.routing(u_hat, batch_size)
        
        return v


# ============================================================================
# 5. CapsNet Architecture
# ============================================================================

class CapsNet(nn.Module):
    """
    Capsule Network for image classification.
    
    Architecture:
    1. Convolutional layer
    2. Primary capsules
    3. Digit capsules (with routing)
    4. Optional reconstruction network
    
    Args:
        input_channels: Number of input channels (1 for grayscale, 3 for RGB)
        num_classes: Number of output classes
        num_routing: Number of routing iterations
        use_reconstruction: Whether to include reconstruction network
    """
    def __init__(
        self,
        input_channels: int = 1,
        num_classes: int = 10,
        num_routing: int = 3,
        use_reconstruction: bool = True
    ):
        super().__init__()
        self.num_classes = num_classes
        self.use_reconstruction = use_reconstruction
        
        # Layer 1: Convolutional
        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=256,
            kernel_size=9,
            stride=1,
            padding=0
        )
        
        # Layer 2: Primary Capsules
        # 32 capsule types, each 8D
        self.primary_caps = PrimaryCaps(
            in_channels=256,
            num_capsules=32,
            capsule_dim=8,
            kernel_size=9,
            stride=2
        )
        
        # Calculate number of primary capsules
        # For 28x28 MNIST: (28-9+1) = 20, then (20-9+1)/2 = 6
        # 32 capsule types * 6 * 6 = 1152 capsules
        num_primary_caps = 32 * 6 * 6
        
        # Layer 3: Digit Capsules
        self.digit_caps = DigitCaps(
            num_lower=num_primary_caps,
            num_higher=num_classes,
            lower_dim=8,
            higher_dim=16,
            num_routing=num_routing
        )
        
        # Reconstruction network
        if use_reconstruction:
            self.decoder = nn.Sequential(
                nn.Linear(16 * num_classes, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, input_channels * 28 * 28),
                nn.Sigmoid()
            )
        
    def forward(
        self,
        x: torch.Tensor,
        y: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            x: Input images [batch, channels, height, width]
            y: Ground truth labels [batch] (for reconstruction masking)
            
        Returns:
            class_scores: Length of digit capsules [batch, num_classes]
            digit_caps: Digit capsule activations [batch, num_classes, 16]
            reconstruction: Reconstructed images [batch, C, H, W] or None
        """
        batch_size = x.size(0)
        
        # Convolutional layer
        x = F.relu(self.conv1(x))
        
        # Primary capsules
        primary = self.primary_caps(x)
        
        # Digit capsules
        digits = self.digit_caps(primary)
        
        # Class scores: Length of capsule vectors
        class_scores = torch.sqrt(torch.sum(digits ** 2, dim=-1))
        
        # Reconstruction
        reconstruction = None
        if self.use_reconstruction:
            # Mask: Keep only correct class (training) or predicted class (testing)
            if y is None:
                # Testing: Use predicted class
                _, max_indices = class_scores.max(dim=1)
                y = max_indices
            
            # Create mask
            mask = torch.zeros(batch_size, self.num_classes, device=x.device)
            mask[torch.arange(batch_size), y] = 1.0
            
            # Apply mask: [batch, num_classes, 16]
            masked = digits * mask.unsqueeze(-1)
            
            # Flatten
            masked_flat = masked.view(batch_size, -1)
            
            # Decode
            reconstruction = self.decoder(masked_flat)
            reconstruction = reconstruction.view(batch_size, -1, 28, 28)
        
        return class_scores, digits, reconstruction


# ============================================================================
# 6. Loss Functions
# ============================================================================

class MarginLoss(nn.Module):
    """
    Margin loss for capsule networks.
    
    L_k = T_k * max(0, m+ - ||v_k||)Β² + Ξ» * (1 - T_k) * max(0, ||v_k|| - m-)Β²
    
    Args:
        m_plus: Upper margin (default: 0.9)
        m_minus: Lower margin (default: 0.1)
        lambda_: Down-weighting for absent classes (default: 0.5)
    """
    def __init__(
        self,
        m_plus: float = 0.9,
        m_minus: float = 0.1,
        lambda_: float = 0.5
    ):
        super().__init__()
        self.m_plus = m_plus
        self.m_minus = m_minus
        self.lambda_ = lambda_
        
    def forward(
        self,
        class_scores: torch.Tensor,
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Args:
            class_scores: Length of capsules [batch, num_classes]
            labels: Ground truth [batch]
            
        Returns:
            loss: Margin loss
            loss_dict: Breakdown of loss components
        """
        batch_size = class_scores.size(0)
        num_classes = class_scores.size(1)
        
        # One-hot encoding
        labels_one_hot = F.one_hot(labels, num_classes).float()
        
        # Present class loss: T_k * max(0, m+ - ||v_k||)Β²
        loss_present = labels_one_hot * F.relu(self.m_plus - class_scores) ** 2
        
        # Absent class loss: Ξ» * (1 - T_k) * max(0, ||v_k|| - m-)Β²
        loss_absent = self.lambda_ * (1 - labels_one_hot) * F.relu(class_scores - self.m_minus) ** 2
        
        # Total
        loss = torch.sum(loss_present + loss_absent, dim=1).mean()
        
        loss_dict = {
            'margin': loss.item(),
            'present': torch.sum(loss_present).item() / batch_size,
            'absent': torch.sum(loss_absent).item() / batch_size
        }
        
        return loss, loss_dict


class ReconstructionLoss(nn.Module):
    """
    Reconstruction loss for capsule networks.
    
    L_recon = ||x - x_reconstructed||Β²
    """
    def __init__(self):
        super().__init__()
        
    def forward(
        self,
        reconstruction: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            reconstruction: Reconstructed images [batch, C, H, W]
            target: Original images [batch, C, H, W]
            
        Returns:
            loss: MSE reconstruction loss
        """
        loss = F.mse_loss(reconstruction, target, reduction='sum')
        loss = loss / target.size(0)  # Average over batch
        return loss


class CapsNetLoss(nn.Module):
    """
    Combined loss for CapsNet: Margin + Reconstruction.
    
    L_total = L_margin + Ξ± * L_recon
    
    Args:
        alpha: Weight for reconstruction loss (default: 0.0005)
    """
    def __init__(self, alpha: float = 0.0005):
        super().__init__()
        self.alpha = alpha
        self.margin_loss = MarginLoss()
        self.recon_loss = ReconstructionLoss()
        
    def forward(
        self,
        class_scores: torch.Tensor,
        labels: torch.Tensor,
        reconstruction: Optional[torch.Tensor] = None,
        target: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Args:
            class_scores: [batch, num_classes]
            labels: [batch]
            reconstruction: [batch, C, H, W] or None
            target: [batch, C, H, W] or None
            
        Returns:
            total_loss, loss_dict
        """
        # Margin loss
        margin_loss, margin_dict = self.margin_loss(class_scores, labels)
        
        # Reconstruction loss
        if reconstruction is not None and target is not None:
            recon_loss = self.recon_loss(reconstruction, target)
            total_loss = margin_loss + self.alpha * recon_loss
        else:
            recon_loss = 0.0
            total_loss = margin_loss
        
        loss_dict = {
            'total': total_loss.item(),
            'margin': margin_loss.item(),
            'recon': recon_loss if isinstance(recon_loss, float) else recon_loss.item(),
            **margin_dict
        }
        
        return total_loss, loss_dict


# ============================================================================
# 7. Demo and Visualization
# ============================================================================

def demo_squashing():
    """Demonstrate squashing activation."""
    print("=" * 80)
    print("Squashing Activation Demo")
    print("=" * 80)
    
    # Generate input vectors with varying magnitudes
    magnitudes = torch.linspace(0, 5, 100)
    directions = torch.randn(100, 8)
    directions = F.normalize(directions, p=2, dim=1)
    
    # Scale by magnitudes
    s = directions * magnitudes.unsqueeze(1)
    
    # Apply squashing
    v = squash(s, dim=1)
    
    # Compute output norms
    v_norms = torch.sqrt(torch.sum(v ** 2, dim=1))
    
    print(f"\nInput magnitudes: {magnitudes[:5].tolist()}")
    print(f"Output norms: {v_norms[:5].tolist()}")
    print(f"\nMax output norm: {v_norms.max().item():.4f} (should be < 1)")
    print(f"Output norm at ||s||=5: {v_norms[-1].item():.4f}")
    
    # Check direction preservation
    cosine_sim = F.cosine_similarity(s, v, dim=1)
    print(f"\nDirection preservation (cosine similarity): {cosine_sim.mean().item():.4f}")


def demo_dynamic_routing():
    """Demonstrate dynamic routing."""
    print("\n" + "=" * 80)
    print("Dynamic Routing Demo")
    print("=" * 80)
    
    batch_size = 4
    num_lower = 1152  # Primary capsules
    num_higher = 10   # Digit capsules
    lower_dim = 8
    higher_dim = 16
    
    # Create random prediction vectors
    u_hat = torch.randn(batch_size, num_lower, num_higher, higher_dim)
    
    # Apply routing
    routing = DynamicRouting(num_iterations=3)
    v = routing(u_hat, batch_size)
    
    print(f"\nInput shape (predictions): {u_hat.shape}")
    print(f"Output shape (higher capsules): {v.shape}")
    
    # Check output norms
    v_norms = torch.sqrt(torch.sum(v ** 2, dim=-1))
    print(f"\nOutput norms (should be < 1):")
    print(f"  Mean: {v_norms.mean().item():.4f}")
    print(f"  Max: {v_norms.max().item():.4f}")
    print(f"  Min: {v_norms.min().item():.4f}")


def demo_capsnet():
    """Demonstrate full CapsNet."""
    print("\n" + "=" * 80)
    print("CapsNet Architecture Demo")
    print("=" * 80)
    
    # Create model
    model = CapsNet(
        input_channels=1,
        num_classes=10,
        num_routing=3,
        use_reconstruction=True
    )
    
    # Generate dummy batch
    batch_size = 4
    images = torch.randn(batch_size, 1, 28, 28)
    labels = torch.randint(0, 10, (batch_size,))
    
    # Forward pass
    print(f"\nInput shape: {images.shape}")
    class_scores, digit_caps, reconstruction = model(images, labels)
    
    print(f"Class scores shape: {class_scores.shape}")
    print(f"Digit capsules shape: {digit_caps.shape}")
    print(f"Reconstruction shape: {reconstruction.shape if reconstruction is not None else None}")
    
    # Compute loss
    loss_fn = CapsNetLoss(alpha=0.0005)
    total_loss, loss_dict = loss_fn(class_scores, labels, reconstruction, images)
    
    print(f"\nLoss breakdown:")
    print(f"  Total: {loss_dict['total']:.4f}")
    print(f"  Margin: {loss_dict['margin']:.4f}")
    print(f"  Reconstruction: {loss_dict['recon']:.4f}")
    print(f"  Present class: {loss_dict['present']:.4f}")
    print(f"  Absent class: {loss_dict['absent']:.4f}")
    
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {num_params:,}")


def print_performance_comparison():
    """Print performance comparison."""
    print("\n" + "=" * 80)
    print("CAPSULE NETWORKS PERFORMANCE")
    print("=" * 80)
    
    comparison = """
    MNIST Classification:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method               β”‚ Test Acc (%) β”‚ Parameters   β”‚ Notes        β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Baseline CNN         β”‚ 99.70        β”‚ ~1M          β”‚ Standard     β”‚
    β”‚ CapsNet (no recon)   β”‚ 99.52        β”‚ 8.2M         β”‚ Routing only β”‚
    β”‚ CapsNet (+ recon)    β”‚ 99.75        β”‚ 8.2M + 6.8M  β”‚ Full model   β”‚
    β”‚ Ensemble CNN (5)     β”‚ 99.79        β”‚ ~5M          β”‚ 5 models     β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    SmallNORB (Viewpoint Variation):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method               β”‚ Test Acc (%) β”‚ Notes        β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ CNN Baseline         β”‚ 92.8         β”‚ ResNet-like  β”‚
    β”‚ Convolutional CNN    β”‚ 95.0         β”‚ Optimized    β”‚
    β”‚ CapsNet (3 routing)  β”‚ 97.3         β”‚ Equivariant  β”‚
    β”‚ Matrix Capsules      β”‚ 98.5         β”‚ EM routing   β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    MultiMNIST (Overlapping Digits):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method               β”‚ Both Correct β”‚ Notes        β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ CNN                  β”‚ ~80%         β”‚ Struggles    β”‚
    β”‚ CapsNet              β”‚ 95.0%        β”‚ Routing wins β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    CIFAR-10:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method               β”‚ Test Acc (%) β”‚ Notes        β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ ResNet-110           β”‚ 93.6         β”‚ Standard     β”‚
    β”‚ DenseNet-BC          β”‚ 95.5         β”‚ State-of-art β”‚
    β”‚ CapsNet (basic)      β”‚ 75.0         β”‚ Not scalable β”‚
    β”‚ Efficient CapsNet    β”‚ 82.0         β”‚ Optimized    β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Routing Iterations Impact (MNIST):
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Iterations   β”‚ Accuracy (%) β”‚ Time (ms)    β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ 1            β”‚ 99.23        β”‚ 15.2         β”‚
    β”‚ 3            β”‚ 99.75        β”‚ 18.7         β”‚
    β”‚ 5            β”‚ 99.76        β”‚ 23.1         β”‚
    β”‚ 7            β”‚ 99.75        β”‚ 28.4         β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Computational Complexity:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Operation            β”‚ Complexity                   β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Conv Layer           β”‚ O(HΒ·WΒ·CΒ·KΒ²)                  β”‚
    β”‚ Primary Caps         β”‚ O(HΒ·WΒ·CΒ·D)                   β”‚
    β”‚ Dynamic Routing      β”‚ O(rΒ·LΒ·HΒ·DΒ²)                  β”‚
    β”‚   r = iterations     β”‚ 3 typical                    β”‚
    β”‚   L = lower caps     β”‚ 1152 (MNIST)                 β”‚
    β”‚   H = higher caps    β”‚ 10 (MNIST)                   β”‚
    β”‚   D = capsule dim    β”‚ 16 (MNIST)                   β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Key Advantages:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ βœ“ Equivariance: Activities change with transforms  β”‚
    β”‚ βœ“ Part-whole: Explicit hierarchical modeling       β”‚
    β”‚ βœ“ Robustness: Resistant to adversarial attacks     β”‚
    β”‚ βœ“ Interpretability: Capsule dims have meaning      β”‚
    β”‚ βœ“ Overlapping: Segments overlapping objects        β”‚
    β”‚ βœ“ Viewpoint: Excellent on pose variation tasks     β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Key Disadvantages:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ βœ— Scalability: Expensive for high-res images       β”‚
    β”‚ βœ— Speed: 2-5Γ— slower than CNNs                     β”‚
    β”‚ βœ— Memory: Vector activations use more RAM          β”‚
    β”‚ βœ— Optimization: Less mature than CNNs              β”‚
    β”‚ βœ— Hardware: Not optimized for GPUs                 β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Decision Guide:
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Use CapsNets When:           β”‚ Use CNNs When:      β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ β€’ Viewpoint variation        β”‚ β€’ Large images      β”‚
    β”‚ β€’ Part-whole relationships   β”‚ β€’ Speed critical    β”‚
    β”‚ β€’ Overlapping objects        β”‚ β€’ Simple tasks      β”‚
    β”‚ β€’ Small datasets             β”‚ β€’ Well-established  β”‚
    β”‚ β€’ Interpretability needed    β”‚ β€’ ImageNet scale    β”‚
    β”‚ β€’ 3D understanding           β”‚ β€’ Production ready  β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    """
    
    print(comparison)
    
    print("\nKey Insights:")
    print("1. CapsNets excel on viewpoint variation (SmallNORB: 97.3% vs 95.0%)")
    print("2. Routing enables segmentation of overlapping objects (MultiMNIST: 95%)")
    print("3. 3 routing iterations optimal (diminishing returns after)")
    print("4. Reconstruction regularization improves accuracy (+0.23% on MNIST)")
    print("5. Computational cost: O(rΒ·LΒ·HΒ·DΒ²) limits scalability")
    print("6. Trade-off: Better representations vs. speed (2-5Γ— slower)")
    print("7. MNIST/SmallNORB: CapsNets competitive or better")
    print("8. CIFAR-10/ImageNet: CNNs still superior")
    print("9. Future: Hybrid architectures combining best of both")
    print("10. Best for: Pose, 3D, overlapping objects, interpretability")


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    demo_squashing()
    demo_dynamic_routing()
    demo_capsnet()
    print_performance_comparison()
    
    print("\n" + "=" * 80)
    print("Capsule Networks Implementations Complete!")
    print("=" * 80)

Advanced Capsule Networks: Mathematical Foundations and Modern ArchitecturesΒΆ

1. Introduction to Capsule NetworksΒΆ

Capsule Networks (CapsNets) represent a paradigm shift in neural network design, addressing fundamental limitations of CNNs through hierarchical part-whole relationships and equivariant representations.

Core motivation: CNNs use scalar neurons that lose spatial relationships between features. Capsules use vector neurons that explicitly encode pose (position, orientation, scale).

Key innovation (Hinton et al., 2017): Replace scalar activations with capsule vectors $\(\mathbf{v}_j = \text{squash}(\mathbf{s}_j) = \frac{\|\mathbf{s}_j\|^2}{1 + \|\mathbf{s}_j\|^2} \frac{\mathbf{s}_j}{\|\mathbf{s}_j\|}\)$

where:

  • \(\mathbf{v}_j \in \mathbb{R}^d\): Capsule output (length = probability, direction = properties)

  • \(\mathbf{s}_j\): Total input to capsule \(j\)

  • Squashing: Non-linear activation that preserves direction, scales length to \((0, 1)\)

Fundamental properties:

  1. Length: Probability that entity exists (\(0 \leq \|\mathbf{v}_j\| \leq 1\))

  2. Direction: Instantiation parameters (pose, texture, deformation)

  3. Equivariance: Activities change smoothly with viewpoint

2. Dynamic Routing Between CapsulesΒΆ

2.1 Routing-by-Agreement AlgorithmΒΆ

Goal: Route information from lower-level capsules to higher-level capsules based on agreement.

Prediction vectors: Lower capsule \(i\) predicts properties of higher capsule \(j\) $\(\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i\)$

where:

  • \(\mathbf{u}_i \in \mathbb{R}^{d_{\text{in}}}\): Output of capsule \(i\) in layer \(L\)

  • \(\mathbf{W}_{ij} \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}\): Transformation matrix (learnable)

  • \(\hat{\mathbf{u}}_{j|i} \in \mathbb{R}^{d_{\text{out}}}\): Prediction (what \(i\) thinks \(j\) should be)

Routing coefficients: Softmax over all capsules in layer \(L+1\) $\(c_{ij} = \frac{\exp(b_{ij})}{\sum_{k} \exp(b_{ik})}\)$

where \(b_{ij}\) are routing logits (updated iteratively, not learned).

Weighted sum: Input to capsule \(j\) $\(\mathbf{s}_j = \sum_{i} c_{ij} \hat{\mathbf{u}}_{j|i}\)$

Agreement: Update routing logits based on dot product $\(b_{ij} \leftarrow b_{ij} + \hat{\mathbf{u}}_{j|i} \cdot \mathbf{v}_j\)$

Intuition: If prediction \(\hat{\mathbf{u}}_{j|i}\) aligns with output \(\mathbf{v}_j\), increase \(c_{ij}\) (send more information).

2.2 Routing Algorithm (Detailed)ΒΆ

Input: \(\mathbf{u}_i\) for all capsules \(i\) in layer \(L\)
Output: \(\mathbf{v}_j\) for all capsules \(j\) in layer \(L+1\)

Procedure:

1. Initialize routing logits: b_ij ← 0 for all i, j
2. For r iterations:
    a. Compute coupling coefficients: c_ij ← softmax_j(b_ij)
    b. Compute weighted input: s_j ← Ξ£_i c_ij Γ»_{j|i}
    c. Apply squashing: v_j ← squash(s_j)
    d. Update routing logits: b_ij ← b_ij + Γ»_{j|i} Β· v_j
3. Return v_j

Typical: \(r = 3\) iterations (diminishing returns beyond).

2.3 Mathematical AnalysisΒΆ

Squashing function properties: $\(\text{squash}(\mathbf{s}) = \frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2} \frac{\mathbf{s}}{\|\mathbf{s}\|}\)$

Derivative (for backprop): $\(\frac{\partial \text{squash}(\mathbf{s})}{\partial \mathbf{s}} = \frac{2\|\mathbf{s}\|}{(1 + \|\mathbf{s}\|^2)^2} \mathbf{I} + \frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2} \left(\frac{1}{\|\mathbf{s}\|} \mathbf{I} - \frac{\mathbf{s}\mathbf{s}^T}{\|\mathbf{s}\|^3}\right)\)$

Length behavior:

  • \(\|\mathbf{s}\| \to 0\): \(\|\text{squash}(\mathbf{s})\| \approx \|\mathbf{s}\|\) (linear)

  • \(\|\mathbf{s}\| \to \infty\): \(\|\text{squash}(\mathbf{s})\| \to 1\) (saturates)

Routing as optimization: Iterative routing approximates EM algorithm

  • E-step: Compute \(c_{ij}\) (soft assignment)

  • M-step: Update \(\mathbf{v}_j\) (recompute means)

Connection to attention: \(c_{ij}\) similar to attention weights, but uses agreement instead of learned query/key.

3. Margin Loss for Digit ExistenceΒΆ

Motivation: Binary classification per capsule (entity present or not).

Margin loss (per class \(k\)): $\(\mathcal{L}_k = T_k \max(0, m^+ - \|\mathbf{v}_k\|)^2 + \lambda (1 - T_k) \max(0, \|\mathbf{v}_k\| - m^-)^2\)$

where:

  • \(T_k \in \{0, 1\}\): Ground truth (1 if class \(k\) present)

  • \(m^+ = 0.9\): Upper margin (capsule should be long if class present)

  • \(m^- = 0.1\): Lower margin (capsule should be short if class absent)

  • \(\lambda = 0.5\): Down-weighting for absent classes (prevent all capsules collapsing)

Interpretation:

  • If \(T_k = 1\): Penalize \(\|\mathbf{v}_k\| < 0.9\)

  • If \(T_k = 0\): Penalize \(\|\mathbf{v}_k\| > 0.1\) (but with weight \(0.5\))

Total loss: $\(\mathcal{L}_{\text{margin}} = \sum_{k=1}^K \mathcal{L}_k\)$

Multi-label: Allow multiple classes simultaneously (unlike softmax).

4. Reconstruction RegularizationΒΆ

Motivation: Encourage capsules to learn meaningful representations by reconstructing input.

Decoder: Fully connected network $\(\mathbf{x}_{\text{recon}} = \text{Decoder}(\mathbf{v}_{\text{class}})\)$

where \(\mathbf{v}_{\text{class}}\) is the capsule vector for the true class (masked during training).

Reconstruction loss: $\(\mathcal{L}_{\text{recon}} = \|\mathbf{x} - \mathbf{x}_{\text{recon}}\|^2\)$

Total loss: $\(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{margin}} + \alpha \mathcal{L}_{\text{recon}}\)$

where \(\alpha = 0.0005\) (small weight to avoid dominating margin loss).

Benefits:

  • Regularization (prevent overfitting)

  • Interpretability (manipulate capsule dimensions to see effect)

  • Ensures capsule encodes all information about entity

5. CapsNet Architecture (Original)ΒΆ

5.1 Architecture OverviewΒΆ

Layer 1: Convolutional Layer

  • Input: \(28 \times 28\) grayscale image

  • Filters: 256 filters, \(9 \times 9\) kernel, stride 1, ReLU

  • Output: \(20 \times 20 \times 256\)

Layer 2: PrimaryCaps

  • 32 primary capsules, each 8D

  • Convolution: \(9 \times 9\) kernel, stride 2

  • Reshape: \(6 \times 6 \times 32 \times 8\) β†’ \(1152 \times 8\) (flatten spatial)

  • Output: 1152 capsules, each 8D

Layer 3: DigitCaps

  • 10 digit capsules (one per class), each 16D

  • Routing: Dynamic routing from PrimaryCaps to DigitCaps

  • Total connections: \(1152 \times 10\) transformation matrices \(\mathbf{W}_{ij} \in \mathbb{R}^{16 \times 8}\)

  • Output: 10 capsules, each 16D

Layer 4: Decoder (Reconstruction)

  • Mask: Select capsule for true class

  • FC layers: 16 β†’ 512 β†’ 1024 β†’ 784 (28Γ—28)

  • Activation: ReLU (except output sigmoid)

5.2 Parameter CountΒΆ

Conv1: \((9 \times 9 \times 1 + 1) \times 256 = 20,992\)

PrimaryCaps:

  • Each of 32 capsules has 8 conv filters

  • \((9 \times 9 \times 256 + 1) \times 8 \times 32 = 5,308,672\)

DigitCaps (routing weights):

  • \(1152 \times 10 \times (8 \times 16) = 1,474,560\)

Decoder:

  • \(16 \times 512 = 8,192\)

  • \(512 \times 1024 = 524,288\)

  • \(1024 \times 784 = 802,816\)

Total: ~8.2M parameters (comparable to small CNN).

6. Improvements and VariantsΒΆ

6.1 EM Routing (Hinton et al., 2018)ΒΆ

Motivation: Replace iterative routing with EM algorithm for Gaussian mixture models.

Capsule as Gaussian: Capsule \(j\) represents \(\mathcal{N}(\boldsymbol{\mu}_j, \boldsymbol{\sigma}_j)\)

E-step: Compute assignment probabilities $\(r_{ij} = \frac{a_i p(\mathbf{v}_i | j)}{\sum_k a_k p(\mathbf{v}_i | k)}\)$

where:

  • \(a_i\): Activation probability of capsule \(i\)

  • \(p(\mathbf{v}_i | j)\): Gaussian likelihood

M-step: Update mean and variance $\(\boldsymbol{\mu}_j = \frac{\sum_i r_{ij} \mathbf{v}_i}{\sum_i r_{ij}}, \quad \boldsymbol{\sigma}_j^2 = \frac{\sum_i r_{ij} (\mathbf{v}_i - \boldsymbol{\mu}_j)^2}{\sum_i r_{ij}}\)$

Activation: Based on total probability mass $\(a_j = \text{logistic}(\lambda(\beta_a - \sum_h (\beta_h + \log \sigma_h)))\)$

Advantages:

  • Probabilistic framework

  • Better gradient flow

  • Faster convergence

Results: 45% fewer parameters, SOTA on smallNORB (2.7% error).

6.2 Self-Attention RoutingΒΆ

Key idea: Replace iterative routing with single-pass self-attention.

Attention weights: $\(c_{ij} = \frac{\exp(\mathbf{q}_j^T \mathbf{k}_i / \sqrt{d})}{\sum_{k} \exp(\mathbf{q}_j^T \mathbf{k}_k / \sqrt{d})}\)$

where:

  • \(\mathbf{q}_j = \mathbf{W}_Q \mathbf{v}_j\): Query from higher capsule

  • \(\mathbf{k}_i = \mathbf{W}_K \mathbf{u}_i\): Key from lower capsule

Advantages:

  • Parallelizable (no iterative updates)

  • Compatible with Transformers

  • Faster inference

6.3 Stacked Capsule Autoencoders (SCAE)ΒΆ

Part-Capsules: Encode object parts (position, presence)

Object-Capsules: Encode whole objects (pose, presence)

Set Transformer: Route parts to objects via permutation-invariant attention $\(\mathbf{v}_{\text{obj}} = \text{SetTransformer}(\{\mathbf{v}_{\text{part}}^{(i)}\}_{i=1}^N)\)$

Unsupervised learning: Constellation prior (object = set of parts)

Results: Better unsupervised segmentation, interpretable representations.

6.4 Capsule Graph Neural NetworksΒΆ

Graph structure: Capsules as nodes, routing as edges

Message passing: $\(\mathbf{h}_j^{(t+1)} = \text{Update}\left(\mathbf{h}_j^{(t)}, \sum_{i \in \mathcal{N}(j)} \text{Message}(\mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)})\right)\)$

Advantages:

  • Flexible graph topology

  • Longer-range dependencies

  • Adaptable to irregular data (point clouds, molecules)

7. Equivariance and InvarianceΒΆ

7.1 Equivariance in Capsule NetworksΒΆ

Definition: Transformation of input β†’ equivalent transformation of representation $\(f(T(x)) = T'(f(x))\)$

CNNs: Translation equivariant (but pooling breaks it).

CapsNets: Aim for viewpoint equivariance

  • Rotation of input β†’ rotation of capsule pose parameters

  • Not achieved by original CapsNet (requires coordinate addition)

7.2 Coordinate Addition (Hinton, 2021)ΒΆ

Idea: Explicitly represent coordinates in capsule vectors.

Implementation:

  • Capsule \(= [\text{pose}, \text{activation}]\)

  • Pose includes \((x, y, \theta)\) coordinates

  • Apply affine transformations to pose

Benefits:

  • True equivariance (not just approximately)

  • Systematic generalization to novel viewpoints

7.3 Steerable CapsulesΒΆ

Group equivariance: \(G\)-equivariance (e.g., \(G = SO(2)\) for rotations)

Steerable filters: Filters transform predictably under group actions $\(T_g \star f = (T_g f) \star (T_g \psi)\)$

Application to capsules:

  • Capsule features are \(G\)-equivariant

  • Routing preserves equivariance

Results: Better rotation robustness, data efficiency.

8. Applications and ResultsΒΆ

8.1 MNISTΒΆ

Original CapsNet (Sabour et al., 2017):

  • Test error: 0.25% (SOTA among non-augmented)

  • Reconstruction: High-quality digit reconstruction

  • Robustness: 79% accuracy on affNIST (CNNs ~66%)

EM Routing:

  • Test error: 0.21%

  • Fewer parameters (45% reduction)

8.2 CIFAR-10ΒΆ

Challenges: More complex, natural images

Results:

  • CapsNet (3-layer): 10.6% error (vs 6.5% ResNet)

  • Deep CapsNet: 8.5% error (with augmentation)

Observation: Gap with CNNs larger for complex datasets (architectural improvements needed).

8.3 smallNORBΒΆ

Dataset: 3D object recognition, jittered stereo images

EM Routing: 2.7% error (SOTA)

  • CNNs: ~3-5% error

  • Viewpoint generalization: Superior to CNNs

Insight: Capsules excel when viewpoint variation is key.

8.4 Adversarial RobustnessΒΆ

Hinton et al. (2018): CapsNets more robust to adversarial attacks than CNNs.

White-box attacks (FGSM, PGD):

  • CapsNet: 65-70% accuracy under attack

  • CNN: 40-50% accuracy

Reason: Distributed representations less susceptible to small perturbations.

8.5 Object SegmentationΒΆ

SCAE (Kosiorek et al., 2019):

  • Unsupervised object discovery

  • Pixel-level segmentation without labels

  • State-of-art on CLEVR, Shapestacks

Part-whole decomposition: Natural for segmentation.

9. Theoretical AnalysisΒΆ

9.1 Routing ConvergenceΒΆ

Theorem (informal): Dynamic routing converges to local optimum of clustering objective.

Proof sketch:

  • Define energy: \(E = -\sum_{ij} c_{ij} \hat{\mathbf{u}}_{j|i} \cdot \mathbf{v}_j\)

  • Each routing iteration decreases \(E\)

  • Converges to local minimum (depends on initialization)

Empirical: 3 iterations sufficient (diminishing improvements).

9.2 Expressive PowerΒΆ

Theorem: Capsule networks are universal approximators (with sufficient capsules and dimensions).

Comparison to CNNs:

  • CNNs: Spatial pooling loses information

  • CapsNets: Preserve pose information (more expressive)

Caveat: Practical expressiveness depends on routing quality.

9.3 Sample ComplexityΒΆ

Hypothesis: Capsules reduce sample complexity by exploiting part-whole structure.

Evidence:

  • Better few-shot learning (Meta-CapsNet)

  • Superior generalization on smallNORB (3D geometry)

Theoretical gap: Formal bounds lacking (open problem).

10. Computational ComplexityΒΆ

10.1 Time ComplexityΒΆ

Forward pass:

  • Conv layers: \(O(K^2 C_{\text{in}} C_{\text{out}} H W)\) (same as CNN)

  • Prediction: \(O(N_L N_{L+1} d_{\text{in}} d_{\text{out}})\) (matrix multiplications)

  • Routing: \(O(r N_L N_{L+1} d_{\text{out}})\) (iterative updates, \(r\) iterations)

Total: \(O(r N_L N_{L+1} d_{\text{out}})\) dominates (quadratic in number of capsules).

Comparison:

  • CNN: \(O(K^2 C H W)\) per layer

  • CapsNet: \(O(r N^2 d)\) for routing

Bottleneck: All-to-all routing (needs sparsification for large networks).

10.2 Space ComplexityΒΆ

Parameters:

  • Transformation matrices: \(N_L \times N_{L+1} \times d_{\text{in}} \times d_{\text{out}}\)

Activations (during training):

  • Capsule outputs: \(N \times d\) per layer

  • Routing coefficients: \(N_L \times N_{L+1}\) (not stored with gradient checkpointing)

Comparison: 2-5Γ— more parameters than equivalent CNN (due to transformation matrices).

10.3 Optimization ChallengesΒΆ

Gradient flow: Routing iterations create deep computational graph

  • Backprop through routing: \(O(r)\) depth

  • Can cause vanishing/exploding gradients

Solutions:

  • Gradient clipping

  • Careful initialization

  • Shorter routing (1-3 iterations)

11. Limitations and Open ProblemsΒΆ

11.1 ScalabilityΒΆ

Problem: Quadratic cost in number of capsules limits scaling.

Current: Works well for <10K capsules, struggles beyond.

Solutions:

  • Sparse routing (route only to subset)

  • Hierarchical capsules (local routing)

  • Attention-based routing (single-pass)

11.2 Deep CapsNet ArchitecturesΒΆ

Challenge: Stacking many capsule layers unstable.

Issues:

  • Gradient flow through multiple routing layers

  • Increased computational cost

  • Diminishing returns (benefits plateau)

State-of-art: 3-5 capsule layers (vs 100+ for ResNets).

11.3 Lack of Theoretical UnderstandingΒΆ

Open questions:

  1. Why does routing-by-agreement work?

  2. Optimal number of routing iterations?

  3. Formal expressiveness compared to CNNs?

  4. Sample complexity bounds?

  5. Provable robustness guarantees?

Current: Mostly empirical understanding.

11.4 Performance Gap on Complex DatasetsΒΆ

ImageNet: CapsNets lag behind SOTA CNNs/Transformers.

Reasons:

  • Architectural maturity (CNNs: 30+ years, CapsNets: <10)

  • Computational cost limits depth

  • Engineering optimizations (fewer for CapsNets)

Hope: Architectural innovations (e.g., Transformer-CapsNet hybrids) may close gap.

12. Recent Advances (2020-2024)ΒΆ

12.1 Efficient CapsNetsΒΆ

Depthwise Capsules: Reduce parameters via depthwise-separable idea

  • Factorize \(\mathbf{W}_{ij}\) into smaller matrices

  • 3-5Γ— parameter reduction, similar accuracy

Inverted Capsules: Inspired by MobileNets

  • Expand β†’ route β†’ project

  • Better parameter efficiency

12.2 Vision Transformers + CapsulesΒΆ

Capsule Transformers: Replace self-attention with capsule routing

  • Query/Key/Value β†’ capsule predictions

  • Routing β†’ attention weights

Benefits:

  • Part-whole relationships in Transformers

  • Better interpretability

Results: Competitive on ImageNet (75-78% top-1).

12.3 Equivariant CapsulesΒΆ

E(n)-equivariant capsules: Equivariant to Euclidean group

  • Rotation, translation, reflection equivariance

  • Applications: 3D point clouds, molecular graphs

Tensor Field Networks: Capsules with tensor features

  • Higher-order equivariance (beyond vectors)

12.4 Capsules for NLPΒΆ

Text Capsules: Route word embeddings to sentence capsules

Applications:

  • Sentiment analysis (better than RNNs on some datasets)

  • Intent detection

  • Multi-label text classification

Challenges: Less impact than in vision (Transformers dominate).

13. Implementation Best PracticesΒΆ

13.1 InitializationΒΆ

Routing logits: Initialize to zero (\(b_{ij} = 0\))

Transformation matrices: Xavier/He initialization $\(\mathbf{W}_{ij} \sim \mathcal{N}\left(0, \frac{2}{d_{\text{in}} + d_{\text{out}}}\right)\)$

Decoder: Standard FC initialization

13.2 Training HyperparametersΒΆ

Optimizer: Adam with \(\beta_1 = 0.9\), \(\beta_2 = 0.999\)

Learning rate:

  • Initial: \(10^{-3}\)

  • Decay: Exponential decay (0.96 every epoch) or cosine annealing

Batch size: 128 (larger if memory permits)

Routing iterations: 3 (default, tune if needed)

Reconstruction weight: \(\alpha = 0.0005\) (small to avoid dominating)

Margin loss:

  • \(m^+ = 0.9\)

  • \(m^- = 0.1\)

  • \(\lambda = 0.5\)

13.3 Debugging TipsΒΆ

Check capsule lengths: Should be in \((0, 1)\) after squashing.

Monitor routing coefficients: Should converge (stabilize) after 2-3 iterations.

Reconstruction quality: Decoder should produce recognizable images (qualitative check).

Gradient norms: Clip if exploding (max norm 5-10).

13.4 RegularizationΒΆ

Dropout: On capsule outputs (not within routing)

  • Rate: 0.2-0.3

Weight decay: L2 regularization

  • Coefficient: \(10^{-4}\)

Data augmentation: Random crops, flips (standard CV)

14. Comparison with Other ArchitecturesΒΆ

14.1 CapsNets vs CNNsΒΆ

Aspect

CapsNets

CNNs

Neuron type

Vector (pose + probability)

Scalar (activation)

Pooling

No pooling (routing)

Max/Avg pooling (info loss)

Equivariance

Viewpoint (intended)

Translation

Invariance

Learned (via routing)

Built-in (pooling)

Parameters

More (transformation matrices)

Fewer (shared filters)

Interpretability

High (capsule = entity)

Low (features opaque)

Adversarial robustness

Better

Worse

Scalability

Limited (quadratic routing)

Excellent

Performance (ImageNet)

Moderate (75-78%)

SOTA (85-90%)

Recommendation: CapsNets for interpretability, viewpoint robustness; CNNs for general-purpose SOTA.

14.2 CapsNets vs TransformersΒΆ

Aspect

CapsNets

Transformers

Routing

Agreement-based

Attention-based

Complexity

\(O(N^2 d)\)

\(O(N^2 d)\)

Position encoding

Implicit (capsule pose)

Explicit (positional)

Part-whole

Native

Requires design

Interpretability

High

Moderate (attention maps)

Scalability

Limited

Excellent (efficient variants)

Performance

Moderate

SOTA

Synergy: Capsule-Transformer hybrids combine strengths.

14.3 CapsNets vs Graph Neural NetworksΒΆ

Aspect

CapsNets

GNNs

Structure

Fixed layers

Arbitrary graphs

Routing

Dynamic

Message passing

Applications

Vision (primarily)

Graphs, molecules, social

Equivariance

Viewpoint

Graph isomorphism

Overlap: Capsule GNNs merge both paradigms.

15. Summary and Future DirectionsΒΆ

15.1 Key TakeawaysΒΆ

Innovations:

  1. Vector neurons: Encode entity properties (pose, probability)

  2. Dynamic routing: Route by agreement (not max-pooling)

  3. Margin loss: Multi-label classification

  4. Reconstruction: Regularization + interpretability

Advantages:

  • Better viewpoint robustness (smallNORB)

  • Adversarial robustness (65-70% under attack)

  • Interpretability (capsule dimensions manipulable)

  • Part-whole relationships (natural for segmentation)

Challenges:

  • Scalability (quadratic routing)

  • Performance gap on complex datasets (ImageNet)

  • Deep architectures unstable

  • Theoretical understanding limited

15.2 Open ProblemsΒΆ

  1. Efficient routing: Reduce \(O(N^2)\) complexity

  2. Deep CapsNets: Stable training for 10+ layers

  3. Theoretical foundations: Sample complexity, expressiveness bounds

  4. Large-scale datasets: Close gap with CNNs/Transformers on ImageNet

  5. Equivariance: Provable viewpoint equivariance

  6. Hybrid architectures: Capsules + Transformers/GNNs

15.3 Future DirectionsΒΆ

Technical:

  • Sparse routing (attend to subset)

  • Hierarchical capsules (tree structure)

  • Learned routing algorithms (meta-learning)

  • Continuous capsules (neural ODEs)

Applications:

  • 3D vision (point clouds, mesh processing)

  • Medical imaging (part-based anatomy)

  • Robotics (viewpoint-invariant perception)

  • Scientific discovery (molecule generation)

Theoretical:

  • Formalize routing as optimization

  • Prove expressiveness theorems

  • Sample complexity bounds

  • Connection to causality (part-whole = causal structure)

16. ConclusionΒΆ

Capsule Networks represent a conceptually elegant alternative to CNNs, addressing fundamental issues:

  • Spatial relationships: Preserved via vector representations

  • Viewpoint robustness: Improved generalization across views

  • Interpretability: Capsule dimensions encode meaningful properties

Current state (2024):

  • Proven on small-to-medium datasets (MNIST, smallNORB)

  • Architectural innovations ongoing (Transformers, equivariance)

  • Performance gap with CNNs narrowing but not closed

Verdict: Capsule Networks are a promising research direction with practical applications in specialized domains. Not yet ready to replace CNNs/Transformers for general-purpose vision, but offer unique advantages for problems where part-whole relationships and viewpoint equivariance are critical.

Hinton’s vision: β€œCNNs are doomed” (because they discard spatial info). Whether capsules fulfill this prophecy remains an open question, but they have undeniably enriched our understanding of neural representations.

"""
Advanced Capsule Networks - Production Implementation
Comprehensive PyTorch implementations with dynamic routing and modern variants
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List

# ===========================
# 1. Squashing Function
# ===========================

def squash(s: torch.Tensor, dim: int = -1, epsilon: float = 1e-8) -> torch.Tensor:
    """
    Squashing non-linearity: v = ||s||Β² / (1 + ||s||Β²) * s / ||s||
    
    Args:
        s: Input tensor (B, ..., D)
        dim: Dimension to compute norm over
        epsilon: Small constant for numerical stability
    
    Returns:
        Squashed tensor with same shape as input
    """
    squared_norm = torch.sum(s ** 2, dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    unit = s / torch.sqrt(squared_norm + epsilon)
    return scale * unit


# ===========================
# 2. Dynamic Routing
# ===========================

class DynamicRouting(nn.Module):
    """
    Dynamic routing by agreement (Sabour et al., 2017)
    Routes information from lower capsules to higher capsules
    """
    
    def __init__(self, 
                 in_capsules: int,
                 out_capsules: int,
                 in_dim: int,
                 out_dim: int,
                 num_iterations: int = 3):
        super().__init__()
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_iterations = num_iterations
        
        # Transformation matrices W_ij: (in_capsules, out_capsules, out_dim, in_dim)
        self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
        nn.init.kaiming_normal_(self.W)
    
    def forward(self, u: torch.Tensor) -> torch.Tensor:
        """
        Args:
            u: (batch_size, in_capsules, in_dim) - lower-level capsule outputs
        
        Returns:
            v: (batch_size, out_capsules, out_dim) - higher-level capsule outputs
        """
        batch_size = u.size(0)
        
        # Compute predictions Γ»_{j|i} = W_ij @ u_i
        # u: (B, in_caps, in_dim) -> (B, in_caps, 1, in_dim, 1)
        # W: (in_caps, out_caps, out_dim, in_dim) -> (1, in_caps, out_caps, out_dim, in_dim)
        # Γ»: (B, in_caps, out_caps, out_dim)
        u_expanded = u[:, :, None, :, None]  # (B, in_caps, 1, in_dim, 1)
        W_expanded = self.W[None, :, :, :, :]  # (1, in_caps, out_caps, out_dim, in_dim)
        
        u_hat = torch.matmul(W_expanded, u_expanded).squeeze(-1)  # (B, in_caps, out_caps, out_dim)
        
        # Initialize routing logits b_ij to zero
        b = torch.zeros(batch_size, self.in_capsules, self.out_capsules, 
                       device=u.device)  # (B, in_caps, out_caps)
        
        # Iterative routing
        for iteration in range(self.num_iterations):
            # Softmax over output capsules: c_ij = softmax_j(b_ij)
            c = F.softmax(b, dim=2)  # (B, in_caps, out_caps)
            
            # Weighted sum: s_j = Ξ£_i c_ij Γ»_{j|i}
            c_expanded = c[:, :, :, None]  # (B, in_caps, out_caps, 1)
            s = torch.sum(c_expanded * u_hat, dim=1)  # (B, out_caps, out_dim)
            
            # Squashing: v_j = squash(s_j)
            v = squash(s, dim=-1)  # (B, out_caps, out_dim)
            
            # Update routing logits: b_ij += Γ»_{j|i} Β· v_j (except last iteration)
            if iteration < self.num_iterations - 1:
                v_expanded = v[:, None, :, :]  # (B, 1, out_caps, out_dim)
                agreement = torch.sum(u_hat * v_expanded, dim=-1)  # (B, in_caps, out_caps)
                b = b + agreement
        
        return v


# ===========================
# 3. Primary Capsules
# ===========================

class PrimaryCaps(nn.Module):
    """
    Primary capsule layer: convolutional layer reshaped into capsules
    """
    
    def __init__(self,
                 in_channels: int,
                 num_capsules: int,
                 capsule_dim: int,
                 kernel_size: int = 9,
                 stride: int = 2):
        super().__init__()
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim
        
        # Convolutional layer: output has num_capsules * capsule_dim channels
        self.conv = nn.Conv2d(in_channels, 
                             num_capsules * capsule_dim,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W) input feature map
        
        Returns:
            capsules: (B, num_capsules * H' * W', capsule_dim)
        """
        # Convolution
        out = self.conv(x)  # (B, num_capsules * capsule_dim, H', W')
        
        batch_size, _, height, width = out.shape
        
        # Reshape to capsules: (B, num_capsules, capsule_dim, H', W')
        out = out.view(batch_size, self.num_capsules, self.capsule_dim, height, width)
        
        # Flatten spatial dimensions: (B, num_capsules * H' * W', capsule_dim)
        out = out.permute(0, 1, 3, 4, 2).contiguous()
        out = out.view(batch_size, -1, self.capsule_dim)
        
        # Squash
        return squash(out, dim=-1)


# ===========================
# 4. Digit Capsules (with Routing)
# ===========================

class DigitCaps(nn.Module):
    """
    Digit capsule layer with dynamic routing
    """
    
    def __init__(self,
                 in_capsules: int,
                 in_dim: int,
                 out_capsules: int,
                 out_dim: int,
                 num_iterations: int = 3):
        super().__init__()
        self.routing = DynamicRouting(in_capsules, out_capsules, in_dim, out_dim, num_iterations)
    
    def forward(self, u: torch.Tensor) -> torch.Tensor:
        """
        Args:
            u: (B, in_capsules, in_dim)
        
        Returns:
            v: (B, out_capsules, out_dim)
        """
        return self.routing(u)


# ===========================
# 5. Reconstruction Decoder
# ===========================

class Decoder(nn.Module):
    """
    Decoder network for reconstruction regularization
    """
    
    def __init__(self,
                 capsule_dim: int = 16,
                 num_classes: int = 10,
                 img_size: int = 28,
                 hidden_dims: List[int] = [512, 1024]):
        super().__init__()
        self.capsule_dim = capsule_dim
        self.num_classes = num_classes
        self.img_size = img_size
        
        # Build FC layers
        layers = []
        in_dim = capsule_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU(inplace=True))
            in_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(in_dim, img_size * img_size))
        layers.append(nn.Sigmoid())
        
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, v: torch.Tensor, labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            v: (B, num_classes, capsule_dim) - digit capsule outputs
            labels: (B,) - ground truth labels for masking (None during inference)
        
        Returns:
            reconstruction: (B, img_size, img_size)
        """
        batch_size = v.size(0)
        
        if labels is not None:
            # Training: mask out all but the true class
            mask = torch.zeros_like(v)
            mask[torch.arange(batch_size), labels] = 1.0
            v_masked = v * mask
        else:
            # Inference: use capsule with largest length
            lengths = torch.sqrt(torch.sum(v ** 2, dim=-1))  # (B, num_classes)
            max_idx = torch.argmax(lengths, dim=1)  # (B,)
            mask = torch.zeros_like(v)
            mask[torch.arange(batch_size), max_idx] = 1.0
            v_masked = v * mask
        
        # Flatten: (B, num_classes * capsule_dim) -> (B, capsule_dim) after masking
        v_flat = v_masked.view(batch_size, -1)
        
        # Decode
        reconstruction = self.decoder(v_flat)
        reconstruction = reconstruction.view(batch_size, self.img_size, self.img_size)
        
        return reconstruction


# ===========================
# 6. CapsNet (Complete Architecture)
# ===========================

class CapsNet(nn.Module):
    """
    Complete Capsule Network architecture (Sabour et al., 2017)
    """
    
    def __init__(self,
                 img_channels: int = 1,
                 img_size: int = 28,
                 num_classes: int = 10,
                 primary_caps: int = 32,
                 primary_dim: int = 8,
                 digit_dim: int = 16,
                 num_routing: int = 3):
        super().__init__()
        
        self.img_channels = img_channels
        self.img_size = img_size
        self.num_classes = num_classes
        
        # Layer 1: Convolutional layer
        self.conv1 = nn.Conv2d(img_channels, 256, kernel_size=9, stride=1, padding=0)
        
        # Conv output size: (28 - 9 + 1) = 20
        conv_output_size = img_size - 8
        
        # Layer 2: Primary capsules
        self.primary_caps = PrimaryCaps(256, primary_caps, primary_dim, kernel_size=9, stride=2)
        
        # Primary caps output size: (20 - 9) / 2 + 1 = 6
        primary_output_size = (conv_output_size - 8) // 2
        num_primary_capsules = primary_caps * primary_output_size * primary_output_size
        
        # Layer 3: Digit capsules
        self.digit_caps = DigitCaps(num_primary_capsules, primary_dim, 
                                   num_classes, digit_dim, num_routing)
        
        # Layer 4: Decoder
        self.decoder = Decoder(digit_dim, num_classes, img_size)
    
    def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (B, C, H, W) input images
            labels: (B,) ground truth labels (for reconstruction masking)
        
        Returns:
            digit_caps: (B, num_classes, digit_dim)
            reconstruction: (B, H, W)
        """
        # Conv layer
        x = F.relu(self.conv1(x))  # (B, 256, 20, 20)
        
        # Primary capsules
        primary = self.primary_caps(x)  # (B, 1152, 8) for MNIST
        
        # Digit capsules
        digits = self.digit_caps(primary)  # (B, 10, 16)
        
        # Reconstruction
        reconstruction = self.decoder(digits, labels)  # (B, 28, 28)
        
        return digits, reconstruction


# ===========================
# 7. Margin Loss
# ===========================

class MarginLoss(nn.Module):
    """
    Margin loss for multi-label classification in CapsNets
    """
    
    def __init__(self,
                 m_plus: float = 0.9,
                 m_minus: float = 0.1,
                 lambda_: float = 0.5):
        super().__init__()
        self.m_plus = m_plus
        self.m_minus = m_minus
        self.lambda_ = lambda_
    
    def forward(self, v: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Args:
            v: (B, num_classes, capsule_dim) - digit capsule outputs
            labels: (B,) - ground truth labels
        
        Returns:
            loss: scalar
        """
        batch_size, num_classes, _ = v.shape
        
        # Capsule lengths (probabilities)
        lengths = torch.sqrt(torch.sum(v ** 2, dim=-1))  # (B, num_classes)
        
        # One-hot encoding
        targets = F.one_hot(labels, num_classes).float()  # (B, num_classes)
        
        # Margin loss per class
        # L_k = T_k max(0, m+ - ||v_k||)Β² + Ξ»(1 - T_k) max(0, ||v_k|| - m-)Β²
        positive_loss = targets * F.relu(self.m_plus - lengths) ** 2
        negative_loss = self.lambda_ * (1 - targets) * F.relu(lengths - self.m_minus) ** 2
        
        loss = torch.sum(positive_loss + negative_loss, dim=1)
        
        return loss.mean()


# ===========================
# 8. CapsNet Trainer
# ===========================

class CapsNetTrainer:
    """Training utilities for CapsNet"""
    
    def __init__(self,
                 model: CapsNet,
                 margin_loss: MarginLoss,
                 recon_weight: float = 0.0005,
                 device: str = 'cuda'):
        self.model = model.to(device)
        self.margin_loss = margin_loss
        self.recon_weight = recon_weight
        self.device = device
    
    def train_step(self, x: torch.Tensor, labels: torch.Tensor, 
                   optimizer: torch.optim.Optimizer) -> dict:
        """Single training step"""
        self.model.train()
        optimizer.zero_grad()
        
        x = x.to(self.device)
        labels = labels.to(self.device)
        
        # Forward
        digit_caps, reconstruction = self.model(x, labels)
        
        # Margin loss
        loss_margin = self.margin_loss(digit_caps, labels)
        
        # Reconstruction loss
        loss_recon = F.mse_loss(reconstruction, x.squeeze(1))
        
        # Total loss
        loss_total = loss_margin + self.recon_weight * loss_recon
        
        # Backward
        loss_total.backward()
        optimizer.step()
        
        # Accuracy
        lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
        pred = torch.argmax(lengths, dim=1)
        acc = (pred == labels).float().mean()
        
        return {
            'loss_total': loss_total.item(),
            'loss_margin': loss_margin.item(),
            'loss_recon': loss_recon.item(),
            'accuracy': acc.item()
        }
    
    @torch.no_grad()
    def eval_step(self, x: torch.Tensor, labels: torch.Tensor) -> dict:
        """Evaluation step"""
        self.model.eval()
        
        x = x.to(self.device)
        labels = labels.to(self.device)
        
        # Forward
        digit_caps, reconstruction = self.model(x, labels)
        
        # Losses
        loss_margin = self.margin_loss(digit_caps, labels)
        loss_recon = F.mse_loss(reconstruction, x.squeeze(1))
        loss_total = loss_margin + self.recon_weight * loss_recon
        
        # Accuracy
        lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
        pred = torch.argmax(lengths, dim=1)
        acc = (pred == labels).float().mean()
        
        return {
            'loss_total': loss_total.item(),
            'loss_margin': loss_margin.item(),
            'loss_recon': loss_recon.item(),
            'accuracy': acc.item()
        }


# ===========================
# 9. EM Routing (Advanced)
# ===========================

class EMRouting(nn.Module):
    """
    EM routing with Gaussian mixture models (Hinton et al., 2018)
    More principled than dynamic routing
    """
    
    def __init__(self,
                 in_capsules: int,
                 out_capsules: int,
                 in_dim: int,
                 out_dim: int,
                 num_iterations: int = 3,
                 temperature: float = 1.0):
        super().__init__()
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_iterations = num_iterations
        self.temperature = temperature
        
        # Transformation matrices
        self.W = nn.Parameter(torch.randn(in_capsules, out_capsules, out_dim, in_dim))
        nn.init.kaiming_normal_(self.W)
        
        # Learnable parameters for Gaussian (optional)
        self.beta_a = nn.Parameter(torch.zeros(1))
        self.beta_u = nn.Parameter(torch.zeros(out_capsules, out_dim))
    
    def forward(self, u: torch.Tensor, activation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            u: (B, in_capsules, in_dim) - pose vectors
            activation: (B, in_capsules) - activation probabilities
        
        Returns:
            v: (B, out_capsules, out_dim) - output poses
            a: (B, out_capsules) - output activations
        """
        batch_size = u.size(0)
        
        # Predictions
        u_expanded = u[:, :, None, :, None]
        W_expanded = self.W[None, :, :, :, :]
        u_hat = torch.matmul(W_expanded, u_expanded).squeeze(-1)  # (B, in_caps, out_caps, out_dim)
        
        # Initialize routing probabilities uniformly
        r = torch.ones(batch_size, self.in_capsules, self.out_capsules, 
                      device=u.device) / self.out_capsules
        
        # EM iterations
        for iteration in range(self.num_iterations):
            # M-step: Compute mean and variance
            r_weighted = r * activation[:, :, None]  # (B, in_caps, out_caps)
            r_sum = torch.sum(r_weighted, dim=1, keepdim=True) + 1e-8  # (B, 1, out_caps)
            
            # Mean: ΞΌ_j = Ξ£_i r_ij Γ»_{j|i} / Ξ£_i r_ij
            mu = torch.sum(r_weighted[:, :, :, None] * u_hat, dim=1) / r_sum  # (B, out_caps, out_dim)
            
            # Variance: σ²_j = Ξ£_i r_ij (Γ»_{j|i} - ΞΌ_j)Β² / Ξ£_i r_ij
            diff = u_hat - mu[:, None, :, :]  # (B, in_caps, out_caps, out_dim)
            variance = torch.sum(r_weighted[:, :, :, None] * diff ** 2, dim=1) / r_sum
            variance = variance + 0.01  # Minimum variance for stability
            
            # Activation: Based on cost (simplified)
            cost = torch.sum(variance, dim=-1)  # (B, out_caps)
            a = torch.sigmoid(-cost + self.beta_a)  # (B, out_caps)
            
            # E-step: Update routing probabilities (except last iteration)
            if iteration < self.num_iterations - 1:
                # Log probability of Γ»_{j|i} under Gaussian(ΞΌ_j, σ²_j)
                log_p = -0.5 * torch.sum((diff ** 2) / variance[:, None, :, :], dim=-1)  # (B, in_caps, out_caps)
                log_p = log_p - 0.5 * torch.sum(torch.log(variance[:, None, :, :] + 1e-8), dim=-1)
                
                # Weighted by activation
                log_p = log_p + torch.log(a[:, None, :] + 1e-8)
                
                # Softmax
                r = F.softmax(log_p / self.temperature, dim=2)
        
        return mu, a


# ===========================
# 10. Demo Functions
# ===========================

def demo_squash():
    """Demonstrate squashing function"""
    print("=" * 50)
    print("Demo: Squashing Function")
    print("=" * 50)
    
    # Test squashing behavior
    s = torch.tensor([[0.1, 0.2], [1.0, 2.0], [10.0, 20.0]])
    v = squash(s, dim=-1)
    
    print(f"Input vectors:")
    print(s)
    print(f"\nNorms: {torch.norm(s, dim=-1)}")
    
    print(f"\nSquashed vectors:")
    print(v)
    print(f"Squashed norms: {torch.norm(v, dim=-1)}")
    print(f"All norms in (0,1): {torch.all((torch.norm(v, dim=-1) > 0) & (torch.norm(v, dim=-1) < 1))}")
    print()


def demo_dynamic_routing():
    """Demonstrate dynamic routing"""
    print("=" * 50)
    print("Demo: Dynamic Routing")
    print("=" * 50)
    
    routing = DynamicRouting(in_capsules=1152, out_capsules=10, 
                            in_dim=8, out_dim=16, num_iterations=3)
    
    u = torch.randn(4, 1152, 8)  # Batch of 4, 1152 primary capsules
    v = routing(u)
    
    print(f"Input: {u.shape} (batch_size, in_capsules, in_dim)")
    print(f"Output: {v.shape} (batch_size, out_capsules, out_dim)")
    print(f"Capsule lengths: {torch.norm(v, dim=-1)[0]}")
    print(f"Transformation matrices: {routing.W.shape}")
    num_params = routing.W.numel()
    print(f"Parameters: {num_params:,} ({num_params * 4 / 1024**2:.2f} MB)")
    print()


def demo_capsnet_forward():
    """Demonstrate full CapsNet forward pass"""
    print("=" * 50)
    print("Demo: CapsNet Forward Pass")
    print("=" * 50)
    
    model = CapsNet(img_channels=1, img_size=28, num_classes=10)
    x = torch.randn(4, 1, 28, 28)
    labels = torch.randint(0, 10, (4,))
    
    digit_caps, reconstruction = model(x, labels)
    
    print(f"Input: {x.shape}")
    print(f"Digit capsules: {digit_caps.shape}")
    print(f"Reconstruction: {reconstruction.shape}")
    
    # Capsule lengths (predictions)
    lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
    predictions = torch.argmax(lengths, dim=1)
    
    print(f"\nCapsule lengths (probabilities):")
    print(lengths[0])
    print(f"Predicted classes: {predictions}")
    print(f"True labels: {labels}")
    
    # Parameter count
    num_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {num_params:,} (~{num_params * 4 / 1024**2:.1f} MB)")
    print()


def demo_margin_loss():
    """Demonstrate margin loss"""
    print("=" * 50)
    print("Demo: Margin Loss")
    print("=" * 50)
    
    loss_fn = MarginLoss(m_plus=0.9, m_minus=0.1, lambda_=0.5)
    
    # Simulate capsule outputs
    digit_caps = torch.randn(4, 10, 16)
    labels = torch.tensor([3, 7, 1, 0])
    
    loss = loss_fn(digit_caps, labels)
    
    print(f"Digit capsules: {digit_caps.shape}")
    print(f"Labels: {labels}")
    print(f"Margin loss: {loss.item():.4f}")
    
    # Show capsule lengths
    lengths = torch.sqrt(torch.sum(digit_caps ** 2, dim=-1))
    print(f"\nCapsule lengths:")
    print(lengths)
    print(f"Target class lengths: {lengths[torch.arange(4), labels]}")
    print()


def demo_training_step():
    """Demonstrate training step"""
    print("=" * 50)
    print("Demo: Training Step")
    print("=" * 50)
    
    device = 'cpu'
    model = CapsNet(img_channels=1, img_size=28, num_classes=10)
    margin_loss = MarginLoss()
    trainer = CapsNetTrainer(model, margin_loss, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # Dummy data
    x = torch.randn(8, 1, 28, 28)
    labels = torch.randint(0, 10, (8,))
    
    print("Training for 3 steps...")
    for step in range(3):
        metrics = trainer.train_step(x, labels, optimizer)
        print(f"Step {step+1}: loss={metrics['loss_total']:.4f}, "
              f"margin={metrics['loss_margin']:.4f}, "
              f"recon={metrics['loss_recon']:.4f}, "
              f"acc={metrics['accuracy']:.2%}")
    print()


def demo_em_routing():
    """Demonstrate EM routing"""
    print("=" * 50)
    print("Demo: EM Routing")
    print("=" * 50)
    
    routing = EMRouting(in_capsules=1152, out_capsules=10,
                       in_dim=8, out_dim=16, num_iterations=3)
    
    u = torch.randn(4, 1152, 8)
    activation = torch.rand(4, 1152)
    
    mu, a = routing(u, activation)
    
    print(f"Input poses: {u.shape}")
    print(f"Input activations: {activation.shape}")
    print(f"Output poses (means): {mu.shape}")
    print(f"Output activations: {a.shape}")
    print(f"\nOutput activation values (sample): {a[0]}")
    print()


def print_performance_comparison():
    """Comprehensive performance comparison and decision guide"""
    print("=" * 80)
    print("PERFORMANCE COMPARISON: Capsule Networks")
    print("=" * 80)
    
    # 1. MNIST Results
    print("\n1. MNIST Results (Test Error %)")
    print("-" * 80)
    data = [
        ("Model", "Error %", "Parameters", "Notes"),
        ("-" * 30, "-" * 10, "-" * 12, "-" * 30),
        ("Baseline CNN", "0.39", "~35K", "Simple CNN baseline"),
        ("CapsNet (3 routing)", "0.25", "8.2M", "Original (Sabour 2017)"),
        ("CapsNet (1 routing)", "0.31", "8.2M", "Faster, slight quality loss"),
        ("EM Routing CapsNet", "0.21", "4.5M", "45% fewer params (Hinton 2018)"),
        ("", "", "", ""),
        ("SOTA (non-augmented)", "0.17", "N/A", "Ensemble methods"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<10} {row[2]:<12} {row[3]:<30}")
    
    # 2. affNIST (Viewpoint Robustness)
    print("\n2. affNIST - Viewpoint Robustness (Accuracy %)")
    print("-" * 80)
    data = [
        ("Model", "Accuracy", "Notes"),
        ("-" * 35, "-" * 10, "-" * 35),
        ("CNN (baseline)", "66%", "Poor generalization to affine transforms"),
        ("CNN (expanded training)", "79%", "With affine augmentation"),
        ("CapsNet (no augmentation)", "79%", "Naturally robust to viewpoint changes"),
        ("", "", ""),
        ("Observation:", "", "CapsNets achieve CNN+augmentation performance"),
        ("", "", "without needing explicit augmentation"),
    ]
    for row in data:
        print(f"{row[0]:<35} {row[1]:<10} {row[2]:<35}")
    
    # 3. smallNORB (3D Objects)
    print("\n3. smallNORB - 3D Object Recognition (Test Error %)")
    print("-" * 80)
    data = [
        ("Model", "Error %", "Notes"),
        ("-" * 40, "-" * 10, "-" * 35),
        ("CNN baseline", "5.2", ""),
        ("Convolutional Deep Belief Net", "4.5", ""),
        ("CapsNet (dynamic routing)", "3.1", "Better viewpoint invariance"),
        ("EM Routing CapsNet", "2.7", "SOTA (Hinton 2018)"),
        ("", "", ""),
        ("Key:", "", "3D geometry β†’ capsules excel"),
    ]
    for row in data:
        print(f"{row[0]:<40} {row[1]:<10} {row[2]:<35}")
    
    # 4. CIFAR-10
    print("\n4. CIFAR-10 - Natural Images (Test Error %)")
    print("-" * 80)
    data = [
        ("Model", "Error %", "Parameters", "Notes"),
        ("-" * 30, "-" * 10, "-" * 12, "-" * 30),
        ("CapsNet (3-layer)", "10.6", "11.8M", "Original architecture"),
        ("Deep CapsNet", "8.5", "~20M", "Deeper architecture + augmentation"),
        ("", "", "", ""),
        ("ResNet-56", "6.5", "850K", "Standard CNN (fewer params)"),
        ("ResNet-110", "5.5", "1.7M", "Deeper CNN"),
        ("Wide ResNet-28-10", "3.9", "36M", "SOTA CNN"),
        ("", "", "", ""),
        ("Gap:", "", "", "CapsNets lag on complex natural images"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<10} {row[2]:<12} {row[3]:<30}")
    
    # 5. Adversarial Robustness
    print("\n5. Adversarial Robustness (Accuracy under attack %)")
    print("-" * 80)
    data = [
        ("Model", "Clean", "FGSM", "PGD", "Notes"),
        ("-" * 25, "-" * 8, "-" * 8, "-" * 8, "-" * 25),
        ("CNN (baseline)", "99%", "45%", "40%", "Vulnerable to adversarial"),
        ("CapsNet", "99.2%", "68%", "65%", "More robust (distributed repr)"),
        ("Adversarial Training CNN", "98%", "85%", "80%", "Needs explicit defense"),
        ("", "", "", "", ""),
        ("Insight:", "", "", "", "CapsNets naturally more robust"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<8} {row[2]:<8} {row[3]:<8} {row[4]:<25}")
    
    # 6. Computational Complexity
    print("\n6. Computational Complexity")
    print("-" * 80)
    data = [
        ("Operation", "Complexity", "Notes"),
        ("-" * 35, "-" * 25, "-" * 30),
        ("Conv layers", "O(KΒ²Β·CΒ·HΒ·W)", "Same as standard CNN"),
        ("Primary Caps", "O(KΒ²Β·CΒ·HΒ·W)", "Convolution + reshape"),
        ("Dynamic Routing", "O(rΒ·N_inΒ·N_outΒ·d)", "r=3 iterations typical"),
        ("Transformation Matrices", "O(N_inΒ·N_outΒ·dΒ²)", "Largest parameter cost"),
        ("", "", ""),
        ("Bottleneck:", "Routing (quadratic)", "Limits scalability"),
        ("Total (MNIST)", "~10Γ— slower than CNN", "Per forward pass"),
    ]
    for row in data:
        print(f"{row[0]:<35} {row[1]:<25} {row[2]:<30}")
    
    # 7. Training Hyperparameters
    print("\n7. Recommended Training Hyperparameters")
    print("-" * 80)
    data = [
        ("Parameter", "MNIST", "CIFAR-10", "Notes"),
        ("-" * 25, "-" * 15, "-" * 15, "-" * 30),
        ("Batch size", "128", "128-256", "Larger if memory allows"),
        ("Learning rate", "1e-3", "1e-3", "Adam optimizer"),
        ("LR decay", "Exponential 0.96", "Cosine", "Every epoch or cycle"),
        ("", "", "", ""),
        ("Routing iterations", "3", "3", "Diminishing returns beyond"),
        ("Reconstruction weight", "0.0005", "0.0005", "Small to avoid dominance"),
        ("", "", "", ""),
        ("m+ (upper margin)", "0.9", "0.9", "Target for present classes"),
        ("m- (lower margin)", "0.1", "0.1", "Target for absent classes"),
        ("Ξ» (absent weight)", "0.5", "0.5", "Down-weight absent classes"),
        ("", "", "", ""),
        ("Primary capsules", "32 Γ— 8D", "32 Γ— 8D", "Standard config"),
        ("Digit capsules", "10 Γ— 16D", "10 Γ— 16D", "One per class"),
        ("", "", "", ""),
        ("Training epochs", "50-100", "200-300", "Until convergence"),
        ("Gradient clipping", "5.0", "5.0", "Max norm for stability"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<15} {row[2]:<15} {row[3]:<30}")
    
    # 8. Architecture Variants
    print("\n8. Capsule Network Variants")
    print("-" * 80)
    data = [
        ("Variant", "Key Innovation", "Pros", "Cons"),
        ("-" * 20, "-" * 30, "-" * 25, "-" * 25),
        ("Dynamic Routing", "Routing by agreement", "Simple, effective", "Iterative (slow)"),
        ("EM Routing", "Gaussian mixture model", "45% fewer params", "More complex"),
        ("Self-Attention", "Attention-based routing", "Parallelizable", "Less interpretable"),
        ("SCAE", "Unsupervised learning", "Object discovery", "Training unstable"),
        ("Capsule GNN", "Graph message passing", "Irregular data", "Architecture complex"),
    ]
    for row in data:
        print(f"{row[0]:<20} {row[1]:<30} {row[2]:<25} {row[3]:<25}")
    
    # 9. Use Case Decision Guide
    print("\n9. DECISION GUIDE: When to Use Capsule Networks")
    print("=" * 80)
    
    print("\nβœ“ USE Capsule Networks When:")
    advantages = [
        "β€’ Viewpoint variation is critical (3D objects, poses)",
        "β€’ Need interpretable part-whole representations",
        "β€’ Adversarial robustness important (intrinsically more robust)",
        "β€’ Dataset has limited augmentation (affine transforms)",
        "β€’ Object segmentation/discovery required (unsupervised)",
        "β€’ Small-to-medium datasets (MNIST, smallNORB)",
        "β€’ Multi-label classification (margin loss advantage)",
    ]
    for adv in advantages:
        print(adv)
    
    print("\nβœ— AVOID Capsule Networks When:")
    limitations = [
        "β€’ Need SOTA on ImageNet-scale datasets (CNNs/Transformers better)",
        "β€’ Real-time inference critical (10Γ— slower than CNNs)",
        "β€’ Very deep architectures required (stability issues)",
        "β€’ Large number of capsules needed (>10K, quadratic cost)",
        "β€’ Limited computational budget (more expensive training)",
    ]
    for lim in limitations:
        print(lim)
    
    print("\n→ RECOMMENDED ALTERNATIVES:")
    alternatives = [
        "β€’ General vision β†’ ResNet, EfficientNet, Vision Transformers",
        "β€’ Fast inference β†’ MobileNets, SqueezeNet",
        "β€’ Adversarial robustness β†’ Adversarial training, certified defenses",
        "β€’ Interpretability β†’ Attention mechanisms, GradCAM",
    ]
    for alt in alternatives:
        print(alt)
    
    # 10. Comparison with Other Architectures
    print("\n10. Comparison: CapsNets vs CNNs vs Transformers")
    print("=" * 80)
    data = [
        ("Aspect", "CapsNets", "CNNs", "Transformers"),
        ("-" * 25, "-" * 20, "-" * 20, "-" * 20),
        ("Neuron type", "Vector (pose)", "Scalar (activation)", "Embedding vectors"),
        ("Spatial relations", "Preserved (routing)", "Lost (pooling)", "Positional encoding"),
        ("Equivariance", "Viewpoint (goal)", "Translation", "None (learned)"),
        ("Interpretability", "High (part-whole)", "Low", "Moderate (attention)"),
        ("Scalability", "Limited (NΒ²)", "Excellent", "Good (NΒ² in seq)"),
        ("Training stability", "Moderate", "Good", "Good (with tricks)"),
        ("Adversarial robust", "Better", "Worse", "Moderate"),
        ("Parameter count", "High", "Low-Medium", "Very High"),
        ("ImageNet accuracy", "75-78%", "85-90%", "88-92%"),
        ("", "", "", ""),
        ("Best for:", "3D, viewpoint", "General vision", "Large-scale, NLP"),
    ]
    for row in data:
        print(f"{row[0]:<25} {row[1]:<20} {row[2]:<20} {row[3]:<20}")
    
    # 11. Troubleshooting
    print("\n11. TROUBLESHOOTING COMMON ISSUES")
    print("=" * 80)
    data = [
        ("Problem", "Possible Cause", "Solution"),
        ("-" * 30, "-" * 35, "-" * 40),
        ("Poor convergence", "Learning rate too high", "Reduce LR to 1e-4 or 1e-5"),
        ("", "Routing iterations wrong", "Try 1-3 iterations"),
        ("", "", ""),
        ("Exploding gradients", "Routing depth", "Gradient clipping (max_norm=5)"),
        ("", "Bad initialization", "Kaiming/Xavier init"),
        ("", "", ""),
        ("All capsules same", "Reconstruction weight too high", "Reduce Ξ± to 0.0005 or lower"),
        ("", "Margin loss not working", "Check m+, m-, Ξ» values"),
        ("", "", ""),
        ("Slow training", "Too many routing iters", "Use 1-2 iterations (faster)"),
        ("", "Large model", "Reduce primary caps or dims"),
        ("", "", ""),
        ("Blurry reconstruction", "Decoder too small", "Add hidden layers"),
        ("", "Weight too small", "Increase Ξ± to 0.001"),
        ("", "", ""),
        ("Low accuracy", "Insufficient training", "Train longer (50-100 epochs)"),
        ("", "Poor architecture", "Try EM routing or variants"),
    ]
    for row in data:
        print(f"{row[0]:<30} {row[1]:<35} {row[2]:<40}")
    
    print("\n" + "=" * 80)
    print("Summary: Capsule Networks offer interpretable part-whole representations")
    print("and viewpoint robustness, but lag on complex datasets. Best for 3D objects,")
    print("adversarial robustness, and when interpretability is critical.")
    print("=" * 80)
    print()


# ===========================
# Run All Demos
# ===========================

if __name__ == "__main__":
    print("\n" + "=" * 80)
    print("CAPSULE NETWORKS - COMPREHENSIVE IMPLEMENTATION")
    print("=" * 80 + "\n")
    
    demo_squash()
    demo_dynamic_routing()
    demo_capsnet_forward()
    demo_margin_loss()
    demo_training_step()
    demo_em_routing()
    print_performance_comparison()
    
    print("\n" + "=" * 80)
    print("All demos completed successfully!")
    print("=" * 80)