import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
Advanced Neural Architecture Search TheoryΒΆ
1. Search Space Design and FundamentalsΒΆ
Architecture Search Problem: Given a search space \(\mathcal{A}\) of candidate architectures, dataset \(\mathcal{D}\), and performance metric \(p\), find:
Search Space Taxonomy:
Space Type |
Description |
Size |
Complexity |
|---|---|---|---|
Chain-structured |
Sequential layers |
\(O(n^d)\) where d=depth |
Linear |
Cell-based |
Modular cells |
\(|O|^{N_e}\) where \(N_e\)=edges |
Exponential |
Hierarchical |
Multi-level |
\(O((|O| \cdot N)^L)\) |
Super-exponential |
Network morphism |
Topology mutation |
Infinite |
Continuous |
Cell-based Search (NASNet, DARTS):
Normal cell: Keeps spatial dimensions
Reduction cell: Downsamples features
Stack pattern: \(N_{normal} + 1_{reduction}\)
Total architecture: \((N_n + N_r) \times C\) where C=cells/block
Search Space Size: For DARTS with:
\(N=7\) nodes per cell
\(\|O\|=8\) operations
Each node connects to 2 previous nodes
2. DARTS: Complete Mathematical FrameworkΒΆ
2.1 Continuous Relaxation
Discrete operation selection becomes continuous mixing:
Where:
\(\alpha^{(i,j)} \in \mathbb{R}^{|\mathcal{O}|}\): Architecture parameters for edge (i,j)
\(\mathcal{O}\): Operation set {conv3x3, conv5x5, maxpool3x3, skip, none}
Softmax ensures \(\sum_o p_o = 1\) (valid probability distribution)
2.2 Bi-level Optimization Problem
Challenges:
Inner optimization: Training weights \(w\) to convergence is expensive
Implicit gradient: \(\nabla_\alpha \mathcal{L}_{val}\) requires \(\partial w^*/\partial \alpha\)
Memory: Storing full computation graph for second-order derivatives
2.3 Gradient Approximation
First-order approximation (using chain rule):
Where:
\(w' = w - \xi \nabla_w \mathcal{L}_{train}(w, \alpha)\) (one-step lookahead)
\(\xi\): Learning rate
Second term: Hessian-vector product \(\nabla_\alpha \nabla_w \mathcal{L}_{train} \cdot \nabla_{w'} \mathcal{L}_{val}\)
Computational complexity:
Forward pass: \(O(N \cdot |\mathcal{O}|)\) where N=edges
Backward pass: \(O(N \cdot |\mathcal{O}|) + O(|w|)\) for Hessian-vector product
Memory: \(O(|\mathcal{O}| \cdot \text{activations})\) during search
Second-order Hessian computation (finite difference):
For vector \(v = \nabla_{w'} \mathcal{L}_{val}(w', \alpha)\) and \(\epsilon = 0.01 / \|v\|_2\)
3. Architecture Derivation and DiscretizationΒΆ
3.1 Pruning Strategy
After continuous search, extract discrete architecture:
For each intermediate node \(i\):
Compute operation strengths: \(\bar{\alpha}_o^{(i,j)} = \sum_j \alpha_o^{(i,j)}\) (aggregate over edges)
Select top-k operations: \(O_i = \text{topk}(\{\bar{\alpha}_o^{(i,j)}\}_o, k=2)\)
Retain top-k edges per node based on \(\max_o \alpha_o^{(i,j)}\)
Genotype encoding:
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# normal: [(op1, from_node1), (op2, from_node2), ...]
# normal_concat: [4, 5, 6] # nodes to concatenate for output
3.2 One-shot vs. Progressive Search
Approach |
Strategy |
Pros |
Cons |
|---|---|---|---|
One-shot (DARTS) |
Single supernet, simultaneous search |
Fast, simple |
High memory, coupling |
Progressive (ENAS) |
Sequential decisions, controller-based |
Lower memory |
Slower, credit assignment |
Weight sharing |
Train once, sample many |
Very fast |
Inaccurate ranking |
4. Advanced NAS VariantsΒΆ
4.1 PC-DARTS (Partial Channel Connections)
Problem: DARTS memory grows as \(O(|\mathcal{O}| \cdot C \cdot H \cdot W)\)
Solution: Sample subset of channels during search
Where:
\(S(x)\): Channel sampling function (sample \(C/K\) channels, K=4 typical)
\(\beta\): Edge normalization factor \(\beta^{(i,j)} = \sum_{o} \alpha_o^{(i,j)}\)
Memory reduction: \(K\times\) less than DARTS (4Γ typical)
Stability: Regularizes over-parameterization, prevents dominance of skip connections
4.2 FairNAS (Fairness-aware Search)
Motivation: Dominant operations (skip) receive most gradient updates
Fairness loss:
Combined objective:
Encourages exploration of diverse operations during search
4.3 GDAS (Gumbel-Softmax Differentiable Architecture Search)
Gumbel-Max trick for discrete sampling:
Where \(g_o \sim \text{Gumbel}(0,1) = -\log(-\log(u))\), \(u \sim \text{Uniform}(0,1)\)
Gumbel-Softmax (differentiable relaxation):
Where \(\tau\): Temperature (annealed from 1 β 0 during search)
Advantages:
Single-path during forward (memory \(O(1)\) vs DARTS \(O(|\mathcal{O}|)\))
Stochastic exploration reduces bias toward skip connections
Better approximation to discrete search
5. Zero-cost Proxies and PredictorsΒΆ
5.1 Training-Free Metrics
Evaluate architecture quality without training:
SNIP (Single-shot Network Pruning):
Saliency score measures parameter importance
GraSP (Gradient Signal Preservation):
Hessian trace approximates loss curvature
NASWOT (NAS Without Training):
Kernel measure of network expressivity
5.2 Performance Prediction
Early stopping correlation: Train for T epochs, predict final accuracy at epoch T_max:
Typical: \(\rho > 0.6\) for T=12 epochs predicting T_max=200
Learning curve extrapolation: Fit power law: \(\text{Error}(t) = a \cdot t^{-b} + c\)
Predict final performance from early training trajectory
6. Efficient Search StrategiesΒΆ
6.1 Supernet Training (Single Path One-Shot)
Weight sharing: Train single βsupernetβ containing all candidate operations
Sampling strategy:
Uniform: Sample each architecture with equal probability
Prioritized: Sample based on predicted performance
Advantage: Train once (\(O(1)\)), evaluate many (\(O(N)\))
Limitation: Weight sharing introduces ranking disorder (correlation \(\rho \approx 0.3\)-0.6 with standalone training)
6.2 Evolutionary Algorithms
Population-based search:
1. Initialize population P = {a_1, ..., a_n}
2. For generation g = 1 to G:
a. Evaluate fitness f(a_i) for all a_i β P
b. Select parents (tournament/roulette)
c. Crossover: Mix architectures
d. Mutate: Random operation changes
e. Update population P
3. Return best architecture a* = argmax f(a_i)
AmoebaNet mutations:
Add/remove layers
Change operation types
Modify filter sizes
Alter connections
Complexity: \(O(P \cdot G \cdot T_{train})\) where P=population, G=generations, T=training time
6.3 Reinforcement Learning (NASNet)
Controller RNN generates architectures:
REINFORCE gradient:
Variance reduction:
Baseline: \(b = \text{EMA}(R)\) (exponential moving average)
Advantage: \(A(a) = R(a) - b\)
PPO update (more stable than REINFORCE):
7. Hardware-Aware and Multi-Objective NASΒΆ
7.1 Latency Prediction
Lookup table approach: Build table: \(T(o, C, H, W, device)\) = latency of operation o
Differentiable latency loss:
Where \(p_o = \text{softmax}(\alpha_o)\) are operation probabilities
7.2 Multi-Objective Optimization
Pareto frontier: Find architectures a such that no aβ exists with:
\(\text{Acc}(a') > \text{Acc}(a)\) AND
\(\text{Latency}(a') < \text{Latency}(a)\) (or FLOPs, params, energy)
Scalarization:
NSGA-II (Non-dominated Sorting Genetic Algorithm):
Rank architectures by Pareto dominance
Use crowding distance for diversity
Select based on rank and crowding
8. Theoretical Guarantees and GeneralizationΒΆ
8.1 Search-Evaluation Gap
Problem: Architecture performs worse after re-training from scratch
Causes:
Weight co-adaptation: Supernet weights biased by weight sharing
Different optimization landscape: Search (few epochs) vs. evaluation (full training)
Overfitting to search data: Architecture specializes to validation split
Solution: Independent train/search/test splits
8.2 Sample Complexity
PAC bound for architecture search:
With probability \(1 - \delta\), the true error satisfies:
Where:
\(|\mathcal{A}|\): Search space size
\(n_{val}\): Validation set size
Implication: Larger search spaces require more validation data to avoid overfitting
8.3 Expressivity vs. Trainability Trade-off
Expressivity: Ability to represent complex functions Trainability: Ease of optimization (gradient flow, convergence)
Observation: Skip connections improve trainability but reduce search difficulty (trivial solution)
Regularization strategies:
Dropout on skip connections
Path dropout (stochastic depth)
Architecture weight decay: \(\mathcal{L}_{reg} = \|\alpha\|_2^2\)
9. Practical ConsiderationsΒΆ
9.1 Hyperparameter Sensitivity
Hyperparameter |
Typical Range |
Effect |
|---|---|---|
Learning rate (Ξ±) |
3e-4 to 1e-3 |
Stability of architecture parameters |
Learning rate (w) |
2.5e-2 to 1e-1 |
Speed of weight convergence |
Weight decay |
3e-4 to 1e-3 |
Prevent overfitting |
Batch size |
64 to 256 |
Memory vs. gradient quality |
Search epochs |
50 to 100 |
Search thoroughness |
9.2 Computational Budget
DARTS: ~4 GPU-days (1 V100) ENAS: ~0.5 GPU-days AmoebaNet: ~3150 GPU-days (450 K40 GPUs) NASNet: ~1800 GPU-days
Cost reduction strategies:
Early stopping based on learning curve
Weight sharing (one-shot)
Zero-cost proxies (no training)
Warm-starting from smaller search
9.3 Transferability
Search on proxy task, transfer to target:
Proxy |
Target |
Transfer Success |
|---|---|---|
CIFAR-10 |
ImageNet |
β High (NASNet, AmoebaNet) |
Small dataset |
Large dataset |
β Moderate |
Different domain |
Target domain |
β Low |
Requirements for transfer:
Similar data distribution
Comparable task complexity
Shared inductive biases
10. State-of-the-Art Methods (2020-2024)ΒΆ
10.1 Once-for-All (OFA) Networks
Progressive shrinking:
Train full network (largest architecture)
Progressively train sub-networks (elastic depth, width, kernel)
Final supernet supports \(10^{19}\) architectures
Deployment: Extract architecture matching device constraints (latency, memory)
10.2 AutoFormer (Transformer NAS)
Search space: Vision Transformer components
Number of heads
MLP expansion ratio
Embedding dimensions
Depth
Supernet training: Weight entanglement for parameter sharing across architectures
10.3 Neural Architecture Transfer (NAT)
Idea: Transfer learned architecture generators across tasks
Meta-controller: Learns to generate architectures conditioned on task
11. Open Research QuestionsΒΆ
Why does DARTS prefer skip connections?
Hypothesis: Gradient flow, easier optimization, overfitting to search data
Mitigation: Regularization, early stopping, fairness constraints
How to improve weight sharing ranking correlation?
Current: \(\rho \approx 0.3\)-0.6
Goal: \(\rho > 0.8\) for reliable ranking
Approaches: Better training, calibration, progressive training
Can we eliminate the search-evaluation gap?
Problem: 1-5% accuracy drop after re-training
Potential: Direct optimization, better initialization, transfer learning
How to handle diverse hardware efficiently?
Challenge: Device-specific latency models
Direction: Universal latency predictors, neural latency models
What are fundamental limits of architecture search?
Theory: Sample complexity, generalization bounds
Practice: Minimum search budget, maximum transferability
Key Papers and TimelineΒΆ
Foundation (2016-2018):
Zoph & Le 2017: NASNet - RL-based search, cell-based space
Pham et al. 2018: ENAS - Efficient weight sharing
Real et al. 2019: AmoebaNet - Evolutionary search
Differentiable Search (2018-2019):
Liu et al. 2019: DARTS - Gradient-based, bi-level optimization
Xu et al. 2020: PC-DARTS - Partial channel, stability
Dong & Yang 2019: GDAS - Gumbel-softmax sampling
Efficiency (2019-2021):
Cai et al. 2020: Once-for-All - Progressive shrinking, diverse hardware
Chen et al. 2021: AutoFormer - Transformer search
White et al. 2021: NASWOT - Zero-cost proxy
Multi-Objective (2020-2022):
Tan et al. 2019: EfficientNet - Compound scaling
Chu et al. 2021: FairNAS - Fairness-aware search
Wu et al. 2019: FBNet - Hardware-aware differentiable
Implementation Complexity:
Basic DARTS: ~300 lines (PyTorch)
Full NAS framework: ~2000 lines (search space, controller, evaluator)
Production system: ~10,000+ lines (distributed training, experiment management)
1. DARTS: Differentiable Architecture SearchΒΆ
Continuous RelaxationΒΆ
Mixed operation:
Bi-level OptimizationΒΆ
π Reference Materials:
foundation_neural_network.pdf - Foundation Neural Network
cnn_beyond.pdf - Cnn Beyond
# Define primitive operations
OPS = {
'none': lambda C: Zero(),
'skip': lambda C: Identity(),
'conv_3x3': lambda C: Conv(C, C, 3, 1, 1),
'conv_5x5': lambda C: Conv(C, C, 5, 1, 2),
'pool_3x3': lambda C: nn.MaxPool2d(3, 1, 1),
}
class Zero(nn.Module):
def forward(self, x):
return x * 0
class Identity(nn.Module):
def forward(self, x):
return x
class Conv(nn.Module):
def __init__(self, C_in, C_out, kernel, stride, padding):
super().__init__()
self.op = nn.Sequential(
nn.Conv2d(C_in, C_out, kernel, stride, padding, bias=False),
nn.BatchNorm2d(C_out),
nn.ReLU()
)
def forward(self, x):
return self.op(x)
print(f"Defined {len(OPS)} operations")
Mixed OperationΒΆ
In DARTS (Differentiable Architecture Search), the discrete choice among candidate operations is relaxed into a continuous mixture. Each edge in the search cell maintains a weighted combination of all candidate operations (convolutions, pooling, skip connections, zero), with learnable architecture weights \(\alpha\) controlling the mixing: \(\bar{o}(x) = \sum_o \frac{\exp(\alpha_o)}{\sum_{o'} \exp(\alpha_{o'})} o(x)\). This softmax-weighted sum makes the architecture differentiable with respect to \(\alpha\), enabling gradient-based optimization of both the architecture and the model weights simultaneously.
class MixedOp(nn.Module):
"""Mixed operation with architecture weights."""
def __init__(self, C):
super().__init__()
self.ops = nn.ModuleList()
for name, op in OPS.items():
self.ops.append(op(C))
def forward(self, x, weights):
"""Apply weighted combination of operations."""
return sum(w * op(x) for w, op in zip(weights, self.ops))
print("MixedOp defined")
Search CellΒΆ
The search cell is a directed acyclic graph (DAG) of nodes, where each node aggregates features from previous nodes through mixed operations on the connecting edges. Typically, a cell has 2 input nodes (from previous cells), several intermediate nodes, and a concatenation output. All cells in the network share the same architecture, so searching for one cell structure defines the entire network. DARTS searches for both a normal cell (standard feature processing) and a reduction cell (spatial downsampling), which are then stacked to form the final architecture.
class Cell(nn.Module):
"""Searchable cell with mixed operations."""
def __init__(self, C, n_nodes=4):
super().__init__()
self.n_nodes = n_nodes
# Mixed operations for each edge
self.ops = nn.ModuleList()
for i in range(n_nodes):
for j in range(i + 2): # Connect to previous nodes
self.ops.append(MixedOp(C))
def forward(self, x, alphas):
states = [x, x] # Initial states
offset = 0
for i in range(self.n_nodes):
# Aggregate from all previous nodes
s = sum(self.ops[offset + j](h, F.softmax(alphas[offset + j], dim=0))
for j, h in enumerate(states))
offset += len(states)
states.append(s)
# Concatenate intermediate nodes
return torch.cat(states[2:], dim=1)
print("Cell defined")
Search NetworkΒΆ
The search network stacks multiple copies of the search cell (with shared architecture weights) to form a full network for the proxy task. During the search phase, this network is typically smaller than the final model (fewer cells, fewer channels) to keep the search computationally affordable. The networkβs classification accuracy on the proxy task serves as a signal for the quality of the architecture, though there can be discrepancies between proxy performance and final performance β a limitation known as the search-evaluation gap.
class SearchNetwork(nn.Module):
def __init__(self, C=16, n_cells=3, n_nodes=3, n_classes=10):
super().__init__()
self.C = C
self.n_cells = n_cells
self.n_nodes = n_nodes
# Initial convolution
self.stem = nn.Sequential(
nn.Conv2d(1, C, 3, 1, 1, bias=False),
nn.BatchNorm2d(C)
)
# Stacked cells
self.cells = nn.ModuleList()
for i in range(n_cells):
self.cells.append(Cell(C, n_nodes))
# Classifier
self.classifier = nn.Linear(C * n_nodes, n_classes)
# Architecture parameters
self._init_alphas()
def _init_alphas(self):
"""Initialize architecture parameters."""
n_ops = len(OPS)
n_edges = sum(2 + i for i in range(self.n_nodes))
self.alphas = nn.ParameterList([
nn.Parameter(torch.randn(n_ops)) for _ in range(n_edges)
])
def forward(self, x):
x = self.stem(x)
for cell in self.cells:
x = cell(x, self.alphas)
x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
return self.classifier(x)
def arch_parameters(self):
return self.alphas
def model_parameters(self):
return [p for n, p in self.named_parameters() if 'alphas' not in n]
print("SearchNetwork defined")
DARTS TrainingΒΆ
DARTS uses a bilevel optimization procedure: the model weights \(w\) are optimized on the training set, while the architecture weights \(\alpha\) are optimized on the validation set. In practice, these two updates alternate within each training step: one gradient step on \(w\) (standard cross-entropy loss), then one gradient step on \(\alpha\) (validation loss). This bilevel approach prevents the architecture from overfitting to the training data. The entire search typically requires only a few GPU-days, compared to weeks or months for reinforcement-learning-based NAS methods.
# Data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
# Split train into train/val for bi-level optimization
n_train = int(0.8 * len(train_data))
n_val = len(train_data) - n_train
train_subset, val_subset = torch.utils.data.random_split(train_data, [n_train, n_val])
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_subset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)
print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(test_data)}")
def train_darts(model, train_loader, val_loader, n_epochs=10):
# Separate optimizers
optimizer_w = torch.optim.Adam(model.model_parameters(), lr=3e-3)
optimizer_a = torch.optim.Adam(model.arch_parameters(), lr=3e-4)
train_losses = []
val_losses = []
for epoch in range(n_epochs):
model.train()
train_iter = iter(train_loader)
val_iter = iter(val_loader)
epoch_train_loss = 0
epoch_val_loss = 0
for step in range(min(len(train_loader), len(val_loader))):
# Update architecture (alpha)
try:
x_val, y_val = next(val_iter)
except StopIteration:
val_iter = iter(val_loader)
x_val, y_val = next(val_iter)
x_val, y_val = x_val.to(device), y_val.to(device)
optimizer_a.zero_grad()
output = model(x_val)
loss_val = F.cross_entropy(output, y_val)
loss_val.backward()
optimizer_a.step()
epoch_val_loss += loss_val.item()
# Update weights (w)
try:
x_train, y_train = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
x_train, y_train = next(train_iter)
x_train, y_train = x_train.to(device), y_train.to(device)
optimizer_w.zero_grad()
output = model(x_train)
loss_train = F.cross_entropy(output, y_train)
loss_train.backward()
optimizer_w.step()
epoch_train_loss += loss_train.item()
n_steps = min(len(train_loader), len(val_loader))
train_losses.append(epoch_train_loss / n_steps)
val_losses.append(epoch_val_loss / n_steps)
print(f"Epoch {epoch+1}: Train={train_losses[-1]:.4f}, Val={val_losses[-1]:.4f}")
return train_losses, val_losses
# Train
model = SearchNetwork(C=16, n_cells=2, n_nodes=3).to(device)
train_losses, val_losses = train_darts(model, train_loader, val_loader, n_epochs=8)
Extract ArchitectureΒΆ
After the search phase, the continuous architecture weights \(\alpha\) are discretized by selecting the operation with the highest weight on each edge: \(o^* = \arg\max_o \alpha_o\). Edges with only weak connections (low maximum \(\alpha\)) may be pruned to produce a cleaner cell topology. The resulting discrete architecture is then retrained from scratch at full scale (more cells, more channels, longer training), which typically yields significantly higher accuracy than the search network. Visualizing the extracted cell structure as a DAG diagram provides an interpretable summary of what DARTS has discovered.
def parse_architecture(model):
"""Extract discrete architecture from alphas."""
genotype = []
for alpha in model.alphas:
weights = F.softmax(alpha, dim=0)
best_op = weights.argmax().item()
op_name = list(OPS.keys())[best_op]
genotype.append((op_name, weights[best_op].item()))
return genotype
arch = parse_architecture(model)
print("\nDiscovered Architecture:")
for i, (op, weight) in enumerate(arch):
print(f"Edge {i}: {op} (Ξ±={weight:.3f})")
Visualize ResultsΒΆ
Visualizing the evolution of architecture weights during the search reveals how DARTS progressively eliminates weak operations and converges on a specific cell design. Comparing the searched architectureβs accuracy against hand-designed baselines quantifies the value of automated architecture search. The search cost (GPU hours) versus the accuracy improvement provides a practical cost-benefit analysis for whether NAS is worthwhile for a given problem domain.
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Training curves
axes[0].plot(train_losses, 'b-o', label='Train', markersize=5)
axes[0].plot(val_losses, 'r-o', label='Validation', markersize=5)
axes[0].set_xlabel('Epoch', fontsize=11)
axes[0].set_ylabel('Loss', fontsize=11)
axes[0].set_title('DARTS Training', fontsize=12)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Architecture weights
op_names = list(OPS.keys())
alpha_matrix = torch.stack([F.softmax(a, dim=0) for a in model.alphas]).detach().cpu().numpy()
im = axes[1].imshow(alpha_matrix.T, aspect='auto', cmap='Blues')
axes[1].set_xlabel('Edge', fontsize=11)
axes[1].set_ylabel('Operation', fontsize=11)
axes[1].set_title('Architecture Weights', fontsize=12)
axes[1].set_yticks(range(len(op_names)))
axes[1].set_yticklabels(op_names)
plt.colorbar(im, ax=axes[1])
plt.tight_layout()
plt.show()
SummaryΒΆ
DARTS Key Ideas:ΒΆ
Continuous relaxation of discrete search space
Mixed operations with softmax weights
Bi-level optimization: architecture Ξ± and weights w
Gradient-based search (efficient)
Search Space:ΒΆ
Operations: conv 3Γ3, 5Γ5, skip, pool, none
Cell-based structure
DAG with mixed edges
Advantages:ΒΆ
1000Γ faster than RL/evolution
End-to-end differentiable
Transfer to larger datasets
Variants:ΒΆ
PC-DARTS: Partial channel sampling
GDAS: Gumbel-softmax sampling
SNAS: Stochastic relaxation
ProxylessNAS: Memory-efficient
# ============================================================================
# Advanced NAS Implementations
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from copy import deepcopy
import time
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# ============================================================================
# 1. PC-DARTS: Partial Channel Connections
# ============================================================================
class PCDARTSMixedOp(nn.Module):
"""
Mixed operation with partial channel sampling for memory efficiency.
Memory reduction: O(C/K) instead of O(C) where K is channel divisor.
"""
def __init__(self, C, K=4):
super().__init__()
self.K = K # Channel sampling factor
self.C = C
self.C_sampled = C // K
# Build operations for sampled channels
self.ops = nn.ModuleList()
for name, op_fn in OPS.items():
if name in ['none', 'skip']:
self.ops.append(op_fn(C)) # Full channels for param-free ops
else:
self.ops.append(op_fn(self.C_sampled)) # Sampled channels
def forward(self, x, weights, edge_norm=None):
"""
Forward with partial channel sampling.
Args:
x: Input tensor [B, C, H, W]
weights: Architecture weights [|O|]
edge_norm: Edge normalization factor (PC-DARTS)
"""
# Sample random channels
channel_indices = torch.randperm(self.C)[:self.C_sampled]
x_sampled = x[:, channel_indices, :, :]
# Apply operations
outputs = []
for i, (op, w) in enumerate(zip(self.ops, weights)):
if i < 2: # none, skip (param-free, use full channels)
out = op(x)
else: # parameterized ops (use sampled channels)
out_sampled = op(x_sampled)
# Expand back to full channels
out = torch.zeros_like(x)
out[:, channel_indices, :, :] = out_sampled
# Apply edge normalization if provided
if edge_norm is not None:
w = w * edge_norm
outputs.append(w * out)
return sum(outputs)
class PCDARTSCell(nn.Module):
"""Cell with PC-DARTS mixed operations."""
def __init__(self, C, n_nodes=4, K=4):
super().__init__()
self.n_nodes = n_nodes
self.K = K
# Mixed operations for each edge
self.ops = nn.ModuleList()
for i in range(n_nodes):
for j in range(i + 2): # Connect to previous nodes
self.ops.append(PCDARTSMixedOp(C, K))
def forward(self, x, alphas):
states = [x, x]
offset = 0
for i in range(self.n_nodes):
# Compute edge normalization (sum of alphas)
edge_norms = []
for j in range(len(states)):
alpha = alphas[offset + j]
edge_norm = alpha.sum() # Sum over operations
edge_norms.append(edge_norm)
# Aggregate from all previous nodes
s = sum(
self.ops[offset + j](h, F.softmax(alphas[offset + j], dim=0), edge_norms[j])
for j, h in enumerate(states)
)
offset += len(states)
states.append(s)
return torch.cat(states[2:], dim=1)
# ============================================================================
# 2. GDAS: Gumbel-Softmax Differentiable Search
# ============================================================================
class GumbelSoftmax(nn.Module):
"""
Gumbel-Softmax for differentiable discrete sampling.
Ο β 0: Approaches one-hot (discrete)
Ο β β: Approaches uniform distribution
"""
def __init__(self, tau_min=0.1, tau_max=10.0):
super().__init__()
self.tau = tau_max # Temperature (annealed during training)
self.tau_min = tau_min
self.tau_max = tau_max
def anneal_temperature(self, progress):
"""
Anneal temperature: Ο(t) = Ο_max Β· (Ο_min/Ο_max)^progress
Args:
progress: Training progress in [0, 1]
"""
self.tau = self.tau_max * (self.tau_min / self.tau_max) ** progress
def sample_gumbel(self, shape):
"""Sample from Gumbel(0, 1): -log(-log(U)) where U ~ Uniform(0,1)"""
U = torch.rand(shape, device=device)
return -torch.log(-torch.log(U + 1e-20) + 1e-20)
def forward(self, logits, hard=False):
"""
Gumbel-Softmax sampling.
Args:
logits: Architecture parameters [|O|]
hard: If True, return one-hot (discrete) in forward, soft in backward
Returns:
Sampled probabilities (soft or hard)
"""
gumbel_noise = self.sample_gumbel(logits.shape)
y = logits + gumbel_noise
y_soft = F.softmax(y / self.tau, dim=0)
if hard:
# Straight-through estimator
y_hard = torch.zeros_like(y_soft)
y_hard[y_soft.argmax()] = 1.0
# y_hard in forward, y_soft gradient in backward
y = (y_hard - y_soft).detach() + y_soft
else:
y = y_soft
return y
class GDASMixedOp(nn.Module):
"""Mixed operation with Gumbel-Softmax sampling (single-path)."""
def __init__(self, C):
super().__init__()
self.ops = nn.ModuleList()
for name, op_fn in OPS.items():
self.ops.append(op_fn(C))
self.gumbel = GumbelSoftmax()
def forward(self, x, alpha, hard=False):
"""
Single-path forward (memory-efficient).
Only one operation is executed based on Gumbel sampling.
"""
# Sample operation weights
weights = self.gumbel(alpha, hard=hard)
if hard:
# Single-path: Execute only sampled operation
selected_op = weights.argmax().item()
return self.ops[selected_op](x)
else:
# Multi-path: Weighted sum (used during evaluation)
return sum(w * op(x) for w, op in zip(weights, self.ops))
# ============================================================================
# 3. FairNAS: Fairness-Aware Search
# ============================================================================
def compute_fairness_loss(alphas):
"""
Compute variance of architecture parameters to encourage exploration.
L_fair = Ξ£_edges Var(Ξ±^edge) = Ξ£_edges (1/|O|) Ξ£_o (Ξ±_o - αΎ±)Β²
"""
fairness_loss = 0.0
for alpha in alphas:
mean_alpha = alpha.mean()
variance = ((alpha - mean_alpha) ** 2).mean()
fairness_loss += variance
return fairness_loss
class FairNASNetwork(nn.Module):
"""
Search network with fairness regularization.
Prevents skip connections from dominating the search.
"""
def __init__(self, C=16, n_cells=2, n_nodes=3, n_classes=10, lambda_fair=0.1):
super().__init__()
self.C = C
self.n_cells = n_cells
self.n_nodes = n_nodes
self.lambda_fair = lambda_fair
# Network structure (same as DARTS)
self.stem = nn.Sequential(
nn.Conv2d(1, C, 3, 1, 1, bias=False),
nn.BatchNorm2d(C)
)
self.cells = nn.ModuleList()
for i in range(n_cells):
self.cells.append(Cell(C, n_nodes))
self.classifier = nn.Linear(C * n_nodes, n_classes)
# Architecture parameters
self._init_alphas()
def _init_alphas(self):
n_ops = len(OPS)
n_edges = sum(2 + i for i in range(self.n_nodes))
self.alphas = nn.ParameterList([
nn.Parameter(torch.randn(n_ops)) for _ in range(n_edges)
])
def forward(self, x):
x = self.stem(x)
for cell in self.cells:
x = cell(x, self.alphas)
x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
return self.classifier(x)
def loss(self, x, y):
"""
Combined loss: task loss + fairness regularization.
L = L_task + Ξ»_fair * L_fair
"""
logits = self(x)
task_loss = F.cross_entropy(logits, y)
fairness_loss = compute_fairness_loss(self.alphas)
total_loss = task_loss + self.lambda_fair * fairness_loss
return total_loss, task_loss.item(), fairness_loss.item()
# ============================================================================
# 4. Zero-Cost Proxy: SNIP (Single-Shot Network Pruning)
# ============================================================================
def compute_snip_score(model, dataloader, num_batches=1):
"""
Compute SNIP score: Ξ£_ΞΈ |ΞΈ Β· β_ΞΈ L|
Measures parameter importance without training.
"""
model.train()
# Accumulate gradients over few batches
for i, (x, y) in enumerate(dataloader):
if i >= num_batches:
break
x, y = x.to(device), y.to(device)
model.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Compute saliency: |ΞΈ Β· βΞΈ|
total_score = 0.0
for param in model.parameters():
if param.grad is not None:
score = torch.abs(param * param.grad).sum().item()
total_score += score
return total_score
def rank_architectures_snip(search_space, dataloader, top_k=5):
"""
Rank architectures using SNIP without training.
Args:
search_space: List of architecture configurations
dataloader: Data for computing gradients
top_k: Number of top architectures to return
Returns:
List of (architecture, score) sorted by score
"""
scores = []
for i, arch_config in enumerate(search_space):
# Build model from architecture config
model = build_model_from_config(arch_config).to(device)
# Compute SNIP score
score = compute_snip_score(model, dataloader, num_batches=3)
scores.append((arch_config, score))
print(f"Architecture {i+1}/{len(search_space)}: SNIP={score:.2f}")
# Sort by score (higher is better)
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def build_model_from_config(config):
"""
Build model from architecture configuration.
Config format: [(layer_type, params), ...]
"""
# Simplified: Return a basic model
# In practice, parse config to build custom architecture
return SearchNetwork(C=16, n_cells=2, n_nodes=3)
# ============================================================================
# 5. Evolutionary Search
# ============================================================================
class ArchitectureGenome:
"""
Genome representing an architecture for evolutionary search.
Encoding: List of operation IDs for each edge.
"""
def __init__(self, n_edges, n_ops):
self.n_edges = n_edges
self.n_ops = n_ops
self.genes = np.random.randint(0, n_ops, size=n_edges)
self.fitness = None
def mutate(self, mutation_rate=0.1):
"""Randomly change operations with given probability."""
mutated = ArchitectureGenome(self.n_edges, self.n_ops)
mutated.genes = self.genes.copy()
for i in range(self.n_edges):
if np.random.rand() < mutation_rate:
mutated.genes[i] = np.random.randint(0, self.n_ops)
return mutated
def crossover(self, other):
"""Single-point crossover with another genome."""
child = ArchitectureGenome(self.n_edges, self.n_ops)
crossover_point = np.random.randint(0, self.n_edges)
child.genes[:crossover_point] = self.genes[:crossover_point]
child.genes[crossover_point:] = other.genes[crossover_point:]
return child
class EvolutionarySearch:
"""
Evolutionary algorithm for architecture search.
Population-based search with mutation and crossover.
"""
def __init__(self, n_edges, n_ops, population_size=20, n_generations=10):
self.n_edges = n_edges
self.n_ops = n_ops
self.population_size = population_size
self.n_generations = n_generations
# Initialize random population
self.population = [
ArchitectureGenome(n_edges, n_ops)
for _ in range(population_size)
]
def evaluate_fitness(self, genome, train_loader, val_loader):
"""
Train model and evaluate accuracy (fitness).
In practice, use early stopping (e.g., 10 epochs).
"""
# Convert genome to architecture
model = self.genome_to_model(genome)
# Quick training (simplified)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(3): # Fast evaluation
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = F.cross_entropy(model(x), y)
loss.backward()
optimizer.step()
# Evaluate accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
accuracy = 100.0 * correct / total
return accuracy
def genome_to_model(self, genome):
"""Convert genome to PyTorch model (simplified)."""
# In practice: Build custom model based on genes
return SearchNetwork(C=16, n_cells=2, n_nodes=3).to(device)
def tournament_selection(self, tournament_size=3):
"""Select parent via tournament (best among random subset)."""
tournament = np.random.choice(self.population, tournament_size, replace=False)
return max(tournament, key=lambda g: g.fitness)
def evolve(self, train_loader, val_loader):
"""
Run evolutionary search.
Returns best genome found.
"""
history = []
for generation in range(self.n_generations):
print(f"\nGeneration {generation + 1}/{self.n_generations}")
# Evaluate population
for i, genome in enumerate(self.population):
if genome.fitness is None:
fitness = self.evaluate_fitness(genome, train_loader, val_loader)
genome.fitness = fitness
print(f" Individual {i+1}: Fitness={fitness:.2f}%")
# Track best
best_genome = max(self.population, key=lambda g: g.fitness)
avg_fitness = np.mean([g.fitness for g in self.population])
history.append({'best': best_genome.fitness, 'avg': avg_fitness})
print(f" Best: {best_genome.fitness:.2f}%, Avg: {avg_fitness:.2f}%")
# Create next generation
new_population = []
# Elitism: Keep best individual
new_population.append(best_genome)
# Generate offspring
while len(new_population) < self.population_size:
# Selection
parent1 = self.tournament_selection()
parent2 = self.tournament_selection()
# Crossover
child = parent1.crossover(parent2)
# Mutation
child = child.mutate(mutation_rate=0.2)
new_population.append(child)
self.population = new_population
# Return best
best = max(self.population, key=lambda g: g.fitness)
return best, history
# ============================================================================
# 6. Demonstration: Compare Search Methods
# ============================================================================
print("\n" + "="*80)
print("Comparing Neural Architecture Search Methods")
print("="*80)
# Use existing data loaders from DARTS section above
# train_loader, val_loader, test_loader defined earlier
# ============================================================================
# 6.1 PC-DARTS Demo
# ============================================================================
print("\n--- PC-DARTS (Partial Channel) ---")
class PCDARTSNet(nn.Module):
"""Simplified PC-DARTS network for demo."""
def __init__(self, C=16, K=4):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(1, C, 3, 1, 1, bias=False),
nn.BatchNorm2d(C)
)
self.cell = PCDARTSCell(C, n_nodes=2, K=K)
self.classifier = nn.Linear(C * 2, 10)
# Architecture params
n_ops = len(OPS)
n_edges = 2 + 1 # 2 nodes, 1+0 edges for node 0, 2+0 for node 1
self.alphas = nn.ParameterList([
nn.Parameter(torch.randn(n_ops)) for _ in range(n_edges)
])
def forward(self, x):
x = self.stem(x)
x = self.cell(x, self.alphas)
x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
return self.classifier(x)
pc_model = PCDARTSNet(C=16, K=4).to(device)
# Quick training (3 epochs)
optimizer = torch.optim.Adam(pc_model.parameters(), lr=1e-3)
for epoch in range(3):
pc_model.train()
total_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = F.cross_entropy(pc_model(x), y)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")
# Extract architecture
print("\nExtracted Architecture (PC-DARTS):")
for i, alpha in enumerate(pc_model.alphas):
weights = F.softmax(alpha, dim=0)
best_op = weights.argmax().item()
op_name = list(OPS.keys())[best_op]
print(f" Edge {i}: {op_name} (weight={weights[best_op].item():.3f})")
# ============================================================================
# 6.2 GDAS Demo
# ============================================================================
print("\n--- GDAS (Gumbel-Softmax) ---")
# Demonstrate Gumbel-Softmax sampling
gumbel = GumbelSoftmax(tau_min=0.1, tau_max=5.0)
logits = torch.tensor([1.0, 0.5, -0.5, 2.0, 0.0]) # Architecture weights
print("Temperature annealing:")
for progress in [0.0, 0.25, 0.5, 0.75, 1.0]:
gumbel.anneal_temperature(progress)
soft = gumbel(logits, hard=False)
hard = gumbel(logits, hard=True)
print(f" Progress={progress:.2f}, Ο={gumbel.tau:.3f}")
print(f" Soft: {soft.numpy()}")
print(f" Hard: {hard.numpy()}")
# ============================================================================
# 6.3 FairNAS Demo
# ============================================================================
print("\n--- FairNAS (Fairness-Aware) ---")
# Demonstrate fairness loss
alphas_unfair = [torch.tensor([5.0, 0.1, 0.1, 0.1, 0.1])] # Skip dominates
alphas_fair = [torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])] # Balanced
loss_unfair = compute_fairness_loss(alphas_unfair)
loss_fair = compute_fairness_loss(alphas_fair)
print(f"Fairness loss (unfair): {loss_unfair:.4f}")
print(f"Fairness loss (fair): {loss_fair:.4f}")
# Quick FairNAS training
fair_model = FairNASNetwork(C=16, n_cells=2, n_nodes=2, lambda_fair=0.2).to(device)
optimizer_fair = torch.optim.Adam(fair_model.parameters(), lr=1e-3)
print("\nTraining FairNAS:")
for epoch in range(3):
fair_model.train()
epoch_task_loss = 0
epoch_fair_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer_fair.zero_grad()
total_loss, task_loss, fair_loss = fair_model.loss(x, y)
total_loss.backward()
optimizer_fair.step()
epoch_task_loss += task_loss
epoch_fair_loss += fair_loss
n = len(train_loader)
print(f"Epoch {epoch+1}: Task={epoch_task_loss/n:.4f}, Fair={epoch_fair_loss/n:.4f}")
# ============================================================================
# 6.4 SNIP Zero-Cost Proxy Demo
# ============================================================================
print("\n--- SNIP (Zero-Cost Proxy) ---")
# Create small search space
search_space = [
{'type': 'small', 'C': 8, 'cells': 2},
{'type': 'medium', 'C': 16, 'cells': 2},
{'type': 'large', 'C': 32, 'cells': 3},
]
print("Ranking architectures with SNIP (no training):")
# Get small subset of data for SNIP
snip_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_subset, range(100)),
batch_size=32, shuffle=False
)
snip_scores = []
for i, config in enumerate(search_space):
model_snip = SearchNetwork(C=config['C'], n_cells=config['cells'], n_nodes=2).to(device)
score = compute_snip_score(model_snip, snip_loader, num_batches=2)
snip_scores.append((config['type'], score))
print(f"{config['type']:10s}: SNIP score = {score:.2f}")
snip_scores.sort(key=lambda x: x[1], reverse=True)
print(f"\nBest architecture (SNIP): {snip_scores[0][0]}")
# ============================================================================
# 6.5 Visualization: Method Comparison
# ============================================================================
print("\n" + "="*80)
print("Summary: NAS Methods Comparison")
print("="*80)
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 1. PC-DARTS memory efficiency
methods = ['DARTS', 'PC-DARTS\n(K=2)', 'PC-DARTS\n(K=4)', 'GDAS']
memory = [100, 50, 25, 12.5] # Relative memory usage
axes[0, 0].bar(methods, memory, color=['red', 'orange', 'green', 'blue'], alpha=0.7)
axes[0, 0].set_ylabel('Memory Usage (%)', fontsize=11)
axes[0, 0].set_title('Memory Efficiency', fontsize=12, fontweight='bold')
axes[0, 0].axhline(y=50, color='gray', linestyle='--', linewidth=1, label='50% baseline')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3, axis='y')
# 2. Gumbel temperature annealing
progress_vals = np.linspace(0, 1, 50)
tau_max, tau_min = 5.0, 0.1
tau_curve = tau_max * (tau_min / tau_max) ** progress_vals
axes[0, 1].plot(progress_vals, tau_curve, 'b-', linewidth=2)
axes[0, 1].fill_between(progress_vals, tau_curve, alpha=0.3)
axes[0, 1].set_xlabel('Training Progress', fontsize=11)
axes[0, 1].set_ylabel('Temperature Ο', fontsize=11)
axes[0, 1].set_title('GDAS Temperature Annealing', fontsize=12, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, tau_max * 1.1])
# 3. Fairness loss effect
epochs_demo = np.arange(1, 11)
unfair_arch = 0.8 - 0.05 * epochs_demo # Skip connection dominates, lower diversity
fair_arch = 0.6 + 0.02 * epochs_demo # FairNAS maintains diversity
axes[1, 0].plot(epochs_demo, unfair_arch, 'r-o', label='No Fairness', linewidth=2, markersize=6)
axes[1, 0].plot(epochs_demo, fair_arch, 'g-s', label='With Fairness', linewidth=2, markersize=6)
axes[1, 0].set_xlabel('Epoch', fontsize=11)
axes[1, 0].set_ylabel('Architecture Diversity', fontsize=11)
axes[1, 0].set_title('FairNAS: Exploration vs. Exploitation', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# 4. SNIP ranking correlation
# Simulate: SNIP score vs. true accuracy correlation
np.random.seed(42)
true_acc = np.random.rand(20) * 20 + 70 # True accuracy 70-90%
snip_score_sim = true_acc + np.random.randn(20) * 3 # Correlated with noise
axes[1, 1].scatter(snip_score_sim, true_acc, alpha=0.6, s=80, c='purple', edgecolors='black')
axes[1, 1].set_xlabel('SNIP Score', fontsize=11)
axes[1, 1].set_ylabel('True Accuracy (%)', fontsize=11)
axes[1, 1].set_title('Zero-Cost Proxy Correlation', fontsize=12, fontweight='bold')
# Add correlation line
z = np.polyfit(snip_score_sim, true_acc, 1)
p = np.poly1d(z)
x_line = np.linspace(snip_score_sim.min(), snip_score_sim.max(), 100)
axes[1, 1].plot(x_line, p(x_line), 'r--', linewidth=2, label=f'Correlation (Οβ0.8)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n" + "="*80)
print("Key Takeaways:")
print("="*80)
print("1. PC-DARTS: 4Γ memory reduction via partial channel sampling")
print("2. GDAS: Single-path search with Gumbel-Softmax (discrete sampling)")
print("3. FairNAS: Regularization prevents skip connection dominance")
print("4. SNIP: Zero-cost proxy enables fast architecture ranking without training")
print("5. Trade-offs: Memory vs. accuracy, speed vs. thoroughness, exploration vs. exploitation")
print("\nNext: Apply NAS to custom search spaces and hardware-aware optimization!")
Advanced Neural Architecture Search TheoryΒΆ
1. Introduction to Neural Architecture Search (NAS)ΒΆ
1.1 MotivationΒΆ
Traditional approach: Manual architecture design
Requires expert knowledge
Time-consuming trial-and-error
May miss optimal designs
NAS approach: Automated architecture discovery
Search space: Set of possible architectures
Search strategy: How to explore space
Performance estimation: How to evaluate architectures
Success stories:
NASNet (2017): Surpassed human-designed architectures on ImageNet
EfficientNet (2019): SOTA accuracy with 8.4Γ fewer parameters
GPT-3 (2020): Architecture hyperparameters via NAS principles
1.2 NAS Framework ComponentsΒΆ
1. Search Space:
Ξ = {Ξ±β, Ξ±β, ..., Ξ±β} (set of candidate architectures)
Defines:
Macro-search: Entire network topology
Micro-search: Cell structure (repeated building blocks)
Hyperparameters: Depth, width, kernel sizes, etc.
2. Search Strategy: Methods to navigate search space:
Random search: Sample uniformly
Evolutionary algorithms: Mutation + selection
Reinforcement learning: Controller RNN generates architectures
Gradient-based: Differentiable architecture search (DARTS)
Bayesian optimization: Model performance landscape
3. Performance Estimation: How to evaluate architecture Ξ±:
Full training: Train to convergence (expensive!)
Early stopping: Train for few epochs
Weight sharing: One-shot models
Proxy tasks: Smaller datasets, lower resolution
Objective:
Ξ±* = argmax_{Ξ±βΞ} Accuracy(Ξ±, D_val)
subject to constraints (e.g., FLOPs < budget).
2. Search SpacesΒΆ
2.1 Chain-Structured NetworksΒΆ
Simple sequential: layerβ β layerβ β β¦ β layerβ
Search choices per layer:
Operation: Conv, MaxPool, Identity, Zero
Kernel size: 3Γ3, 5Γ5, 7Γ7
Filters: 16, 32, 64, 128
Activation: ReLU, Swish, Mish
Size: Exponential in depth
|Ξ| = O^L where O operations, L layers
2.2 Cell-Based Search SpaceΒΆ
Motivation: Reduce search space by searching for repeatable cells.
Architecture = Stack of cells:
Input β Cellβ β Cellβ β ... β Cellβ β Output
Cell structure (DAG):
Nodes: Feature maps at different resolutions
Edges: Operations (Conv, Pool, etc.)
Input nodes: Previous cell outputs
Output node: Concatenate intermediate nodes
NASNet search space:
Normal cell: Preserves spatial dimensions
Reduction cell: Downsamples (stride 2)
Size:
|Ξ| β O^E where E edges in cell (much smaller than O^L)
2.3 Hierarchical Search SpacesΒΆ
Multiple levels:
Micro: Operations within cell
Meso: Cell connectivity
Macro: Number of cells, channels
Example (EfficientNet):
Depth: d β [1.0, 2.0, 3.0, β¦]
Width: w β [1.0, 1.5, 2.0, β¦]
Resolution: r β [224, 260, 300, β¦]
Compound scaling:
depth: d = Ξ±^Ο
width: w = Ξ²^Ο
resolution: r = Ξ³^Ο
subject to Ξ±Β·Ξ²Β²Β·Ξ³Β² β 2 (resource constraint).
2.4 Operation SpaceΒΆ
Common operations:
Convolutions:
Standard: 3Γ3, 5Γ5, 7Γ7
Depthwise separable (MobileNet-style)
Dilated convolutions (atrous)
Pooling:
Max pooling: 3Γ3, 5Γ5
Average pooling
Global average pooling
Skip connections:
Identity
1Γ1 projection
Zero operation:
Prunes the edge
Activations:
ReLU, ReLU6
Swish, Mish
GELU
3. Reinforcement Learning-Based NASΒΆ
3.1 NAS with RL (Zoph & Le, 2017)ΒΆ
Idea: Controller RNN generates architectures, trained via REINFORCE.
Controller RNN:
Outputs sequence of decisions (layer types, hyperparameters)
Example: [Conv 3Γ3, 64 filters, ReLU, MaxPool 2Γ2, Conv 5Γ5, 128 filters, β¦]
Training procedure:
Sample architecture: Ξ± ~ Ο_ΞΈ(Ξ±) (controller policy)
Train child network: Get validation accuracy R(Ξ±)
Update controller: Maximize expected reward
β_ΞΈ J(ΞΈ) = E_{Ξ±~Ο_ΞΈ} [R(Ξ±) β_ΞΈ log Ο_ΞΈ(Ξ±)]
Reward: Validation accuracy (or accuracy - λ·complexity)
REINFORCE gradient:
β_ΞΈ J(ΞΈ) β (1/m) Ξ£α΅’ (R(Ξ±α΅’) - b) β_ΞΈ log Ο_ΞΈ(Ξ±α΅’)
where b is baseline (running average of rewards).
3.2 Efficiency ImprovementsΒΆ
Problem: Each architecture requires full training (expensive!).
Solutions:
Early stopping: Train for 5-10 epochs instead of 100+
Learning curve prediction: Extrapolate from early training
Weight inheritance: Initialize child from parent
Computational cost (original NAS):
800 GPUs for 28 days (22,400 GPU-days!)
Modern methods: 1-4 GPU-days
3.3 Block-Level Search (NASNet)ΒΆ
Search for cell structure, not entire network.
Normal cell:
Input: H_i, H_{i-1} (current and previous cell outputs)
Operations: Combine via learned ops
Output: Concatenate intermediate nodes
Reduction cell:
Similar structure but with stride 2 operations
Final architecture:
Stem β [N Γ Normal β Reduction] Γ 3 β Global Pool β FC
Search result (ImageNet):
NASNet-A: 82.7% top-1 accuracy
Human-designed (MobileNet): 81.5%
4. Evolutionary Algorithms for NASΒΆ
4.1 Evolution-Based SearchΒΆ
Inspired by natural selection: Mutation + crossover + survival of fittest.
Algorithm:
Initialize population: P = {Ξ±β, Ξ±β, β¦, Ξ±_pop}
Evaluate fitness: Train each architecture, get accuracy
Selection: Keep top-k (tournament selection)
Mutation: Randomly modify architectures
Add/remove layer
Change operation
Modify hyperparameters
Crossover: Combine two parent architectures
Repeat for T generations
Mutation examples:
Replace Conv 3Γ3 β Conv 5Γ5
Insert skip connection
Increase/decrease channels
4.2 Regularized Evolution (Real et al., 2019)ΒΆ
Key idea: Age-based removal instead of fitness-based.
Aging tournament selection:
Sample S architectures from population
Remove oldest architecture
Mutate best of S, add to population
Advantages:
Prevents premature convergence
Maintains diversity
Better exploration
Results:
AmoebaNet: 83.9% ImageNet top-1 (SOTA at time)
450 GPU-days (vs 22,400 for original NAS)
4.3 Crossover StrategiesΒΆ
Single-point crossover:
Parent 1: [Op1, Op2, Op3, Op4, Op5]
Parent 2: [Op6, Op7, Op8, Op9, Op10]
β (crossover at position 3)
Child: [Op1, Op2, Op3, Op9, Op10]
Uniform crossover: Each gene (operation) randomly from either parent.
Cell-level crossover: Combine cells from different parents.
5. Differentiable Architecture Search (DARTS)ΒΆ
5.1 Continuous RelaxationΒΆ
Problem: Discrete architecture choices β canβt use gradient descent.
Solution: Relax discrete choices to continuous (differentiable).
Continuous architecture: Instead of selecting one operation, weight all operations:
f(x) = Ξ£_{oβO} (exp(Ξ±_o) / Ξ£_{o'} exp(Ξ±_{o'})) o(x)
where Ξ±_o are architecture parameters (learnable).
Softmax over operations:
f(x) = Ξ£_o softmax(Ξ±)_o Β· o(x)
5.2 Bi-Level OptimizationΒΆ
Two sets of parameters:
Network weights: w (operation parameters)
Architecture parameters: Ξ± (operation selection)
Objective:
min_Ξ± L_val(w*(Ξ±), Ξ±)
s.t. w*(Ξ±) = argmin_w L_train(w, Ξ±)
Interpretation:
Inner optimization: Train network weights w given architecture Ξ±
Outer optimization: Find architecture Ξ± that minimizes validation loss
5.3 Approximation via Gradient DescentΒΆ
Challenges: Inner optimization expensive (requires full convergence).
Approximation: One-step gradient descent
w' = w - ΞΎ β_w L_train(w, Ξ±)
Architecture gradient:
β_Ξ± L_val(w', Ξ±) β β_Ξ± L_val(w - ΞΎ β_w L_train(w, Ξ±), Ξ±)
Chain rule:
β_Ξ± L_val = β_Ξ± L_val(w', Ξ±) - ΞΎ β_Ξ±,wΒ² L_train Β· β_w L_val(w', Ξ±)
where β_Ξ±,wΒ² is Hessian (expensive!).
Finite difference approximation:
β_Ξ±,wΒ² L_train Β· v β (β_w L_train(w+Ρ·v, Ξ±) - β_w L_train(w-Ρ·v, Ξ±)) / (2Ξ΅)
5.4 DiscretizationΒΆ
After search: Continuous Ξ± β discrete architecture.
Method: For each edge, select operation with largest Ξ±_o:
o* = argmax_o Ξ±_o
Pruning: Remove edges with weak connections (optional).
5.5 DARTS ResultsΒΆ
Efficiency:
Search time: 4 GPU-days (vs 22,400 for NAS)
Search cost: 1000Γ reduction
Performance:
CIFAR-10: 2.76% error (comparable to NASNet)
ImageNet: 73.3% top-1 (competitive)
Advantages:
Gradient-based: Fast, scalable
End-to-end differentiable
No need for RL or evolution
Limitations:
Discretization gap (continuous β discrete)
Overfitting to validation set
Performance collapse (all skip connections)
6. One-Shot NAS and Weight SharingΒΆ
6.1 Supernet TrainingΒΆ
Idea: Train one βsupernetβ containing all possible operations.
Supernet:
f_super(x) = Ξ£_{path β Paths} I[path sampled] Β· f_path(x)
where paths correspond to different architectures.
Weight sharing: Operations share weights across architectures.
Training:
Sample architecture (path) from supernet
Train on mini-batch
Update only sampled pathβs weights
Repeat
Search: Evaluate architectures using shared weights (no retraining!).
6.2 ENAS (Efficient NAS)ΒΆ
Idea: Weight sharing + RL controller.
Algorithm:
Controller: Sample architecture Ξ± ~ Ο_ΞΈ(Ξ±)
Train child: Use shared weights from supernet, train briefly
Evaluate: Get validation accuracy R(Ξ±)
Update controller: REINFORCE gradient
Update supernet: Train shared weights on sampled architectures
Speedup: 1000Γ faster than original NAS (1 GPU-day).
Results:
CIFAR-10: 2.89% error
Penn Treebank: 55.8 perplexity (language modeling)
6.3 Single-Path One-Shot (SPOS)ΒΆ
Training phase: Uniformly sample one path (architecture) per batch.
For each batch:
Ξ± ~ Uniform(all architectures)
Update weights w using Ξ±
Search phase: Evolutionary search over architectures, evaluate using shared weights.
Advantages:
Decouples training from search
No RL controller needed
Very efficient
6.4 FairNASΒΆ
Problem: Weight sharing biases towards certain architectures.
Solution: Fair sampling strategies.
Techniques:
Fair path sampling: Ensure all operations sampled equally
Sandwich rule: Train largest and smallest sub-networks
Calibration: Normalize operation statistics
7. Hardware-Aware NASΒΆ
7.1 Multi-Objective OptimizationΒΆ
Objectives:
Accuracy: Maximize validation accuracy
Latency: Minimize inference time (mobile, edge devices)
Energy: Minimize power consumption
Memory: Fit in device constraints
FLOPs: Reduce computational cost
Pareto front:
Ξ±* β argmax_{Ξ±} {(Acc(Ξ±), -Latency(Ξ±), -Energy(Ξ±))}
No single architecture dominates all objectives.
7.2 Latency-Aware SearchΒΆ
MNASNet (Tan et al., 2019):
Objective:
maximize Acc(Ξ±) Γ (Latency(Ξ±) / Target)^w
where w controls trade-off (w=0: accuracy only, wββ: latency only).
Latency measurement:
Real device profiling (Pixel phone)
Latency predictor (neural network)
Search: RL-based (controller samples Ξ±, reward is objective).
Results:
1.8Γ faster than MobileNetV2 with same accuracy
76.7% ImageNet top-1 with 78ms latency
7.3 Device-Specific OptimizationΒΆ
ProxylessNAS:
Train supernet on GPU
Evaluate latency on target device (mobile CPU, GPU, etc.)
Architecture-specific for each device
Results:
Mobile: 75.1% / 199ms (CPU), 74.6% / 80ms (GPU)
Different architectures optimal for different hardware!
7.4 Once-For-All (OFA) NetworksΒΆ
Idea: Train once, deploy to many devices.
Progressive shrinking:
Train largest network (full depth, width, resolution)
Progressively train smaller sub-networks
Final supernet supports any configuration
Deployment:
Given device constraints (latency, memory)
Search for best sub-network within constraints
Extract and deploy (no retraining!)
Results:
ImageNet: 80.0% top-1 with 230 FLOPs (SOTA efficiency)
Supports 10^19 sub-networks
8. Performance Estimation StrategiesΒΆ
8.1 Full TrainingΒΆ
Method: Train each architecture to convergence.
Pros:
Accurate performance estimate
No bias
Cons:
Extremely expensive (hours per architecture)
Infeasible for large search spaces
8.2 Early StoppingΒΆ
Method: Train for few epochs (e.g., 5-10), use as proxy for final accuracy.
Assumptions:
Rank correlation: Early accuracy β final accuracy (rank-wise)
Challenges:
Correlation not perfect
Some architectures are βslow startersβ
8.3 Learning Curve ExtrapolationΒΆ
Method: Fit learning curve, predict final accuracy.
Models:
Power law: a(t) = Ξ± - Ξ²Β·t^(-Ξ³)
Exponential: a(t) = Ξ± - Ξ²Β·exp(-Ξ³t)
Bayesian neural networks: Learn curve distribution
Advantages:
Stop training early for bad architectures
Allocate more resources to promising ones
8.4 Weight Sharing / One-ShotΒΆ
Method: All architectures share weights in supernet.
Evaluation: Sample architecture, evaluate on validation set (no training!).
Pros:
Extremely fast (seconds per architecture)
Enables large-scale search
Cons:
Performance estimates biased
Ranking correlation with standalone training: 0.5-0.7
8.5 Performance PredictorsΒΆ
Idea: Train surrogate model to predict accuracy from architecture encoding.
Encoding:
Adjacency matrix: For DAG-based cells
Path encoding: Sequence of operations
Graph neural network: Process architecture as graph
Predictor:
Predictor: encode(Ξ±) β predicted_accuracy
Training:
Collect dataset: (Ξ±, accuracy) pairs
Train regression model (MLP, GNN, Transformer)
Use predictor to evaluate new architectures
Uncertainty quantification:
Bayesian predictors
Ensemble predictors
Gaussian process
9. Transferability and GeneralizationΒΆ
9.1 Transfer from Proxy TasksΒΆ
Common proxies:
CIFAR-10 β ImageNet: Search on CIFAR, transfer to ImageNet
Low resolution β High resolution: Search at 32Γ32, deploy at 224Γ224
Fewer epochs: Search with 50 epochs, deploy with 300
Challenges:
Rank correlation not perfect
Best on CIFAR β best on ImageNet
Transfer techniques:
Fine-tuning: Adjust architecture on target task
Scaling rules: Adjust depth/width for target dataset
9.2 Cross-Task GeneralizationΒΆ
Question: Does architecture found for Task A work on Task B?
Findings:
Within domain: Good transfer (ImageNet β COCO detection)
Across domains: Moderate transfer (vision β NLP)
Task-specific: Some tasks need specialized architectures
Examples:
NASNet (image classification) β Object detection: Good
EfficientNet β Semantic segmentation: Good
Vision architectures β Language modeling: Poor
9.3 Domain AdaptationΒΆ
Search for architecture on source domain, deploy on target.
Techniques:
Multi-task search: Search jointly on source + target
Domain-invariant operations: Prefer ops that transfer well
Meta-NAS: Learn to search across domains
10. Advanced NAS MethodsΒΆ
10.1 Meta-Learning for NASΒΆ
Idea: Learn the search process itself.
Meta-NAS:
Meta-learner: Predicts good architectures for new tasks
Trained on distribution of tasks
Fast adaptation to new task
Applications:
Few-shot learning: Find architecture from few examples
Continual learning: Adapt architecture as task evolves
10.2 Multi-Objective NASΒΆ
Pareto optimization:
Objectives:
fβ(Ξ±) = Accuracy(Ξ±)
fβ(Ξ±) = -Latency(Ξ±)
fβ(Ξ±) = -FLOPs(Ξ±)
Methods:
Evolutionary: NSGA-II (Non-dominated Sorting GA)
Scalarization: Ξ»βfβ + Ξ»βfβ + Ξ»βfβ
Hypervolume optimization
Results: Set of Pareto-optimal architectures (user picks based on constraints).
10.3 Neural Architecture Optimization (NAO)ΒΆ
Idea: Model architecture performance as continuous function.
Encoder: Architecture Ξ± β embedding z Predictor: z β performance Decoder: z β architecture Ξ±β
Optimization:
z* = argmax_z Predictor(z)
Ξ±* = Decoder(z*)
Advantages:
Continuous optimization (gradient-based)
No need for supernet or RL
10.4 AutoML-ZeroΒΆ
Radical idea: Search from scratch (basic operations).
Search space:
Primitive ops: +, -, Γ, /, exp, log, sin, max, min
No domain knowledge (no convolutions!)
Discovery:
Rediscovered linear regression, ReLU, normalized gradients
Proof-of-concept for automated science
11. Benchmarks and ComparisonsΒΆ
11.1 NAS-Bench-101ΒΆ
Dataset: 423,624 unique architectures, all trained on CIFAR-10.
Architecture space:
Cell-based (directed acyclic graph)
3 ops per cell: 3Γ3 conv, 1Γ1 conv, 3Γ3 max pool
Metrics:
Training time
Validation/test accuracy
Number of parameters
Usage:
Researchers: Test NAS algorithms without expensive training
Query: Get performance of architecture in O(1) time
11.2 NAS-Bench-201ΒΆ
Improvements over 101:
Fixed topology, search over operations
15,625 architectures (smaller, fully evaluated)
Multiple datasets: CIFAR-10, CIFAR-100, ImageNet-16-120
Insights:
Simple random search competitive with sophisticated methods!
Weight sharing biases rankings
Early stopping correlation varies
11.3 Performance ComparisonΒΆ
ImageNet top-1 accuracy (2024 SOTA):
Manual designs:
- ResNet-50: 76.2%
- EfficientNet-B0 (manual scaling): 77.1%
NAS-discovered:
- NASNet-A: 82.7% (but inefficient)
- AmoebaNet-A: 83.9%
- EfficientNet-B7 (NAS): 84.4%
- Once-For-All: 80.0% (efficient)
Search cost:
NAS (2017): 22,400 GPU-days
NASNet (2018): 2,000 GPU-days
ENAS (2018): 0.5 GPU-days
DARTS (2019): 4 GPU-days
Once-For-All (2020): 1,200 GPU-hours (50 GPU-days, but train once)
12. Practical ConsiderationsΒΆ
12.1 Search Space DesignΒΆ
Trade-offs:
Large space: More potential, but expensive to search
Small space: Efficient search, but may miss optimal
Best practices:
Start with proven components: Conv, pooling, skip connections
Constrain based on domain knowledge: Vision β NLP
Hierarchical spaces: Micro + macro search
12.2 Validation Set ManagementΒΆ
Problem: Overfitting to validation set during search.
Solutions:
Hold-out test set: Only use after search complete
Cross-validation: Rotate validation sets
Regularization: Penalize complexity during search
12.3 ReproducibilityΒΆ
Challenges:
Randomness in search (RL, evolution)
Stochasticity in training
Hardware differences
Best practices:
Report seeds, hyperparameters
Average over multiple runs
Open-source code and checkpoints
12.4 Computational BudgetΒΆ
Fixed budget: Given B GPU-hours, how to allocate?
Strategies:
Many short runs: Sample many architectures, train briefly
Few long runs: Sample few, train to convergence
Adaptive: Allocate based on promise (bandit algorithms)
Optimal allocation:
Depends on search space quality
Exploration vs. exploitation trade-off
13. Applications Beyond Image ClassificationΒΆ
13.1 Object DetectionΒΆ
Challenges:
Multi-scale features (FPN, PAFPN)
Backbone + neck + head architecture
Approaches:
NAS-FPN: Search for feature pyramid network topology
DetNAS: End-to-end detection-aware NAS
EfficientDet: Compound scaling for detection
Results:
NAS-FPN: 40.5 mAP on COCO (vs 36.8 for hand-designed FPN)
13.2 Semantic SegmentationΒΆ
Auto-DeepLab:
Search for encoder-decoder architecture
Multi-scale feature aggregation
Cell-level search
Results:
Cityscapes: 82.1 mIoU (SOTA for NAS)
13.3 Language ModelingΒΆ
Evolved Transformer:
Search for Transformer variants
Evolutionary algorithm
Discovered architecture:
Gated linear units instead of FFN
Different attention patterns
0.7 perplexity improvement on WMTβ14
13.4 Reinforcement LearningΒΆ
Search for RL agent architectures:
Policy network structure
Value function approximator
Applications:
Atari games
Continuous control (robotics)
14. Limitations and ChallengesΒΆ
14.1 Computational CostΒΆ
Despite improvements, still expensive:
Search: 4-50 GPU-days (vs minutes for manual design)
Amortized over multiple deployments, but not accessible to all
Barriers:
Academia: Limited compute resources
Industry: Cost-benefit analysis
14.2 Overfitting to Search SpaceΒΆ
Search space design is crucial:
If optimal architecture not in space, NAS wonβt find it
Human bias embedded in search space
Example: If search space only includes CNNs, canβt discover Transformers.
14.3 Generalization GapΒΆ
Problem: Performance on proxy task β performance on target task.
Causes:
Dataset shift
Training recipe differences
Scaling effects
14.4 Reproducibility and FairnessΒΆ
Comparison challenges:
Different search spaces
Different computational budgets
Different training recipes
Need: Standardized benchmarks (NAS-Bench), fair comparisons.
15. Future DirectionsΒΆ
15.1 Efficient SearchΒΆ
Goals:
Sub-GPU-day search
Real-time adaptation
Approaches:
Better performance predictors
Smarter search strategies
Transfer learning from previous searches
15.2 Neural Architecture UnderstandingΒΆ
Interpretability:
Why does architecture A outperform B?
What architectural features are important?
Architecture analysis:
Feature visualization
Ablation studies
Causal analysis
15.3 Joint Architecture and Training SearchΒΆ
Current: Search architecture, then train.
Future: Co-optimize architecture + training strategy.
Search dimensions:
Architecture
Data augmentation
Learning rate schedule
Regularization
15.4 Continual Architecture AdaptationΒΆ
Dynamic architectures:
Adapt to data distribution shifts
Grow/shrink based on task complexity
Lifelong learning:
Accumulate architectural knowledge
Transfer to new tasks
16. Key TakeawaysΒΆ
NAS automates architecture design:
Search space + search strategy + performance estimation
Can discover architectures rivaling human experts
Major paradigms:
RL-based: Controller samples architectures (NAS, ENAS)
Evolutionary: Mutation + selection (AmoebaNet)
Gradient-based: Differentiable search (DARTS)
One-shot: Weight sharing supernet (SPOS, OFA)
Efficiency crucial:
Original NAS: 22,400 GPU-days
Modern methods: 1-50 GPU-days
Key: Weight sharing, early stopping, predictors
Hardware-aware NAS:
Multi-objective: accuracy + latency + energy
Device-specific architectures
Once-for-all networks for flexible deployment
Search space design matters:
Cell-based reduces complexity
Hierarchical spaces (micro + macro)
Domain knowledge still important
Applications beyond classification:
Detection, segmentation, NLP, RL
Transfer from proxy tasks
Open challenges:
Computational cost
Overfitting to search space
Generalization gap
Reproducibility
Future: Towards AutoML:
Joint architecture + training search
Continual adaptation
Democratization of NAS
17. ReferencesΒΆ
Foundational papers:
Zoph & Le (2017): βNeural Architecture Search with Reinforcement Learningβ (ICLR)
Zoph et al. (2018): βLearning Transferable Architectures for Scalable Image Recognitionβ (CVPR - NASNet)
Real et al. (2019): βRegularized Evolution for Image Classifier Architecture Searchβ (AAAI - AmoebaNet)
Liu et al. (2019): βDARTS: Differentiable Architecture Searchβ (ICLR)
Efficiency improvements:
Pham et al. (2018): βEfficient Neural Architecture Search via Parameter Sharingβ (ICML - ENAS)
Guo et al. (2020): βSingle Path One-Shot Neural Architecture Search with Uniform Samplingβ (ECCV - SPOS)
Hardware-aware:
Tan et al. (2019): βMnasNet: Platform-Aware Neural Architecture Search for Mobileβ (CVPR)
Cai et al. (2020): βOnce-for-All: Train One Network and Specialize it for Efficient Deploymentβ (ICLR)
Benchmarks:
Ying et al. (2019): βNAS-Bench-101: Towards Reproducible Neural Architecture Searchβ (ICML)
Dong & Yang (2020): βNAS-Bench-201: Extending the Scope of Reproducible NASβ (ICLR)
Surveys:
Elsken et al. (2019): βNeural Architecture Search: A Surveyβ (JMLR)
Wistuba et al. (2019): βA Survey on Neural Architecture Searchβ (arXiv)
"""
Complete Neural Architecture Search Implementations
====================================================
Includes: DARTS (Differentiable), RL-based NAS controller, Evolutionary search,
One-shot supernet, performance predictors, hardware-aware search.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import namedtuple
import random
# ============================================================================
# 1. DARTS (Differentiable Architecture Search)
# ============================================================================
# Define possible operations
OPS = {
'none': lambda C, stride: Zero(stride),
'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 1),
'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 2),
'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),
'avg_pool_3x3': lambda C, stride: nn.AvgPool2d(3, stride=stride, padding=1),
'max_pool_3x3': lambda C, stride: nn.MaxPool2d(3, stride=stride, padding=1),
}
class Zero(nn.Module):
"""Zero operation (drop path)."""
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x * 0.
return x[:, :, ::self.stride, ::self.stride] * 0.
class Identity(nn.Module):
"""Identity (skip connection)."""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SepConv(nn.Module):
"""Depthwise separable convolution."""
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
"""Dilated convolution."""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class FactorizedReduce(nn.Module):
"""Reduce spatial dimensions by 2."""
def __init__(self, C_in, C_out):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
return self.bn(out)
class MixedOp(nn.Module):
"""
Mixed operation: weighted sum of all operations.
f(x) = Ξ£_o softmax(Ξ±)_o Β· o(x)
Args:
C: Number of channels
stride: Stride for operations
"""
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in OPS:
op = OPS[primitive](C, stride)
self._ops.append(op)
def forward(self, x, weights):
"""
Args:
x: Input tensor
weights: Architecture weights (softmax over operations)
Returns:
Weighted sum of operations
"""
return sum(w * op(x) for w, op in zip(weights, self._ops))
class DARTSCell(nn.Module):
"""
DARTS cell (DAG of mixed operations).
Args:
steps: Number of intermediate nodes
multiplier: Output = concat of final `multiplier` nodes
C_prev_prev: Channels from 2 cells ago
C_prev: Channels from 1 cell ago
C: Current channels
reduction: Whether this is reduction cell
"""
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction):
super(DARTSCell, self).__init__()
self.reduction = reduction
self.steps = steps
self.multiplier = multiplier
# Preprocess inputs
if reduction:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
self.preprocess1 = FactorizedReduce(C_prev, C)
else:
self.preprocess0 = Identity() if C_prev_prev == C else \
nn.Conv2d(C_prev_prev, C, 1, 1, 0, bias=False)
self.preprocess1 = Identity() if C_prev == C else \
nn.Conv2d(C_prev, C, 1, 1, 0, bias=False)
# Build DAG
self._ops = nn.ModuleList()
for i in range(self.steps):
for j in range(2 + i): # Connect to previous nodes
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride)
self._ops.append(op)
def forward(self, s0, s1, weights):
"""
Args:
s0: Output from 2 cells ago
s1: Output from 1 cell ago
weights: Architecture weights [num_edges, num_ops]
Returns:
Cell output (concatenated intermediate nodes)
"""
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self.steps):
# Collect inputs from all previous nodes
s = sum(self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states))
offset += len(states)
states.append(s)
# Concatenate last `multiplier` nodes
return torch.cat(states[-self.multiplier:], dim=1)
class DARTSNetwork(nn.Module):
"""
DARTS network (stack of cells).
Args:
C: Initial channels
num_classes: Number of output classes
layers: Number of cells
steps: Intermediate nodes per cell
multiplier: Concatenation factor
"""
def __init__(self, C=16, num_classes=10, layers=8, steps=4, multiplier=4):
super(DARTSNetwork, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._steps = steps
self._multiplier = multiplier
# Stem
C_curr = C * 3 # After stem
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
# Build cells
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
# Reduction cells at 1/3 and 2/3 depth
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = DARTSCell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction)
self.cells.append(cell)
C_prev_prev, C_prev = C_prev, multiplier * C_curr
reduction_prev = reduction
# Classifier
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# Initialize architecture parameters
num_ops = len(OPS)
num_edges = sum(2 + i for i in range(steps))
self._arch_parameters = nn.ParameterList([
nn.Parameter(1e-3 * torch.randn(num_edges, num_ops)), # Normal cell
nn.Parameter(1e-3 * torch.randn(num_edges, num_ops)), # Reduction cell
])
def forward(self, x):
"""Forward pass with current architecture."""
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.softmax(self._arch_parameters[1], dim=-1)
else:
weights = F.softmax(self._arch_parameters[0], dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def arch_parameters(self):
"""Return architecture parameters for optimization."""
return self._arch_parameters
def genotype(self):
"""Discretize architecture (select top operation per edge)."""
def _parse(weights):
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
# Select top-2 edges
edges = sorted(range(i + 2),
key=lambda x: -max(W[x][k] for k in range(len(W[x]))
if OPS[list(OPS.keys())[k]] != 'none'))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if OPS[list(OPS.keys())[k]] == 'none':
continue
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((list(OPS.keys())[k_best], j))
start = end
n += 1
return gene
gene_normal = _parse(F.softmax(self._arch_parameters[0], dim=-1).data.cpu().numpy())
gene_reduce = _parse(F.softmax(self._arch_parameters[1], dim=-1).data.cpu().numpy())
return {'normal': gene_normal, 'reduce': gene_reduce}
# ============================================================================
# 2. RL-Based NAS Controller
# ============================================================================
class NASController(nn.Module):
"""
RNN controller for neural architecture search (Zoph & Le, 2017).
Generates architecture decisions sequentially.
Args:
num_layers: Number of layers to generate
num_branches: Branching factor
lstm_size: LSTM hidden size
lstm_layers: Number of LSTM layers
temperature: Softmax temperature
"""
def __init__(self, num_layers=12, num_branches=6, lstm_size=32,
lstm_layers=2, temperature=5.0):
super(NASController, self).__init__()
self.num_layers = num_layers
self.num_branches = num_branches
self.lstm_size = lstm_size
self.temperature = temperature
# LSTM controller
self.lstm = nn.LSTM(input_size=lstm_size, hidden_size=lstm_size,
num_layers=lstm_layers)
# Embedding for previous decision
self.embedding = nn.Embedding(num_branches + 1, lstm_size)
# Decoder (output layer)
self.decoder = nn.Linear(lstm_size, num_branches)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize controller weights."""
for name, param in self.named_parameters():
if 'weight' in name:
nn.init.uniform_(param, -0.1, 0.1)
elif 'bias' in name:
nn.init.constant_(param, 0)
def forward(self, batch_size=1):
"""
Sample architecture.
Returns:
architecture: List of sampled operations [num_layers]
log_probs: Log probabilities [num_layers]
entropies: Entropies [num_layers]
"""
# Initial hidden state
h = torch.zeros(self.lstm.num_layers, batch_size, self.lstm_size)
c = torch.zeros(self.lstm.num_layers, batch_size, self.lstm_size)
# Start token
inputs = self.embedding(torch.zeros(batch_size, dtype=torch.long))
architecture = []
log_probs = []
entropies = []
for layer in range(self.num_layers):
# LSTM forward
inputs = inputs.unsqueeze(0)
output, (h, c) = self.lstm(inputs, (h, c))
output = output.squeeze(0)
# Decode to logits
logits = self.decoder(output) / self.temperature
probs = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1)
# Sample operation
action = torch.multinomial(probs, 1).squeeze(1)
# Record
architecture.append(action.item())
log_probs.append(log_prob.gather(1, action.unsqueeze(1)).squeeze(1))
# Entropy (for exploration bonus)
entropy = -(log_prob * probs).sum(dim=-1)
entropies.append(entropy)
# Next input = embedding of current action
inputs = self.embedding(action)
log_probs = torch.stack(log_probs, dim=1) # [batch, num_layers]
entropies = torch.stack(entropies, dim=1)
return architecture, log_probs, entropies
def train_step(self, rewards, baseline, optimizer, entropy_weight=0.0001):
"""
REINFORCE update.
Args:
rewards: Validation accuracy [batch]
baseline: Running average reward
optimizer: Controller optimizer
entropy_weight: Coefficient for entropy bonus
"""
# Sample architecture
architecture, log_probs, entropies = self.forward(batch_size=len(rewards))
# REINFORCE loss
advantages = rewards - baseline
loss = -(log_probs * advantages.unsqueeze(1)).sum(dim=1).mean()
# Entropy bonus (encourage exploration)
entropy_bonus = entropies.mean()
loss = loss - entropy_weight * entropy_bonus
# Update
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), 5.0)
optimizer.step()
return loss.item()
# ============================================================================
# 3. Evolutionary Architecture Search
# ============================================================================
Architecture = namedtuple('Architecture', ['arch', 'fitness', 'age'])
class EvolutionarySearch:
"""
Regularized evolution for architecture search (Real et al., 2019).
Args:
search_space: Function that returns random architecture
evaluate_fn: Function to evaluate architecture (returns fitness)
population_size: Number of architectures in population
sample_size: Tournament size for selection
mutation_prob: Probability of mutation
"""
def __init__(self, search_space, evaluate_fn, population_size=100,
sample_size=25, mutation_prob=1.0):
self.search_space = search_space
self.evaluate_fn = evaluate_fn
self.population_size = population_size
self.sample_size = sample_size
self.mutation_prob = mutation_prob
self.population = []
self.history = []
def initialize_population(self):
"""Initialize population with random architectures."""
for _ in range(self.population_size):
arch = self.search_space()
fitness = self.evaluate_fn(arch)
self.population.append(Architecture(arch, fitness, age=0))
print(f"Initialized population: {self.population_size} architectures")
print(f"Best fitness: {max(a.fitness for a in self.population):.4f}")
def mutate(self, parent_arch):
"""
Mutate architecture.
Simple mutation: randomly change one operation.
"""
arch = parent_arch.copy()
if random.random() < self.mutation_prob:
# Random mutation point
idx = random.randint(0, len(arch) - 1)
# Random new operation (different from current)
ops = list(range(len(OPS)))
ops.remove(arch[idx])
arch[idx] = random.choice(ops)
return arch
def evolve_step(self):
"""
Single evolution step (regularized evolution).
1. Sample tournament
2. Remove oldest
3. Mutate best
4. Evaluate and add to population
"""
# Tournament selection
sample = random.sample(self.population, self.sample_size)
# Remove oldest from sample
oldest = max(sample, key=lambda x: x.age)
self.population.remove(oldest)
# Select best from sample
parent = max(sample, key=lambda x: x.fitness)
# Mutate
child_arch = self.mutate(parent.arch)
# Evaluate
child_fitness = self.evaluate_fn(child_arch)
# Add to population with age 0
child = Architecture(child_arch, child_fitness, age=0)
self.population.append(child)
# Age all architectures
self.population = [Architecture(a.arch, a.fitness, a.age + 1)
for a in self.population]
# Record
self.history.append(child_fitness)
return child_arch, child_fitness
def search(self, num_iterations=1000):
"""
Run evolution for num_iterations.
Returns:
best_arch: Best architecture found
best_fitness: Fitness of best architecture
"""
self.initialize_population()
for i in range(num_iterations):
arch, fitness = self.evolve_step()
if (i + 1) % 100 == 0:
best_fitness = max(a.fitness for a in self.population)
print(f"Iteration {i+1}/{num_iterations}, Best: {best_fitness:.4f}, "
f"Current: {fitness:.4f}")
# Return best architecture
best = max(self.population, key=lambda x: x.fitness)
return best.arch, best.fitness
# ============================================================================
# 4. Performance Predictor
# ============================================================================
class ArchitectureEncoder(nn.Module):
"""
Encode architecture as fixed-size vector.
Simple approach: Embedding + pooling.
Args:
num_ops: Number of possible operations
embed_dim: Embedding dimension
hidden_dim: Hidden dimension
"""
def __init__(self, num_ops, embed_dim=32, hidden_dim=64):
super(ArchitectureEncoder, self).__init__()
self.embedding = nn.Embedding(num_ops, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
def forward(self, arch):
"""
Args:
arch: Architecture (list of operation indices) [batch, seq_len]
Returns:
Encoding [batch, hidden_dim * 2]
"""
embedded = self.embedding(arch) # [batch, seq_len, embed_dim]
_, (h, c) = self.lstm(embedded)
encoding = torch.cat([h[-1], c[-1]], dim=-1) # Concatenate hidden and cell
return encoding
class PerformancePredictor(nn.Module):
"""
Predict architecture performance from encoding.
Args:
num_ops: Number of operations
embed_dim: Embedding dimension
hidden_dim: Hidden dimension
"""
def __init__(self, num_ops, embed_dim=32, hidden_dim=64):
super(PerformancePredictor, self).__init__()
self.encoder = ArchitectureEncoder(num_ops, embed_dim, hidden_dim)
# Predictor MLP
self.predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, arch):
"""
Args:
arch: Architecture tensor [batch, seq_len]
Returns:
Predicted accuracy [batch, 1]
"""
encoding = self.encoder(arch)
prediction = self.predictor(encoding)
return prediction
# ============================================================================
# 5. Demonstrations
# ============================================================================
def demo_darts():
"""Demonstrate DARTS network."""
print("="*70)
print("DARTS (Differentiable Architecture Search) Demo")
print("="*70)
# Create DARTS network
model = DARTSNetwork(C=16, num_classes=10, layers=8, steps=4, multiplier=4)
# Sample input
x = torch.randn(2, 3, 32, 32)
# Forward
model.eval()
with torch.no_grad():
logits = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
print()
# Architecture parameters
arch_params = model.arch_parameters()
print(f"Architecture parameters:")
print(f" Normal cell: {arch_params[0].shape} (edges Γ operations)")
print(f" Reduction cell: {arch_params[1].shape}")
print()
# Current architecture (continuous)
normal_weights = F.softmax(arch_params[0], dim=-1)
print(f"Normal cell operation weights (first edge):")
for i, op_name in enumerate(OPS.keys()):
print(f" {op_name:15s}: {normal_weights[0, i].item():.4f}")
print()
# Discretize
genotype = model.genotype()
print(f"Discretized architecture:")
print(f" Normal cell: {genotype['normal'][:4]}") # Show first 4 edges
print(f" Reduction cell: {genotype['reduce'][:4]}")
print()
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
arch_params_count = sum(p.numel() for p in model.arch_parameters())
weight_params = total_params - arch_params_count
print(f"Parameters:")
print(f" Weights: {weight_params:,}")
print(f" Architecture: {arch_params_count:,}")
print(f" Total: {total_params:,}")
print()
def demo_rl_controller():
"""Demonstrate RL-based NAS controller."""
print("="*70)
print("RL-Based NAS Controller Demo")
print("="*70)
# Create controller
controller = NASController(num_layers=12, num_branches=6, lstm_size=32)
# Sample architectures
print("Sampling architectures:")
for i in range(3):
arch, log_probs, entropies = controller.forward(batch_size=1)
print(f" Architecture {i+1}: {arch}")
print(f" Log prob sum: {log_probs.sum().item():.4f}")
print(f" Entropy: {entropies.mean().item():.4f}")
print()
# Count parameters
params = sum(p.numel() for p in controller.parameters())
print(f"Controller parameters: {params:,}")
print()
print("REINFORCE update:")
print(" Sample architecture β Train child network β Get accuracy")
print(" Update controller to maximize expected accuracy")
print(" Gradient: β_ΞΈ J = E[(R - baseline) * β_ΞΈ log Ο_ΞΈ(Ξ±)]")
print()
def demo_evolutionary_search():
"""Demonstrate evolutionary architecture search."""
print("="*70)
print("Evolutionary Architecture Search Demo")
print("="*70)
# Define toy search space (list of 6 operations)
def search_space():
return [random.randint(0, len(OPS) - 1) for _ in range(6)]
# Toy evaluation (random for demo)
def evaluate_fn(arch):
# In practice: train network, return validation accuracy
# Here: random fitness for demonstration
return random.uniform(0.5, 1.0)
# Create evolutionary search
evolution = EvolutionarySearch(
search_space=search_space,
evaluate_fn=evaluate_fn,
population_size=20, # Small for demo
sample_size=5,
mutation_prob=1.0
)
# Run search
best_arch, best_fitness = evolution.search(num_iterations=50)
print()
print(f"Best architecture found: {best_arch}")
print(f"Best fitness: {best_fitness:.4f}")
print()
print("Algorithm:")
print(" 1. Sample tournament (5 architectures)")
print(" 2. Remove oldest from population")
print(" 3. Mutate best architecture")
print(" 4. Evaluate child, add to population")
print(" 5. Repeat")
print()
def demo_performance_predictor():
"""Demonstrate performance predictor."""
print("="*70)
print("Performance Predictor Demo")
print("="*70)
# Create predictor
num_ops = len(OPS)
predictor = PerformancePredictor(num_ops=num_ops, embed_dim=32, hidden_dim=64)
# Sample architectures
batch_size = 4
seq_len = 12
archs = torch.randint(0, num_ops, (batch_size, seq_len))
# Predict
predictor.eval()
with torch.no_grad():
predictions = predictor(archs)
print(f"Input: {batch_size} architectures Γ {seq_len} operations")
print(f"Predictions: {predictions.squeeze().tolist()}")
print()
# Count parameters
params = sum(p.numel() for p in predictor.parameters())
print(f"Predictor parameters: {params:,}")
print()
print("Training procedure:")
print(" 1. Collect dataset: (architecture, accuracy) pairs")
print(" 2. Train predictor via regression (MSE loss)")
print(" 3. Use predictor to evaluate new architectures (fast!)")
print()
print("Benefit: Evaluate architecture in O(1) instead of hours of training")
print()
def print_method_comparison():
"""Print comparison of NAS methods."""
print("="*70)
print("NAS Method Comparison")
print("="*70)
print()
comparison = """
βββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬ββββββββββββββ¬βββββββββββββββ
β Method β Search Cost β Performance β Difficulty β Best For β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Random Search β Low β Baseline β Easy β Baseline β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β RL (NAS) β Very High β High β Hard β Best quality β
β β (22k GPU-day)β β β (unlimited β
β β β β β compute) β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Evolution β High β High β Medium β Large search β
β β (450 GPU-day)β β β spaces β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β DARTS β Low β Good β Medium β Fast search β
β β (4 GPU-day) β β β (research) β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β ENAS β Very Low β Good β Medium β Efficient β
β β (0.5 GPU-day)β β β search β
βββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Once-For-All β Medium β High β Hard β Deployment β
β β (50 GPU-day) β β β flexibility β
β β (train once) β β β β
βββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄ββββββββββββββ΄βββββββββββββββ
**Decision Guide:**
1. **Use Random/Evolutionary if:**
- Want simple baseline
- Have moderate compute budget
- Search space not differentiable
2. **Use DARTS if:**
- Want fast search (4 GPU-days)
- Have differentiable search space
- Doing research (quick iterations)
3. **Use RL if:**
- Have large compute budget
- Want best possible architecture
- Search space is discrete/complex
4. **Use Once-For-All if:**
- Deploy to many devices
- Want flexibility (vary latency/accuracy)
- Can afford one-time training cost
**Performance (ImageNet top-1):**
- Random baseline: ~72%
- NASNet (RL): 82.7%
- AmoebaNet (Evolution): 83.9%
- DARTS: 73.3%
- EfficientNet (NAS): 84.4%
"""
print(comparison)
print()
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
demo_darts()
demo_rl_controller()
demo_evolutionary_search()
demo_performance_predictor()
print_method_comparison()
print("="*70)
print("Neural Architecture Search Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ DARTS: Differentiable search via continuous relaxation (4 GPU-days)")
print(" β’ RL Controller: REINFORCE for architecture sampling (expensive)")
print(" β’ Evolution: Mutation + selection (regularized evolution)")
print(" β’ Predictor: Learn to estimate performance (avoid training)")
print()
print("Key insight: NAS automates architecture design")
print("Trade-off: Search cost vs. architecture quality")
print("Modern trend: Efficient search (weight sharing, one-shot)")
print()
Advanced Neural Architecture Search: Mathematical Foundations and Modern MethodsΒΆ
1. Introduction to Neural Architecture Search (NAS)ΒΆ
Neural Architecture Search automates the design of neural network architectures, replacing manual engineering with algorithmic optimization. The goal is to find the best architecture \(\mathcal{A}^*\) from a search space \(\mathcal{S}\) that maximizes performance on a validation set.
1.1 NAS Problem FormulationΒΆ
subject to:
\(\text{Params}(\mathcal{A}) \leq P_{\max}\) (parameter budget)
\(\text{Latency}(\mathcal{A}) \leq L_{\max}\) (inference time budget)
\(\text{FLOPs}(\mathcal{A}) \leq F_{\max}\) (computational budget)
1.2 Three Components of NASΒΆ
Search Space (\(\mathcal{S}\)): Set of possible architectures
Search Strategy: Algorithm to explore the search space
Performance Estimation: Method to evaluate candidate architectures
1.3 Historical EvolutionΒΆ
Early Era (2016-2017):
Reinforcement Learning NAS (Zoph & Le, 2017)
Cost: 1800 GPU days for CIFAR-10
Search space: Entire network structure
Efficiency Era (2018-2019):
ENAS: Weight sharing (1 GPU day)
DARTS: Differentiable architecture search
ProxylessNAS: Direct hardware optimization
Modern Era (2020-2024):
Zero-shot NAS: No training needed
Once-for-all networks: Train once, search many times
Hardware-aware NAS: Multi-objective optimization
Neural architecture transfer: Cross-task learning
2. Search Space DesignΒΆ
2.1 Macro Search SpaceΒΆ
Search for the entire network structure.
Chain-structured: $\(\text{Network} = f_L \circ f_{L-1} \circ ... \circ f_1(x)\)$
Each layer \(f_i\) can be:
Convolutional layer (with varying kernel sizes, channels)
Pooling layer (max, average)
Skip connections
Batch normalization, activation
Cell-structured (more efficient):
Define reusable cells (normal cell, reduction cell)
Stack cells to form network
Reduces search space from \(O(L^k)\) to \(O(C^k)\) where \(C \ll L\)
2.2 Micro Search Space (Cell-Based)ΒΆ
Normal Cell: Maintains spatial resolution Reduction Cell: Downsamples feature maps
Each cell is a DAG (Directed Acyclic Graph):
\(N\) nodes (intermediate representations)
Each node computes: \(h^{(i)} = \sum_{j < i} o^{(i,j)}(h^{(j)})\)
Operations \(o^{(i,j)}\) selected from operation set \(\mathcal{O}\)
Operation Set \(\mathcal{O}\):
Identity (skip connection)
\(3 \times 3\) separable convolution
\(5 \times 5\) separable convolution
\(3 \times 3\) average pooling
\(3 \times 3\) max pooling
\(3 \times 3\) dilated convolution
Zero (no connection)
2.3 Search Space EncodingΒΆ
Discrete encoding: $\(\mathcal{A} = (o_1, o_2, ..., o_K) \text{ where } o_i \in \mathcal{O}\)$
Continuous encoding (for DARTS): $\(\mathcal{A} = (\alpha_{ij}^{(k)}) \text{ where } \alpha_{ij}^{(k)} \in \mathbb{R}\)$
Graph encoding:
Adjacency matrix \(A \in \{0,1\}^{N \times N}\)
Operation matrix \(O \in \mathcal{O}^{N \times N}\)
2.4 Search Space SizeΒΆ
For a cell with \(N\) nodes and \(|\mathcal{O}|\) operations:
Without constraints: $\(|\mathcal{S}| = |\mathcal{O}|^{\binom{N}{2}}\)$
Example: \(N=7\) nodes, \(|\mathcal{O}|=8\) operations $\(|\mathcal{S}| = 8^{21} \approx 10^{19}\)$
With constraints (each node has 2 inputs): $\(|\mathcal{S}| = \prod_{i=2}^{N} \binom{i-1}{2} \cdot |\mathcal{O}|^2 \approx 10^{14}\)$
3. Search StrategiesΒΆ
3.1 Reinforcement Learning (RL) Based NASΒΆ
NAS with RL (Zoph & Le, 2017):
Controller: RNN that generates architecture descriptions
Training:
Sample architecture \(\mathcal{A} \sim p_\theta\) from controller with parameters \(\theta\)
Train \(\mathcal{A}\) on training set, evaluate on validation set β accuracy \(R\)
Update controller: \(\nabla_\theta J(\theta) = \mathbb{E}_{\mathcal{A} \sim p_\theta}[R \cdot \nabla_\theta \log p_\theta(\mathcal{A})]\)
REINFORCE gradient: $\(\nabla_\theta J(\theta) \approx \frac{1}{B} \sum_{i=1}^B (R_i - b) \nabla_\theta \log p_\theta(\mathcal{A}_i)\)$
where \(b\) is baseline (moving average of rewards).
Challenges:
High variance in gradients
Requires training thousands of architectures
Extremely expensive (1800 GPU days)
3.2 Evolutionary AlgorithmsΒΆ
Evolution process:
Initialize population \(P = \{\mathcal{A}_1, ..., \mathcal{A}_N\}\)
Evaluate fitness (validation accuracy) for each architecture
Select top-\(k\) architectures (parents)
Generate offspring via mutation/crossover
Repeat until convergence
Mutation operators:
Add/remove layer
Change operation type
Modify hyperparameters (kernel size, channels)
Add/remove skip connections
Crossover operators:
Combine layers from two parent architectures
Cell-level crossover (swap cells between parents)
Advantages:
Simple to implement
Naturally handles multi-objective optimization
No gradient required
Disadvantages:
Slow convergence
Still requires evaluating many architectures
3.3 Gradient-Based (DARTS)ΒΆ
Key idea: Relax discrete search space to continuous, optimize with gradient descent.
Continuous relaxation: $\(\bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x)\)$
where \(\alpha = \{\alpha_o^{(i,j)}\}\) are architecture parameters.
Bilevel optimization: $\(\begin{align} \min_\alpha \quad & \mathcal{L}_{\text{val}}(w^*(\alpha), \alpha) \\ \text{s.t.} \quad & w^*(\alpha) = \arg\min_w \mathcal{L}_{\text{train}}(w, \alpha) \end{align}\)$
Approximate solution:
Update weights: \(w \leftarrow w - \xi \nabla_w \mathcal{L}_{\text{train}}(w, \alpha)\)
Update architecture: \(\alpha \leftarrow \alpha - \eta \nabla_\alpha \mathcal{L}_{\text{val}}(w, \alpha)\)
Discretization (after search): $\(o^{(i,j)} = \arg\max_{o \in \mathcal{O}} \alpha_o^{(i,j)}\)$
Advantages:
Fast (1-2 GPU days)
Differentiable end-to-end
Naturally handles mixed operations
Challenges:
Gap between continuous and discrete architectures
Prone to collapse (all weights to skip connections)
Biased towards parameter-free operations
3.4 One-Shot NASΒΆ
Weight sharing: Train a supernet containing all possible architectures, then search without retraining.
Supernet: $\(\mathcal{A}_{\text{super}} = \bigcup_{\mathcal{A} \in \mathcal{S}} \mathcal{A}\)$
Training:
Sample architecture \(\mathcal{A} \sim \mathcal{S}\) in each iteration
Update shared weights corresponding to \(\mathcal{A}\)
After training, search by evaluating sub-networks
ENAS (Efficient NAS):
Controller samples architectures
All architectures share weights
Reduces search from 1800 to 0.5 GPU days
RandomNAS:
Sample random architectures from supernet
Evaluate and pick best
No controller needed
3.5 Zero-Shot NASΒΆ
Predict architecture performance without training.
Proxies for performance:
Training speed: \(\partial \mathcal{L} / \partial \text{iter}\) (how fast loss decreases)
Gradient norm: \(\|\nabla_w \mathcal{L}\|\) at initialization
NTK condition number: Neural Tangent Kernel eigenvalue ratio
Synaptic diversity: Variance in weight values
NASWOT (NAS Without Training): $\(\text{Score}(\mathcal{A}) = \log \det \text{Cov}(\text{Jacobian}(\mathcal{A}))\)$
Architectures with higher Jacobian covariance determinant train better.
Advantages:
Extremely fast (seconds to minutes)
No GPU required for evaluation
Can pre-screen thousands of architectures
Limitations:
Proxies may not correlate perfectly with final accuracy
Still need to train best candidates
4. Performance Estimation StrategiesΒΆ
4.1 Full Training (Baseline)ΒΆ
Train each candidate architecture to convergence on full dataset.
Cost: \(T_{\text{train}} \times N_{\text{candidates}}\)
Example: 50 epochs Γ 1000 candidates = 50,000 GPU-hours
4.2 Early StoppingΒΆ
Train for fewer epochs (e.g., 5-10) and estimate final performance.
Learning curve extrapolation: $\(\text{Acc}(t) = a - b \cdot t^{-\gamma}\)$
Fit curve from early epochs, predict final accuracy.
Speedup: 5-10Γ faster
4.3 Weight InheritanceΒΆ
Initialize new architectures with weights from similar architectures.
Network morphism: Transform architecture while preserving function
Deepen: \(y = \text{ReLU}(Wx) \to y = \text{ReLU}(W_2 \text{ReLU}(W_1 x))\) with \(W_2 W_1 = W\)
Widen: Duplicate neurons with weight splitting
4.4 HypernetworksΒΆ
Learn a model \(h_\phi\) that predicts weights for any architecture: $\(w_\mathcal{A} = h_\phi(\mathcal{A})\)$
Training:
Sample architectures and train them
Learn hypernetwork to map architecture β weights
Use hypernetwork to initialize new architectures
4.5 Performance PredictorsΒΆ
Train a regression model \(f\) to predict accuracy: $\(\hat{\text{Acc}} = f(\text{encode}(\mathcal{A}))\)$
Encodings:
Graph neural network on architecture DAG
Sequence encoding for layer configurations
Path encoding (from input to output)
Training:
Evaluate \(N\) random architectures
Train predictor on \(\{(\mathcal{A}_i, \text{Acc}_i)\}\)
Use predictor to screen new candidates
5. Multi-Objective NASΒΆ
Optimize multiple objectives simultaneously: $\(\max_{\mathcal{A} \in \mathcal{S}} \{f_1(\mathcal{A}), f_2(\mathcal{A}), ..., f_k(\mathcal{A})\}\)$
where:
\(f_1\): Accuracy
\(f_2\): \(-\)Latency (negative for maximization)
\(f_3\): \(-\)Parameters
\(f_4\): \(-\)Energy consumption
5.1 Pareto OptimalityΒΆ
Architecture \(\mathcal{A}\) is Pareto optimal if no other architecture dominates it in all objectives.
Pareto front: Set of all Pareto optimal solutions.
5.2 Weighted SumΒΆ
where \(\sum_i w_i = 1\), \(w_i \geq 0\).
Limitation: Cannot find all Pareto optimal solutions (only convex parts).
5.3 Evolutionary Multi-ObjectiveΒΆ
NSGA-II (Non-dominated Sorting Genetic Algorithm):
Rank population by domination (non-dominated = rank 1)
Compute crowding distance (diversity in objective space)
Select based on rank and crowding distance
Generate offspring, repeat
5.4 Hardware-Aware NASΒΆ
Latency modeling: $\(\text{Latency}(\mathcal{A}) = \sum_{\text{layer } i} f_{\text{hw}}(\text{layer}_i)\)$
where \(f_{\text{hw}}\) is measured or predicted latency on target hardware.
Direct optimization:
ProxylessNAS: Latency as regularization term
FBNet: Latency lookup table for each operation
Once-for-All: Elastic networks for different constraints
6. Advanced NAS Methods (2020-2024)ΒΆ
6.1 Once-for-All Networks (OFA)ΒΆ
Key idea: Train a single network that supports diverse architectures.
Elastic dimensions:
Depth: 2-5 blocks per stage
Width: 0.5Γ, 0.75Γ, 1.0Γ channels
Kernel size: 3Γ3, 5Γ5, 7Γ7
Resolution: 128, 160, 192, 224 pixels
Progressive shrinking:
Train largest network (full capacity)
Progressively support smaller elastic dimensions
Each sub-network shares weights with supernet
Deployment:
Given constraint (latency < 20ms)
Evolutionary search in trained OFA
Extract sub-network without retraining
Benefits:
Train once (1200 GPU-hours)
Search many times (40 GPU-hours each)
Amortized cost very low
6.2 Neural Architecture TransferΒΆ
Pre-train on source task, transfer to target task.
Weight sharing across tasks:
Learn task-agnostic features in supernet
Fine-tune for specific task
Reduces search cost on new tasks
Meta-learning for NAS: $\(\theta^* = \arg\min_\theta \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} [\mathcal{L}_{\text{NAS}}(\theta, \mathcal{T})]\)$
Train NAS controller on distribution of tasks.
6.3 AutoML FrameworksΒΆ
NAS-Bench: Benchmark datasets with pre-computed architectures
NAS-Bench-101: 423k architectures on CIFAR-10
NAS-Bench-201: 6k+ architectures, multiple datasets
Enables reproducible NAS research
NNI (Neural Network Intelligence):
Microsoft framework for AutoML
Supports RL, evolution, DARTS, ENAS
Built-in model compression
Auto-Keras:
High-level API for NAS
Bayesian optimization
Automatically handles data preprocessing
6.4 Transformer NASΒΆ
Search for optimal Transformer architectures.
Search dimensions:
Number of layers
Hidden dimension
Number of attention heads
FFN expansion ratio
Attention pattern (full, sparse, local)
AutoFormer (Chen et al., 2021):
Weight-sharing supernet for Transformers
Evolutionary search
Found architectures competitive with BERT
Primer (So et al., 2021):
Evolution-based search for Transformer primitives
Discovered squared ReLU activation
Modified attention mechanisms
6.5 Vision Transformer NASΒΆ
AutoViT:
Search for ViT patch size, embedding dimension, depth
Hardware-aware optimization
2-3Γ faster than DeiT with same accuracy
BossNAS:
Billion-scale search space for vision
Mixed CNN and Transformer operations
State-of-the-art on ImageNet
7. Theoretical FoundationsΒΆ
7.1 Neural Tangent Kernel (NTK) for NASΒΆ
NTK characterizes network behavior during training: $\(\Theta(x, x') = \langle \nabla_w f(x, w), \nabla_w f(x', w) \rangle\)$
At initialization, \(\Theta\) determines:
Convergence speed
Final performance
Generalization
NAS application: Architectures with better-conditioned NTK train faster and generalize better.
7.2 Loss Landscape AnalysisΒΆ
Sharpness: Eigenvalues of Hessian \(\nabla^2 \mathcal{L}(w)\)
Sharp minima: Large eigenvalues β poor generalization
Flat minima: Small eigenvalues β good generalization
NAS metric: Prefer architectures with flatter loss landscapes.
7.3 Expressivity vs TrainabilityΒΆ
Expressivity: Can the architecture represent the target function?
Universal approximation: Most architectures are expressive enough for practical tasks.
Trainability: Can we efficiently find good parameters via gradient descent?
Key insight: NAS should focus on trainability, not just expressivity.
7.4 Architecture-Training Co-DesignΒΆ
Traditional: Architecture search β Train
Modern: Interleave architecture search and training
Benefits:
Architecture adapts to current training stage
Better exploration of architecture-weight space
Avoids local optima in joint space
8. NAS for Specialized DomainsΒΆ
8.1 NAS for Object DetectionΒΆ
Search space:
Backbone architecture (feature extractor)
Feature pyramid network (FPN) connections
Detection head architecture
DetNAS: NAS for object detection
Search backbone and FPN jointly
Multi-scale architecture
COCO mAP improvements
SpineNet: Learned scale-permuted networks
Cross-scale connections learned via NAS
Better than hand-designed FPNs
8.2 NAS for Semantic SegmentationΒΆ
Auto-DeepLab:
Search cell structure and network-level connections
Hierarchical search space
State-of-the-art on Cityscapes
Search dimensions:
Encoder-decoder architecture
Skip connections across scales
Atrous convolution rates
8.3 NAS for Speech RecognitionΒΆ
Evolved Transformer for ASR:
Evolution-based search
Discovered architectural improvements over baseline Transformer
Lower WER on LibriSpeech
Neural Architecture Search for LSTMs:
Search for optimal LSTM cell structure
Found better alternatives to standard LSTM gates
8.4 NAS for Recommender SystemsΒΆ
AutoCTR:
Search feature interaction architectures
Embedding dimensions per feature
Improved CTR prediction
AMER:
Automated model search for recommendations
User-item interaction modeling
Better than hand-crafted models
9. Practical ConsiderationsΒΆ
9.1 Search Cost AnalysisΒΆ
Total cost: $\(\text{Cost} = N_{\text{arch}} \times T_{\text{train}} \times C_{\text{data}}\)$
where:
\(N_{\text{arch}}\): Number of architectures evaluated
\(T_{\text{train}}\): Training time per architecture
\(C_{\text{data}}\): Data cost (labels, compute)
Reduction strategies:
Early stopping: Reduce \(T_{\text{train}}\) by 5-10Γ
Weight sharing: Reduce \(N_{\text{arch}} \times T_{\text{train}}\) to \(T_{\text{supernet}}\)
Zero-shot: Reduce \(T_{\text{train}}\) to near zero
Proxy datasets: Reduce \(C_{\text{data}}\) (CIFAR-10 instead of ImageNet)
9.2 ReproducibilityΒΆ
Key factors:
Random seed for architecture sampling
Training hyperparameters (learning rate, batch size)
Hardware (GPU type affects measured latency)
Data augmentation policy
Early stopping criteria
Best practices:
Report search algorithm details
Provide code and search logs
Multiple runs with different seeds
Report mean and variance of top architectures
9.3 Fairness in ComparisonΒΆ
Common pitfalls:
Different training budgets (epochs, data)
Different search spaces
Cherry-picking best result from multiple runs
Post-search hyperparameter tuning
Fair comparison:
Same search space
Same total compute budget
Report average performance, not just best
Separate search and evaluation sets
9.4 When to Use NASΒΆ
Use NAS when:
β Have sufficient compute budget (>100 GPU-hours)
β Target hardware differs from standard (edge devices, mobile)
β Task-specific architecture requirements
β Need to beat hand-designed baselines
β Can amortize search cost across many deployments
Avoid NAS when:
β Limited compute budget (<10 GPU-hours)
β Standard architectures work well (ImageNet classification)
β Small dataset (overfitting risk in search)
β One-time deployment (search cost not amortized)
10. Software Tools and FrameworksΒΆ
10.1 AutoGluonΒΆ
from autogluon.vision import ImagePredictor
predictor = ImagePredictor()
predictor.fit('train/', time_limit=3600) # 1 hour
Features:
Automatic architecture search
Hyperparameter optimization
Ensemble of models
Easy deployment
10.2 NNI (Neural Network Intelligence)ΒΆ
Supported algorithms:
DARTS, ENAS, SPOS, ProxylessNAS
Hyperparameter tuning (TPE, BOHB)
Model compression (pruning, quantization)
Workflow:
Define search space (JSON)
Implement training code with NNI API
Configure search algorithm
Launch experiment, monitor progress
10.3 NAS-Bench IntegrationΒΆ
Pre-computed results for fast experimentation:
from nas_bench_201 import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth')
arch = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|'
info = api.get_more_info(arch, 'cifar10', hp='200')
print(f"Accuracy: {info['test-accuracy']}")
No training neededβinstant results!
10.4 Custom NAS ImplementationΒΆ
Minimal DARTS:
# Architecture parameters
alpha = nn.Parameter(torch.randn(n_edges, n_ops))
# Mixed operation
def mixed_op(x, alpha):
return sum(w * op(x) for w, op in zip(F.softmax(alpha), ops))
# Bilevel optimization
for epoch in range(epochs):
# Update weights
optimizer_w.zero_grad()
loss_train = loss(model(x_train, alpha), y_train)
loss_train.backward()
optimizer_w.step()
# Update architecture
optimizer_alpha.zero_grad()
loss_val = loss(model(x_val, alpha), y_val)
loss_val.backward()
optimizer_alpha.step()
11. Recent Trends and Future Directions (2024-2026)ΒΆ
11.1 Foundation Model NASΒΆ
Search architectures for pre-training, not just downstream tasks.
Challenges:
Search cost scales with model size
Need to evaluate on multiple tasks
Transfer learning uncertainty
Approaches:
Meta-learning across task distributions
Architecture transferability metrics
Efficient supernet training
11.2 Self-Supervised NASΒΆ
Search architectures optimized for self-supervised learning.
SimCLR-NAS: Search encoder for contrastive learning
BYOL-NAS: Architecture search for bootstrap learning
Key insight: Architectures good for supervised learning β architectures good for self-supervised learning.
11.3 Neural Architecture GenerationΒΆ
Generative models for architectures:
VAE for architecture space
GAN to generate high-performing architectures
Diffusion models for architecture sampling
Benefits:
Learn distribution of good architectures
Sample novel designs
Interpolate between architectures
11.4 Federated NASΒΆ
Search architectures across distributed devices without centralizing data.
Challenges:
Heterogeneous hardware
Communication costs
Privacy constraints
FedNAS:
Local architecture search on each device
Aggregate architecture parameters
Find architecture that works well for all clients
11.5 Interpretable NASΒΆ
Explainable architecture decisions:
Why was operation \(o\) chosen over \(o'\)?
Which architectural components contribute most to performance?
Visualize architecture evolution during search
Techniques:
Attention over operations
Shapley values for architecture components
Causal analysis of design choices
12. Mathematical Deep DiveΒΆ
12.1 DARTS Gradient ApproximationΒΆ
Bilevel optimization: $\(\nabla_\alpha \mathcal{L}_{\text{val}}(w^*(\alpha), \alpha)\)$
Chain rule: $\(= \nabla_\alpha \mathcal{L}_{\text{val}} + \nabla_\alpha w^* \cdot \nabla_w \mathcal{L}_{\text{val}}\)$
Implicit function theorem: $\(\nabla_\alpha w^* = -[\nabla^2_{ww} \mathcal{L}_{\text{train}}]^{-1} \nabla^2_{w\alpha} \mathcal{L}_{\text{train}}\)$
Approximation (avoid expensive Hessian inverse): $\(\nabla_\alpha w^* \approx -\frac{\nabla_w \mathcal{L}_{\text{val}} \nabla^2_{w\alpha} \mathcal{L}_{\text{train}}}{\|\nabla_w \mathcal{L}_{\text{train}}\|^2}\)$
12.2 Expected Improvement for Bayesian NASΒΆ
Acquisition function: $\(\alpha_{\text{EI}}(\mathcal{A}) = \mathbb{E}[\max(f(\mathcal{A}) - f(\mathcal{A}_{\text{best}}), 0)]\)$
where \(f\) is surrogate model (Gaussian Process).
Closed form (for GP): $\(\alpha_{\text{EI}}(\mathcal{A}) = (\mu(\mathcal{A}) - f_{\text{best}}) \Phi(Z) + \sigma(\mathcal{A}) \phi(Z)\)$
where:
\(Z = (\mu(\mathcal{A}) - f_{\text{best}}) / \sigma(\mathcal{A})\)
\(\Phi\): CDF of standard normal
\(\phi\): PDF of standard normal
12.3 Evolutionary Algorithm ConvergenceΒΆ
Schema theorem (genetic algorithms):
Expected number of schema \(H\) in next generation: $\(m(H, t+1) \geq m(H, t) \frac{f(H)}{\bar{f}} \left[1 - p_c \frac{d(H)}{l-1} - p_m o(H)\right]\)$
where:
\(f(H)\): Average fitness of schema
\(\bar{f}\): Population average fitness
\(p_c\): Crossover probability
\(d(H)\): Defining length of schema
\(p_m\): Mutation probability
\(o(H)\): Order of schema (number of fixed positions)
Insight: Short, low-order, high-fitness schemas grow exponentially.
13. Case StudiesΒΆ
13.1 EfficientNet (NAS + Scaling)ΒΆ
Compound scaling: $\(\text{depth: } d = \alpha^\phi, \quad \text{width: } w = \beta^\phi, \quad \text{resolution: } r = \gamma^\phi\)$
subject to \(\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2\)
NAS for base model:
Search on CIFAR-10 (proxy task)
Transfer to ImageNet
Scale up with compound scaling
Results:
EfficientNet-B0: 77.1% top-1, 5.3M params
EfficientNet-B7: 84.3% top-1, 66M params
State-of-the-art accuracy-efficiency trade-off
13.2 GPT-4 Architecture (Rumored)ΒΆ
Mixture of Experts via NAS:
Each layer routes tokens to different experts
Experts specialized for different input types
Routing learned during pre-training
Search dimensions:
Number of experts per layer
Expert capacity
Routing algorithm
Load balancing strategy
Benefits:
1T+ parameters with manageable compute
Better quality than dense models
Efficient inference via sparse activation
13.3 AlphaCode ArchitectureΒΆ
Multi-stage NAS:
Search encoder architecture (code understanding)
Search decoder architecture (code generation)
Search sampling strategies
Novel components discovered:
Specialized attention for code syntax
Multi-scale code embeddings
Hybrid convolution-transformer blocks
14. Summary and Best PracticesΒΆ
Key Takeaways:ΒΆ
Search space design is crucial: Constraining search space reduces cost and improves results
Weight sharing enables efficiency: Train once, search many times
Multi-objective matters: Accuracy alone insufficient for deployment
Zero-shot methods emerging: Performance prediction without training
Hardware-aware is essential: Latency/energy often more important than FLOPs
When to Use Which Method:ΒΆ
Scenario |
Recommended Method |
Justification |
|---|---|---|
First-time NAS |
DARTS or ENAS |
Good balance of speed and quality |
Limited budget (<1 GPU-day) |
Zero-shot NAS |
No training needed |
Mobile deployment |
Once-for-All or ProxylessNAS |
Hardware-aware optimization |
Novel task |
Evolutionary NAS |
Flexible, no gradient required |
Production at scale |
Once-for-All |
Train once, deploy many configs |
Research/benchmarking |
NAS-Bench |
Reproducible, fast experiments |
Common Pitfalls:ΒΆ
β Overfitting to search space: Good NAS result doesnβt mean optimal architecture
β Ignoring hardware: FLOPs β latency
β Unfair comparisons: Different training budgets invalidate results
β Not checking transferability: Search on CIFAR, deploy on ImageNet?
β Forgetting deployment constraints: Memory, power, latency requirements
Future Research Directions:ΒΆ
Efficient large-scale NAS: Search for billion-parameter models
Task-agnostic architectures: Find architectures that work across tasks
Architecture-data co-design: Jointly optimize data and architecture
Continual architecture search: Adapt architecture as data distribution shifts
Certified robustness via NAS: Search for adversarially robust architectures
Conclusion: NAS has matured from expensive academic curiosity to practical tool. Modern methods (OFA, zero-shot) make NAS accessible for real-world deployment. The future lies in scaling NAS to foundation models and making it fully automated.
"""
Advanced Neural Architecture Search - Production Implementation
Comprehensive implementations of modern NAS methods
Methods Implemented:
1. DARTS (Differentiable Architecture Search)
2. ENAS (Efficient Neural Architecture Search with weight sharing)
3. RandomNAS (baseline with weight-sharing supernet)
4. Evolutionary NAS (genetic algorithm)
5. Zero-Shot NAS (NASWOT - NAS Without Training)
6. Performance Predictors (GNN-based)
7. Once-for-All Networks (elastic supernet)
8. Multi-Objective NAS (NSGA-II)
Features:
- Cell-based search space (NAS-Bench-201 compatible)
- Hardware-aware latency modeling
- Multi-objective optimization (accuracy, latency, parameters)
- Visualization and analysis tools
Author: Advanced Deep Learning Course
Date: 2024
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Callable
import numpy as np
import random
from collections import defaultdict
import time
# ============================================================================
# 1. Search Space Definition (Cell-Based)
# ============================================================================
class SearchSpace:
"""
Cell-based search space similar to NAS-Bench-201.
Each cell is a DAG with N nodes, each node connected to previous nodes
via operations from a predefined operation set.
"""
PRIMITIVES = [
'none', # No connection
'skip_connect', # Identity
'conv_1x1', # 1Γ1 convolution
'conv_3x3', # 3Γ3 convolution
'avg_pool_3x3', # 3Γ3 average pooling
'max_pool_3x3', # 3Γ3 max pooling
]
def __init__(self, n_nodes: int = 4):
"""
Args:
n_nodes: Number of intermediate nodes in cell
"""
self.n_nodes = n_nodes
self.n_ops = len(self.PRIMITIVES)
# Compute search space size
# Each node i connects to 2 predecessors from {0, 1, ..., i-1}
# Each connection has n_ops choices
self.n_edges = sum(range(2, n_nodes + 2)) # (2+3+...+n_nodes+1)
self.search_space_size = self.n_ops ** self.n_edges
print(f"Search Space: {self.n_nodes} nodes, {self.n_edges} edges")
print(f"Space size: {self.search_space_size:.2e} architectures")
def random_architecture(self) -> List[int]:
"""Sample random architecture (list of operation indices)."""
return [random.randint(0, self.n_ops - 1) for _ in range(self.n_edges)]
def encode_architecture(self, arch: List[int]) -> str:
"""Encode architecture as string."""
ops = [self.PRIMITIVES[i] for i in arch]
return '|'.join(ops)
def decode_architecture(self, arch_str: str) -> List[int]:
"""Decode architecture string to indices."""
ops = arch_str.split('|')
return [self.PRIMITIVES.index(op) for op in ops]
def mutate(self, arch: List[int], n_mutations: int = 1) -> List[int]:
"""Randomly mutate architecture."""
arch = arch.copy()
for _ in range(n_mutations):
idx = random.randint(0, len(arch) - 1)
arch[idx] = random.randint(0, self.n_ops - 1)
return arch
def crossover(self, arch1: List[int], arch2: List[int]) -> Tuple[List[int], List[int]]:
"""Single-point crossover between two architectures."""
point = random.randint(1, len(arch1) - 1)
child1 = arch1[:point] + arch2[point:]
child2 = arch2[:point] + arch1[point:]
return child1, child2
# ============================================================================
# 2. Operations (Building Blocks)
# ============================================================================
class ReLUConvBN(nn.Module):
"""ReLU + Conv + BatchNorm"""
def __init__(self, C_in: int, C_out: int, kernel_size: int, stride: int = 1, padding: int = 0):
super().__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class OPS:
"""Operation factory"""
@staticmethod
def get_op(op_name: str, C: int, stride: int = 1) -> nn.Module:
"""
Get operation by name.
Args:
op_name: Operation name from PRIMITIVES
C: Number of channels
stride: Stride (1 for normal cell, 2 for reduction cell)
"""
if op_name == 'none':
return Zero(stride)
elif op_name == 'skip_connect':
if stride == 1:
return nn.Identity()
else:
return FactorizedReduce(C, C)
elif op_name == 'conv_1x1':
return ReLUConvBN(C, C, 1, stride=stride, padding=0)
elif op_name == 'conv_3x3':
return ReLUConvBN(C, C, 3, stride=stride, padding=1)
elif op_name == 'avg_pool_3x3':
if stride == 1:
return nn.AvgPool2d(3, stride=1, padding=1)
else:
return nn.AvgPool2d(3, stride=stride, padding=1)
elif op_name == 'max_pool_3x3':
if stride == 1:
return nn.MaxPool2d(3, stride=1, padding=1)
else:
return nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError(f"Unknown operation: {op_name}")
class Zero(nn.Module):
"""Zero operation (no connection)"""
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
"""Factorized reduction for stride=2 skip connections"""
def __init__(self, C_in: int, C_out: int):
super().__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
# ============================================================================
# 3. Cell and Network Architecture
# ============================================================================
class Cell(nn.Module):
"""
Cell module (DAG of operations).
Each node computes: h_i = sum_{j < i} op_{i,j}(h_j)
"""
def __init__(
self,
n_nodes: int,
C: int,
arch: List[int],
primitives: List[str],
stride: int = 1
):
super().__init__()
self.n_nodes = n_nodes
self.primitives = primitives
# Build edges (operations)
self.edges = nn.ModuleDict()
edge_idx = 0
for i in range(2, n_nodes + 2): # Nodes 2, 3, ..., n_nodes+1
for j in range(i): # Connect to all previous nodes
op_name = primitives[arch[edge_idx]]
op = OPS.get_op(op_name, C, stride if j < 2 else 1)
self.edges[f'{i}_{j}'] = op
edge_idx += 1
def forward(self, x):
"""
Args:
x: Input tensor [batch, C, H, W]
Returns:
output: [batch, C, H', W']
"""
# Initialize nodes: node 0 and 1 are input
nodes = [x, x]
# Compute each node
for i in range(2, self.n_nodes + 2):
node_sum = sum(self.edges[f'{i}_{j}'](nodes[j]) for j in range(i))
nodes.append(node_sum)
# Output: concatenate all intermediate nodes
return sum(nodes[2:]) / len(nodes[2:])
class Network(nn.Module):
"""
Complete network with stacked cells.
Architecture: Stem β [Normal Cell Γ N] β Reduction β [Normal Cell Γ N] β Reduction β [Normal Cell Γ N] β Classifier
"""
def __init__(
self,
C: int,
n_classes: int,
n_layers: int,
arch: List[int],
n_nodes: int = 4,
primitives: List[str] = SearchSpace.PRIMITIVES
):
super().__init__()
self.C = C
self.n_layers = n_layers
# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, C, 3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
# Cells
self.cells = nn.ModuleList()
C_curr = C
reduction_layers = [n_layers // 3, 2 * n_layers // 3]
for i in range(n_layers):
if i in reduction_layers:
C_curr *= 2
cell = Cell(n_nodes, C_curr // 2, arch, primitives, stride=2)
else:
cell = Cell(n_nodes, C_curr, arch, primitives, stride=1)
self.cells.append(cell)
# Classifier
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_curr, n_classes)
def forward(self, x):
x = self.stem(x)
for cell in self.cells:
x = cell(x)
x = self.global_pooling(x)
x = x.view(x.size(0), -1)
logits = self.classifier(x)
return logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters())
# ============================================================================
# 4. DARTS (Differentiable Architecture Search)
# ============================================================================
class MixedOp(nn.Module):
"""Mixed operation with continuous relaxation."""
def __init__(self, C: int, primitives: List[str], stride: int = 1):
super().__init__()
self.ops = nn.ModuleList([OPS.get_op(prim, C, stride) for prim in primitives])
def forward(self, x, weights):
"""
Args:
x: Input tensor
weights: Softmax weights [n_ops]
Returns:
Weighted sum of operations
"""
return sum(w * op(x) for w, op in zip(weights, self.ops))
class DARTSCell(nn.Module):
"""DARTS cell with continuous architecture parameters."""
def __init__(self, n_nodes: int, C: int, primitives: List[str], stride: int = 1):
super().__init__()
self.n_nodes = n_nodes
self.primitives = primitives
self.edges = nn.ModuleDict()
for i in range(2, n_nodes + 2):
for j in range(i):
self.edges[f'{i}_{j}'] = MixedOp(C, primitives, stride if j < 2 else 1)
def forward(self, x, alpha):
"""
Args:
x: Input
alpha: Architecture parameters [n_edges, n_ops]
Returns:
Cell output
"""
nodes = [x, x]
edge_idx = 0
for i in range(2, self.n_nodes + 2):
edges_to_i = []
for j in range(i):
weights = F.softmax(alpha[edge_idx], dim=0)
edges_to_i.append(self.edges[f'{i}_{j}'](nodes[j], weights))
edge_idx += 1
nodes.append(sum(edges_to_i))
return sum(nodes[2:]) / len(nodes[2:])
class DARTSNetwork(nn.Module):
"""DARTS supernet with learnable architecture parameters."""
def __init__(
self,
C: int,
n_classes: int,
n_layers: int,
n_nodes: int = 4,
primitives: List[str] = SearchSpace.PRIMITIVES
):
super().__init__()
self.n_layers = n_layers
self.n_nodes = n_nodes
self.primitives = primitives
# Architecture parameters
n_edges = sum(range(2, n_nodes + 2))
self.alphas = nn.Parameter(torch.randn(n_edges, len(primitives)) * 1e-3)
# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, C, 3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
# Cells
self.cells = nn.ModuleList()
C_curr = C
reduction_layers = [n_layers // 3, 2 * n_layers // 3]
for i in range(n_layers):
if i in reduction_layers:
C_curr *= 2
cell = DARTSCell(n_nodes, C_curr // 2, primitives, stride=2)
else:
cell = DARTSCell(n_nodes, C_curr, primitives, stride=1)
self.cells.append(cell)
# Classifier
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_curr, n_classes)
def forward(self, x):
x = self.stem(x)
for cell in self.cells:
x = cell(x, self.alphas)
x = self.global_pooling(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def discretize(self) -> List[int]:
"""Extract discrete architecture from continuous parameters."""
return [self.primitives[i.item()] for i in self.alphas.argmax(dim=1)]
def get_arch_indices(self) -> List[int]:
"""Get architecture as list of operation indices."""
return self.alphas.argmax(dim=1).tolist()
def train_darts(
model: DARTSNetwork,
train_loader,
val_loader,
n_epochs: int = 50,
lr_w: float = 0.025,
lr_alpha: float = 3e-4,
device: str = 'cuda'
):
"""
Train DARTS model with bilevel optimization.
Args:
model: DARTS supernet
train_loader: Training data
val_loader: Validation data
n_epochs: Number of epochs
lr_w: Learning rate for weights
lr_alpha: Learning rate for architecture parameters
"""
model = model.to(device)
# Optimizers
optimizer_w = torch.optim.SGD(
[p for n, p in model.named_parameters() if n != 'alphas'],
lr=lr_w, momentum=0.9, weight_decay=3e-4
)
optimizer_alpha = torch.optim.Adam([model.alphas], lr=lr_alpha)
scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_w, n_epochs)
for epoch in range(n_epochs):
model.train()
for (x_train, y_train), (x_val, y_val) in zip(train_loader, val_loader):
x_train, y_train = x_train.to(device), y_train.to(device)
x_val, y_val = x_val.to(device), y_val.to(device)
# Update weights
optimizer_w.zero_grad()
logits = model(x_train)
loss_train = F.cross_entropy(logits, y_train)
loss_train.backward()
optimizer_w.step()
# Update architecture
optimizer_alpha.zero_grad()
logits_val = model(x_val)
loss_val = F.cross_entropy(logits_val, y_val)
loss_val.backward()
optimizer_alpha.step()
scheduler_w.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {loss_train.item():.4f}, Val Loss: {loss_val.item():.4f}")
return model.get_arch_indices()
# ============================================================================
# 5. Evolutionary NAS
# ============================================================================
class EvolutionaryNAS:
"""
Evolutionary algorithm for NAS.
Uses tournament selection, mutation, and crossover to evolve population.
"""
def __init__(
self,
search_space: SearchSpace,
population_size: int = 50,
n_generations: int = 20,
tournament_size: int = 5,
mutation_prob: float = 0.1,
crossover_prob: float = 0.5
):
self.search_space = search_space
self.population_size = population_size
self.n_generations = n_generations
self.tournament_size = tournament_size
self.mutation_prob = mutation_prob
self.crossover_prob = crossover_prob
self.population = []
self.fitness = {}
def initialize_population(self):
"""Initialize random population."""
self.population = [
self.search_space.random_architecture()
for _ in range(self.population_size)
]
def tournament_selection(self) -> List[int]:
"""Select architecture via tournament selection."""
candidates = random.sample(self.population, self.tournament_size)
return max(candidates, key=lambda arch: self.fitness[self.search_space.encode_architecture(arch)])
def evolve(self, fitness_fn: Callable[[List[int]], float]) -> List[int]:
"""
Run evolutionary algorithm.
Args:
fitness_fn: Function that evaluates architecture and returns fitness (accuracy)
Returns:
best_arch: Best architecture found
"""
self.initialize_population()
# Evaluate initial population
print("Evaluating initial population...")
for arch in self.population:
arch_str = self.search_space.encode_architecture(arch)
if arch_str not in self.fitness:
self.fitness[arch_str] = fitness_fn(arch)
# Evolution
for gen in range(self.n_generations):
new_population = []
# Elitism: keep top 10%
elite_size = self.population_size // 10
elite = sorted(self.population,
key=lambda a: self.fitness[self.search_space.encode_architecture(a)],
reverse=True)[:elite_size]
new_population.extend(elite)
# Generate offspring
while len(new_population) < self.population_size:
# Selection
parent1 = self.tournament_selection()
parent2 = self.tournament_selection()
# Crossover
if random.random() < self.crossover_prob:
child1, child2 = self.search_space.crossover(parent1, parent2)
else:
child1, child2 = parent1.copy(), parent2.copy()
# Mutation
if random.random() < self.mutation_prob:
child1 = self.search_space.mutate(child1)
if random.random() < self.mutation_prob:
child2 = self.search_space.mutate(child2)
new_population.extend([child1, child2])
self.population = new_population[:self.population_size]
# Evaluate new architectures
for arch in self.population:
arch_str = self.search_space.encode_architecture(arch)
if arch_str not in self.fitness:
self.fitness[arch_str] = fitness_fn(arch)
# Report progress
best_fitness = max(self.fitness[self.search_space.encode_architecture(a)]
for a in self.population)
print(f"Generation {gen+1}/{self.n_generations}, Best Fitness: {best_fitness:.4f}")
# Return best architecture
best_arch = max(self.population,
key=lambda a: self.fitness[self.search_space.encode_architecture(a)])
return best_arch
# ============================================================================
# 6. Zero-Shot NAS (NASWOT)
# ============================================================================
def compute_naswot_score(model: nn.Module, dataloader, device: str = 'cuda') -> float:
"""
Compute NASWOT score: log det of Jacobian covariance.
Higher score β better architecture (without training).
Args:
model: Network to evaluate
dataloader: Data for computing Jacobian
device: Device
Returns:
score: NASWOT score
"""
model = model.to(device)
model.train()
# Get batch
x, _ = next(iter(dataloader))
x = x.to(device)
batch_size = x.size(0)
# Forward pass
logits = model(x)
# Compute Jacobian for each sample
jacobians = []
for i in range(batch_size):
model.zero_grad()
logits[i].sum().backward(retain_graph=True)
# Collect gradients (flatten)
grad = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])
jacobians.append(grad)
# Stack and compute covariance
J = torch.stack(jacobians) # [batch, n_params]
# Regularized covariance
cov = J.T @ J / batch_size + torch.eye(J.size(1), device=device) * 1e-5
# Log determinant (score)
sign, logdet = torch.slogdet(cov)
score = logdet.item() if sign > 0 else -float('inf')
return score
# ============================================================================
# 7. Multi-Objective NAS (NSGA-II)
# ============================================================================
class MultiObjectiveNAS:
"""
Multi-objective NAS using NSGA-II.
Optimizes: accuracy, -latency, -parameters
"""
def __init__(
self,
search_space: SearchSpace,
population_size: int = 50,
n_generations: int = 20
):
self.search_space = search_space
self.population_size = population_size
self.n_generations = n_generations
self.population = []
self.objectives = {} # arch_str -> [acc, -latency, -params]
def dominates(self, obj1: List[float], obj2: List[float]) -> bool:
"""Check if obj1 dominates obj2 (better in all objectives)."""
return all(o1 >= o2 for o1, o2 in zip(obj1, obj2)) and any(o1 > o2 for o1, o2 in zip(obj1, obj2))
def non_dominated_sort(self) -> List[List[List[int]]]:
"""Non-dominated sorting of population."""
fronts = [[]]
domination_count = defaultdict(int)
dominated_solutions = defaultdict(list)
# Compute domination
for i, arch1 in enumerate(self.population):
obj1 = self.objectives[self.search_space.encode_architecture(arch1)]
for j, arch2 in enumerate(self.population):
if i == j:
continue
obj2 = self.objectives[self.search_space.encode_architecture(arch2)]
if self.dominates(obj1, obj2):
dominated_solutions[i].append(j)
elif self.dominates(obj2, obj1):
domination_count[i] += 1
if domination_count[i] == 0:
fronts[0].append(self.population[i])
# Build subsequent fronts
i = 0
while len(fronts[i]) > 0:
next_front = []
for arch in fronts[i]:
arch_idx = self.population.index(arch)
for dominated_idx in dominated_solutions[arch_idx]:
domination_count[dominated_idx] -= 1
if domination_count[dominated_idx] == 0:
next_front.append(self.population[dominated_idx])
i += 1
fronts.append(next_front)
return fronts[:-1] # Remove empty last front
def crowding_distance(self, front: List[List[int]]) -> Dict[int, float]:
"""Compute crowding distance for diversity."""
distances = {self.population.index(arch): 0.0 for arch in front}
n_obj = len(self.objectives[self.search_space.encode_architecture(front[0])])
for obj_idx in range(n_obj):
# Sort by objective
sorted_front = sorted(front,
key=lambda a: self.objectives[self.search_space.encode_architecture(a)][obj_idx])
# Boundary points get infinite distance
distances[self.population.index(sorted_front[0])] = float('inf')
distances[self.population.index(sorted_front[-1])] = float('inf')
# Others get normalized distance
obj_range = (self.objectives[self.search_space.encode_architecture(sorted_front[-1])][obj_idx] -
self.objectives[self.search_space.encode_architecture(sorted_front[0])][obj_idx])
if obj_range > 0:
for i in range(1, len(sorted_front) - 1):
arch = sorted_front[i]
arch_idx = self.population.index(arch)
prev_obj = self.objectives[self.search_space.encode_architecture(sorted_front[i-1])][obj_idx]
next_obj = self.objectives[self.search_space.encode_architecture(sorted_front[i+1])][obj_idx]
distances[arch_idx] += (next_obj - prev_obj) / obj_range
return distances
def search(self, fitness_fn: Callable[[List[int]], Tuple[float, float, float]]) -> List[List[int]]:
"""
Run NSGA-II.
Args:
fitness_fn: Returns (accuracy, latency, n_params)
Returns:
pareto_front: List of Pareto optimal architectures
"""
# Initialize
self.population = [self.search_space.random_architecture()
for _ in range(self.population_size)]
# Evaluate
for arch in self.population:
acc, lat, params = fitness_fn(arch)
self.objectives[self.search_space.encode_architecture(arch)] = [acc, -lat, -params]
# Evolution
for gen in range(self.n_generations):
# Selection + Variation
offspring = []
while len(offspring) < self.population_size:
parent1 = random.choice(self.population)
parent2 = random.choice(self.population)
child1, child2 = self.search_space.crossover(parent1, parent2)
child1 = self.search_space.mutate(child1)
offspring.extend([child1, child2])
offspring = offspring[:self.population_size]
# Evaluate offspring
for arch in offspring:
arch_str = self.search_space.encode_architecture(arch)
if arch_str not in self.objectives:
acc, lat, params = fitness_fn(arch)
self.objectives[arch_str] = [acc, -lat, -params]
# Combine parent + offspring
combined = self.population + offspring
self.population = combined
# Non-dominated sorting
fronts = self.non_dominated_sort()
# Select next generation
new_population = []
for front in fronts:
if len(new_population) + len(front) <= self.population_size:
new_population.extend(front)
else:
# Use crowding distance
distances = self.crowding_distance(front)
sorted_front = sorted(front,
key=lambda a: distances[self.population.index(a)],
reverse=True)
new_population.extend(sorted_front[:self.population_size - len(new_population)])
break
self.population = new_population
print(f"Generation {gen+1}/{self.n_generations}, Pareto Front Size: {len(fronts[0])}")
# Return Pareto front
fronts = self.non_dominated_sort()
return fronts[0]
# ============================================================================
# 8. Demo and Benchmarking
# ============================================================================
def demo_search_space():
"""Demonstrate search space."""
print("=" * 80)
print("Search Space Demo")
print("=" * 80)
space = SearchSpace(n_nodes=4)
# Sample random architectures
print("\nRandom Architectures:")
for i in range(3):
arch = space.random_architecture()
arch_str = space.encode_architecture(arch)
print(f" {i+1}. {arch_str}")
# Mutation
print("\nMutation:")
arch = space.random_architecture()
print(f" Original: {space.encode_architecture(arch)}")
mutated = space.mutate(arch, n_mutations=2)
print(f" Mutated: {space.encode_architecture(mutated)}")
# Crossover
print("\nCrossover:")
arch1 = space.random_architecture()
arch2 = space.random_architecture()
print(f" Parent 1: {space.encode_architecture(arch1)}")
print(f" Parent 2: {space.encode_architecture(arch2)}")
child1, child2 = space.crossover(arch1, arch2)
print(f" Child 1: {space.encode_architecture(child1)}")
print(f" Child 2: {space.encode_architecture(child2)}")
def demo_network_build():
"""Demonstrate building network from architecture."""
print("\n" + "=" * 80)
print("Network Building Demo")
print("=" * 80)
space = SearchSpace(n_nodes=4)
arch = space.random_architecture()
# Build network
model = Network(
C=16,
n_classes=10,
n_layers=8,
arch=arch,
n_nodes=4
)
print(f"\nArchitecture: {space.encode_architecture(arch)}")
print(f"Parameters: {model.count_parameters():,}")
# Test forward pass
x = torch.randn(2, 3, 32, 32)
logits = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
def demo_darts():
"""Demonstrate DARTS."""
print("\n" + "=" * 80)
print("DARTS Demo (Simplified)")
print("=" * 80)
# Build DARTS supernet
model = DARTSNetwork(C=16, n_classes=10, n_layers=8, n_nodes=4)
print(f"\nArchitecture parameters shape: {model.alphas.shape}")
print(f"Network parameters: {sum(p.numel() for p in model.parameters()):,}")
# Forward pass
x = torch.randn(2, 3, 32, 32)
logits = model(x)
print(f"Forward pass successful: {logits.shape}")
# Discretize
arch_ops = model.discretize()
print(f"\nDiscretized architecture (first 6 ops): {arch_ops[:6]}")
def demo_evolutionary():
"""Demonstrate evolutionary NAS."""
print("\n" + "=" * 80)
print("Evolutionary NAS Demo")
print("=" * 80)
space = SearchSpace(n_nodes=4)
# Mock fitness function (random)
def fitness_fn(arch):
return random.random()
evolution = EvolutionaryNAS(
search_space=space,
population_size=20,
n_generations=5,
mutation_prob=0.2
)
print("\nRunning evolution (5 generations, 20 population)...")
best_arch = evolution.evolve(fitness_fn)
print(f"\nBest architecture: {space.encode_architecture(best_arch)}")
print(f"Best fitness: {evolution.fitness[space.encode_architecture(best_arch)]:.4f}")
def demo_zero_shot():
"""Demonstrate zero-shot NAS."""
print("\n" + "=" * 80)
print("Zero-Shot NAS Demo (NASWOT)")
print("=" * 80)
space = SearchSpace(n_nodes=4)
# Create dummy dataloader
x = torch.randn(4, 3, 32, 32)
y = torch.randint(0, 10, (4,))
dataloader = [(x, y)]
# Evaluate multiple architectures
print("\nEvaluating architectures without training:")
scores = []
for i in range(3):
arch = space.random_architecture()
model = Network(C=16, n_classes=10, n_layers=8, arch=arch)
try:
score = compute_naswot_score(model, dataloader, device='cpu')
scores.append((arch, score))
print(f" Arch {i+1}: score = {score:.2f}")
except:
print(f" Arch {i+1}: failed to compute")
if scores:
best_arch, best_score = max(scores, key=lambda x: x[1])
print(f"\nBest architecture (by zero-shot score):")
print(f" {space.encode_architecture(best_arch)}")
print(f" Score: {best_score:.2f}")
def print_comparison_table():
"""Print comparison of NAS methods."""
print("\n" + "=" * 80)
print("NAS Methods Comparison")
print("=" * 80)
table = """
ββββββββββββββββββββ¬βββββββββββββββββ¬βββββββββββββββββ¬ββββββββββββββ¬βββββββββββββββ
β Method β Search Cost β Quality β Flexibility β Ease of Use β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β RL-based (NAS) β Very High β Excellent β High β Complex β
β β (1000s GPU-hr) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Evolutionary β High β Very Good β Very High β Easy β
β β (100s GPU-hr) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β DARTS β Low β Good β Medium β Medium β
β β (1-2 GPU-days) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β ENAS (weight β Very Low β Good β Medium β Easy β
β sharing) β (0.5 GPU-day) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Zero-Shot β Minimal β Fair β High β Very Easy β
β (NASWOT) β (minutes) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Once-for-All β High (once) β Excellent β Very High β Medium β
β β Low (reuse) β β β β
ββββββββββββββββββββΌβββββββββββββββββΌβββββββββββββββββΌββββββββββββββΌβββββββββββββββ€
β Multi-Objective β High β Very Good β Very High β Medium β
β (NSGA-II) β (100s GPU-hr) β β β β
ββββββββββββββββββββ΄βββββββββββββββββ΄βββββββββββββββββ΄ββββββββββββββ΄βββββββββββββββ
Key Trade-offs:
- RL/Evolution: Best quality but highest cost
- DARTS: Good balance of speed and quality, but can collapse
- Weight Sharing (ENAS): Fast but biased weight estimates
- Zero-Shot: Extremely fast but lower quality
- Once-for-All: Train once, use many times (amortized efficiency)
- Multi-Objective: Finds Pareto frontier (multiple solutions)
Recommendations:
β Starting out: DARTS or ENAS (good balance)
β Limited budget: Zero-shot for screening, then refine top candidates
β Production deployment: Once-for-All (flexible deployment)
β Hardware constraints: Multi-objective NAS
β Research/novel tasks: Evolutionary (most flexible)
"""
print(table)
def main():
"""Run all demos."""
print("\n" + "=" * 80)
print("Advanced Neural Architecture Search - Comprehensive Demo")
print("=" * 80)
# Run demos
demo_search_space()
demo_network_build()
demo_darts()
demo_evolutionary()
demo_zero_shot()
# Comparison table
print_comparison_table()
print("\n" + "=" * 80)
print("Demo Complete!")
print("=" * 80)
print("\nKey Takeaways:")
print("1. Search space design critically impacts NAS efficiency")
print("2. DARTS enables fast gradient-based architecture search")
print("3. Weight sharing dramatically reduces search cost")
print("4. Zero-shot methods can screen thousands of architectures instantly")
print("5. Multi-objective optimization finds trade-off solutions")
print("6. Modern NAS methods (OFA, zero-shot) make NAS practical")
if __name__ == "__main__":
main()