import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Advanced Neural Architecture Search TheoryΒΆ

1. Search Space Design and FundamentalsΒΆ

Architecture Search Problem: Given a search space \(\mathcal{A}\) of candidate architectures, dataset \(\mathcal{D}\), and performance metric \(p\), find:

\[a^* = \arg\max_{a \in \mathcal{A}} p(a, \mathcal{D})\]

Search Space Taxonomy:

Space Type

Description

Size

Complexity

Chain-structured

Sequential layers

\(O(n^d)\) where d=depth

Linear

Cell-based

Modular cells

\(|O|^{N_e}\) where \(N_e\)=edges

Exponential

Hierarchical

Multi-level

\(O((|O| \cdot N)^L)\)

Super-exponential

Network morphism

Topology mutation

Infinite

Continuous

Cell-based Search (NASNet, DARTS):

  • Normal cell: Keeps spatial dimensions

  • Reduction cell: Downsamples features

  • Stack pattern: \(N_{normal} + 1_{reduction}\)

  • Total architecture: \((N_n + N_r) \times C\) where C=cells/block

Search Space Size: For DARTS with:

  • \(N=7\) nodes per cell

  • \(\|O\|=8\) operations

  • Each node connects to 2 previous nodes

\[\text{Size} = \prod_{i=1}^{N} \binom{i+1}{2} \cdot \|O\|^2 \approx 10^{18} \text{ architectures}\]

2. DARTS: Complete Mathematical FrameworkΒΆ

2.1 Continuous Relaxation

Discrete operation selection becomes continuous mixing:

\[\bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x)\]

Where:

  • \(\alpha^{(i,j)} \in \mathbb{R}^{|\mathcal{O}|}\): Architecture parameters for edge (i,j)

  • \(\mathcal{O}\): Operation set {conv3x3, conv5x5, maxpool3x3, skip, none}

  • Softmax ensures \(\sum_o p_o = 1\) (valid probability distribution)

2.2 Bi-level Optimization Problem

\[\begin{split}\begin{aligned} \min_{\alpha} \quad & \mathcal{L}_{val}(w^*(\alpha), \alpha) \\ \text{s.t.} \quad & w^*(\alpha) = \arg\min_w \mathcal{L}_{train}(w, \alpha) \end{aligned}\end{split}\]

Challenges:

  1. Inner optimization: Training weights \(w\) to convergence is expensive

  2. Implicit gradient: \(\nabla_\alpha \mathcal{L}_{val}\) requires \(\partial w^*/\partial \alpha\)

  3. Memory: Storing full computation graph for second-order derivatives

2.3 Gradient Approximation

First-order approximation (using chain rule):

\[\nabla_\alpha \mathcal{L}_{val}(w^*, \alpha) = \nabla_\alpha \mathcal{L}_{val}(w', \alpha) - \xi \nabla_\alpha \nabla_w \mathcal{L}_{train}(w', \alpha) \nabla_{w'} \mathcal{L}_{val}(w', \alpha)\]

Where:

  • \(w' = w - \xi \nabla_w \mathcal{L}_{train}(w, \alpha)\) (one-step lookahead)

  • \(\xi\): Learning rate

  • Second term: Hessian-vector product \(\nabla_\alpha \nabla_w \mathcal{L}_{train} \cdot \nabla_{w'} \mathcal{L}_{val}\)

Computational complexity:

  • Forward pass: \(O(N \cdot |\mathcal{O}|)\) where N=edges

  • Backward pass: \(O(N \cdot |\mathcal{O}|) + O(|w|)\) for Hessian-vector product

  • Memory: \(O(|\mathcal{O}| \cdot \text{activations})\) during search

Second-order Hessian computation (finite difference):

\[\nabla_\alpha \nabla_w \mathcal{L}_{train} \cdot v \approx \frac{\nabla_\alpha \mathcal{L}_{train}(w + \epsilon v, \alpha) - \nabla_\alpha \mathcal{L}_{train}(w - \epsilon v, \alpha)}{2\epsilon}\]

For vector \(v = \nabla_{w'} \mathcal{L}_{val}(w', \alpha)\) and \(\epsilon = 0.01 / \|v\|_2\)

3. Architecture Derivation and DiscretizationΒΆ

3.1 Pruning Strategy

After continuous search, extract discrete architecture:

  1. For each intermediate node \(i\):

    • Compute operation strengths: \(\bar{\alpha}_o^{(i,j)} = \sum_j \alpha_o^{(i,j)}\) (aggregate over edges)

    • Select top-k operations: \(O_i = \text{topk}(\{\bar{\alpha}_o^{(i,j)}\}_o, k=2)\)

  2. Retain top-k edges per node based on \(\max_o \alpha_o^{(i,j)}\)

Genotype encoding:

Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# normal: [(op1, from_node1), (op2, from_node2), ...]
# normal_concat: [4, 5, 6] # nodes to concatenate for output

3.2 One-shot vs. Progressive Search

Approach

Strategy

Pros

Cons

One-shot (DARTS)

Single supernet, simultaneous search

Fast, simple

High memory, coupling

Progressive (ENAS)

Sequential decisions, controller-based

Lower memory

Slower, credit assignment

Weight sharing

Train once, sample many

Very fast

Inaccurate ranking

4. Advanced NAS VariantsΒΆ

4.1 PC-DARTS (Partial Channel Connections)

Problem: DARTS memory grows as \(O(|\mathcal{O}| \cdot C \cdot H \cdot W)\)

Solution: Sample subset of channels during search

\[\bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\beta \cdot \alpha_o^{(i,j)})}{\sum_{o'} \exp(\beta \cdot \alpha_{o'}^{(i,j)})} o(S(x))\]

Where:

  • \(S(x)\): Channel sampling function (sample \(C/K\) channels, K=4 typical)

  • \(\beta\): Edge normalization factor \(\beta^{(i,j)} = \sum_{o} \alpha_o^{(i,j)}\)

Memory reduction: \(K\times\) less than DARTS (4Γ— typical)

Stability: Regularizes over-parameterization, prevents dominance of skip connections

4.2 FairNAS (Fairness-aware Search)

Motivation: Dominant operations (skip) receive most gradient updates

Fairness loss:

\[\mathcal{L}_{fair} = \sum_{(i,j)} \text{Var}(\{\alpha_o^{(i,j)}\}_o) = \sum_{(i,j)} \frac{1}{|\mathcal{O}|} \sum_o (\alpha_o^{(i,j)} - \bar{\alpha}^{(i,j)})^2\]

Combined objective:

\[\mathcal{L}_{total} = \mathcal{L}_{val} + \lambda_{fair} \mathcal{L}_{fair}\]

Encourages exploration of diverse operations during search

4.3 GDAS (Gumbel-Softmax Differentiable Architecture Search)

Gumbel-Max trick for discrete sampling:

\[o^{(i,j)} = \text{one\_hot}\left(\arg\max_o (\alpha_o^{(i,j)} + g_o)\right)\]

Where \(g_o \sim \text{Gumbel}(0,1) = -\log(-\log(u))\), \(u \sim \text{Uniform}(0,1)\)

Gumbel-Softmax (differentiable relaxation):

\[p_o = \frac{\exp((\alpha_o + g_o) / \tau)}{\sum_{o'} \exp((\alpha_{o'} + g_{o'}) / \tau)}\]

Where \(\tau\): Temperature (annealed from 1 β†’ 0 during search)

Advantages:

  • Single-path during forward (memory \(O(1)\) vs DARTS \(O(|\mathcal{O}|)\))

  • Stochastic exploration reduces bias toward skip connections

  • Better approximation to discrete search

5. Zero-cost Proxies and PredictorsΒΆ

5.1 Training-Free Metrics

Evaluate architecture quality without training:

SNIP (Single-shot Network Pruning):

\[s(a) = \sum_{\theta \in a} \left|\theta \cdot \nabla_\theta \mathcal{L}(\theta)\right|\]

Saliency score measures parameter importance

GraSP (Gradient Signal Preservation):

\[s(a) = -\text{Tr}\left(\nabla_\theta^2 \mathcal{L}(\theta) \cdot \text{diag}(\theta^2)\right)\]

Hessian trace approximates loss curvature

NASWOT (NAS Without Training):

\[s(a) = \log \det\left(\frac{1}{N} \sum_{x \in \mathcal{D}} \nabla_\theta f_a(x) \nabla_\theta f_a(x)^T\right)\]

Kernel measure of network expressivity

5.2 Performance Prediction

Early stopping correlation: Train for T epochs, predict final accuracy at epoch T_max:

\[\rho = \text{Spearman}(\{acc_T(a_i)\}, \{acc_{T_{max}}(a_i)\})\]

Typical: \(\rho > 0.6\) for T=12 epochs predicting T_max=200

Learning curve extrapolation: Fit power law: \(\text{Error}(t) = a \cdot t^{-b} + c\)

Predict final performance from early training trajectory

6. Efficient Search StrategiesΒΆ

6.1 Supernet Training (Single Path One-Shot)

Weight sharing: Train single β€œsupernet” containing all candidate operations

\[w_{shared} = \arg\min_w \mathbb{E}_{a \sim p(\alpha)} [\mathcal{L}_{train}(w|a)]\]

Sampling strategy:

  • Uniform: Sample each architecture with equal probability

  • Prioritized: Sample based on predicted performance

Advantage: Train once (\(O(1)\)), evaluate many (\(O(N)\))

Limitation: Weight sharing introduces ranking disorder (correlation \(\rho \approx 0.3\)-0.6 with standalone training)

6.2 Evolutionary Algorithms

Population-based search:

1. Initialize population P = {a_1, ..., a_n}
2. For generation g = 1 to G:
   a. Evaluate fitness f(a_i) for all a_i ∈ P
   b. Select parents (tournament/roulette)
   c. Crossover: Mix architectures
   d. Mutate: Random operation changes
   e. Update population P
3. Return best architecture a* = argmax f(a_i)

AmoebaNet mutations:

  • Add/remove layers

  • Change operation types

  • Modify filter sizes

  • Alter connections

Complexity: \(O(P \cdot G \cdot T_{train})\) where P=population, G=generations, T=training time

6.3 Reinforcement Learning (NASNet)

Controller RNN generates architectures:

\[p(a; \theta_c) = \prod_{t=1}^T p(a_t | a_{1:t-1}; \theta_c)\]

REINFORCE gradient:

\[\nabla_{\theta_c} J = \mathbb{E}_{a \sim p(\cdot; \theta_c)} [R(a) \nabla_{\theta_c} \log p(a; \theta_c)]\]

Variance reduction:

  • Baseline: \(b = \text{EMA}(R)\) (exponential moving average)

  • Advantage: \(A(a) = R(a) - b\)

PPO update (more stable than REINFORCE):

\[L(\theta_c) = \mathbb{E}_a \left[\min\left(\frac{p(a; \theta_c)}{p(a; \theta_c^{old})} A(a), \text{clip}(\cdot, 1-\epsilon, 1+\epsilon) A(a)\right)\right]\]

7. Hardware-Aware and Multi-Objective NASΒΆ

7.1 Latency Prediction

Lookup table approach: Build table: \(T(o, C, H, W, device)\) = latency of operation o

Differentiable latency loss:

\[\mathcal{L}_{lat} = \sum_{(i,j)} \sum_o p_o^{(i,j)} \cdot T(o, C^{(i,j)}, H^{(i,j)}, W^{(i,j)})\]

Where \(p_o = \text{softmax}(\alpha_o)\) are operation probabilities

7.2 Multi-Objective Optimization

Pareto frontier: Find architectures a such that no a’ exists with:

  • \(\text{Acc}(a') > \text{Acc}(a)\) AND

  • \(\text{Latency}(a') < \text{Latency}(a)\) (or FLOPs, params, energy)

Scalarization:

\[\mathcal{L} = \mathcal{L}_{val} + \lambda_1 \log(\text{FLOPs}) + \lambda_2 \log(\text{Latency})\]

NSGA-II (Non-dominated Sorting Genetic Algorithm):

  1. Rank architectures by Pareto dominance

  2. Use crowding distance for diversity

  3. Select based on rank and crowding

8. Theoretical Guarantees and GeneralizationΒΆ

8.1 Search-Evaluation Gap

Problem: Architecture performs worse after re-training from scratch

Causes:

  1. Weight co-adaptation: Supernet weights biased by weight sharing

  2. Different optimization landscape: Search (few epochs) vs. evaluation (full training)

  3. Overfitting to search data: Architecture specializes to validation split

Solution: Independent train/search/test splits

8.2 Sample Complexity

PAC bound for architecture search:

With probability \(1 - \delta\), the true error satisfies:

\[\mathcal{L}_{true}(a^*) \leq \mathcal{L}_{val}(a^*) + O\left(\sqrt{\frac{\log |\mathcal{A}| + \log(1/\delta)}{n_{val}}}\right)\]

Where:

  • \(|\mathcal{A}|\): Search space size

  • \(n_{val}\): Validation set size

Implication: Larger search spaces require more validation data to avoid overfitting

8.3 Expressivity vs. Trainability Trade-off

Expressivity: Ability to represent complex functions Trainability: Ease of optimization (gradient flow, convergence)

Observation: Skip connections improve trainability but reduce search difficulty (trivial solution)

Regularization strategies:

  • Dropout on skip connections

  • Path dropout (stochastic depth)

  • Architecture weight decay: \(\mathcal{L}_{reg} = \|\alpha\|_2^2\)

9. Practical ConsiderationsΒΆ

9.1 Hyperparameter Sensitivity

Hyperparameter

Typical Range

Effect

Learning rate (Ξ±)

3e-4 to 1e-3

Stability of architecture parameters

Learning rate (w)

2.5e-2 to 1e-1

Speed of weight convergence

Weight decay

3e-4 to 1e-3

Prevent overfitting

Batch size

64 to 256

Memory vs. gradient quality

Search epochs

50 to 100

Search thoroughness

9.2 Computational Budget

DARTS: ~4 GPU-days (1 V100) ENAS: ~0.5 GPU-days AmoebaNet: ~3150 GPU-days (450 K40 GPUs) NASNet: ~1800 GPU-days

Cost reduction strategies:

  1. Early stopping based on learning curve

  2. Weight sharing (one-shot)

  3. Zero-cost proxies (no training)

  4. Warm-starting from smaller search

9.3 Transferability

Search on proxy task, transfer to target:

Proxy

Target

Transfer Success

CIFAR-10

ImageNet

βœ“ High (NASNet, AmoebaNet)

Small dataset

Large dataset

βœ“ Moderate

Different domain

Target domain

βœ— Low

Requirements for transfer:

  • Similar data distribution

  • Comparable task complexity

  • Shared inductive biases

10. State-of-the-Art Methods (2020-2024)ΒΆ

10.1 Once-for-All (OFA) Networks

Progressive shrinking:

  1. Train full network (largest architecture)

  2. Progressively train sub-networks (elastic depth, width, kernel)

  3. Final supernet supports \(10^{19}\) architectures

Deployment: Extract architecture matching device constraints (latency, memory)

10.2 AutoFormer (Transformer NAS)

Search space: Vision Transformer components

  • Number of heads

  • MLP expansion ratio

  • Embedding dimensions

  • Depth

Supernet training: Weight entanglement for parameter sharing across architectures

10.3 Neural Architecture Transfer (NAT)

Idea: Transfer learned architecture generators across tasks

Meta-controller: Learns to generate architectures conditioned on task

\[p(a | \mathcal{D}_{task}; \theta_{meta}) = \text{Controller}(\text{Embed}(\mathcal{D}_{task}); \theta_{meta})\]

11. Open Research QuestionsΒΆ

  1. Why does DARTS prefer skip connections?

    • Hypothesis: Gradient flow, easier optimization, overfitting to search data

    • Mitigation: Regularization, early stopping, fairness constraints

  2. How to improve weight sharing ranking correlation?

    • Current: \(\rho \approx 0.3\)-0.6

    • Goal: \(\rho > 0.8\) for reliable ranking

    • Approaches: Better training, calibration, progressive training

  3. Can we eliminate the search-evaluation gap?

    • Problem: 1-5% accuracy drop after re-training

    • Potential: Direct optimization, better initialization, transfer learning

  4. How to handle diverse hardware efficiently?

    • Challenge: Device-specific latency models

    • Direction: Universal latency predictors, neural latency models

  5. What are fundamental limits of architecture search?

    • Theory: Sample complexity, generalization bounds

    • Practice: Minimum search budget, maximum transferability

Key Papers and TimelineΒΆ

Foundation (2016-2018):

  • Zoph & Le 2017: NASNet - RL-based search, cell-based space

  • Pham et al. 2018: ENAS - Efficient weight sharing

  • Real et al. 2019: AmoebaNet - Evolutionary search

Differentiable Search (2018-2019):

  • Liu et al. 2019: DARTS - Gradient-based, bi-level optimization

  • Xu et al. 2020: PC-DARTS - Partial channel, stability

  • Dong & Yang 2019: GDAS - Gumbel-softmax sampling

Efficiency (2019-2021):

  • Cai et al. 2020: Once-for-All - Progressive shrinking, diverse hardware

  • Chen et al. 2021: AutoFormer - Transformer search

  • White et al. 2021: NASWOT - Zero-cost proxy

Multi-Objective (2020-2022):

  • Tan et al. 2019: EfficientNet - Compound scaling

  • Chu et al. 2021: FairNAS - Fairness-aware search

  • Wu et al. 2019: FBNet - Hardware-aware differentiable

Implementation Complexity:

  • Basic DARTS: ~300 lines (PyTorch)

  • Full NAS framework: ~2000 lines (search space, controller, evaluator)

  • Production system: ~10,000+ lines (distributed training, experiment management)

SummaryΒΆ

DARTS Key Ideas:ΒΆ

  1. Continuous relaxation of discrete search space

  2. Mixed operations with softmax weights

  3. Bi-level optimization: architecture Ξ± and weights w

  4. Gradient-based search (efficient)

Search Space:ΒΆ

  • Operations: conv 3Γ—3, 5Γ—5, skip, pool, none

  • Cell-based structure

  • DAG with mixed edges

Advantages:ΒΆ

  • 1000Γ— faster than RL/evolution

  • End-to-end differentiable

  • Transfer to larger datasets

Variants:ΒΆ

  • PC-DARTS: Partial channel sampling

  • GDAS: Gumbel-softmax sampling

  • SNAS: Stochastic relaxation

  • ProxylessNAS: Memory-efficient

# ============================================================================
# Advanced NAS Implementations
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from copy import deepcopy
import time

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# 1. PC-DARTS: Partial Channel Connections
# ============================================================================

class PCDARTSMixedOp(nn.Module):
    """
    Mixed operation with partial channel sampling for memory efficiency.
    
    Memory reduction: O(C/K) instead of O(C) where K is channel divisor.
    """
    def __init__(self, C, K=4):
        super().__init__()
        self.K = K  # Channel sampling factor
        self.C = C
        self.C_sampled = C // K
        
        # Build operations for sampled channels
        self.ops = nn.ModuleList()
        for name, op_fn in OPS.items():
            if name in ['none', 'skip']:
                self.ops.append(op_fn(C))  # Full channels for param-free ops
            else:
                self.ops.append(op_fn(self.C_sampled))  # Sampled channels
    
    def forward(self, x, weights, edge_norm=None):
        """
        Forward with partial channel sampling.
        
        Args:
            x: Input tensor [B, C, H, W]
            weights: Architecture weights [|O|]
            edge_norm: Edge normalization factor (PC-DARTS)
        """
        # Sample random channels
        channel_indices = torch.randperm(self.C)[:self.C_sampled]
        x_sampled = x[:, channel_indices, :, :]
        
        # Apply operations
        outputs = []
        for i, (op, w) in enumerate(zip(self.ops, weights)):
            if i < 2:  # none, skip (param-free, use full channels)
                out = op(x)
            else:  # parameterized ops (use sampled channels)
                out_sampled = op(x_sampled)
                # Expand back to full channels
                out = torch.zeros_like(x)
                out[:, channel_indices, :, :] = out_sampled
            
            # Apply edge normalization if provided
            if edge_norm is not None:
                w = w * edge_norm
            
            outputs.append(w * out)
        
        return sum(outputs)


class PCDARTSCell(nn.Module):
    """Cell with PC-DARTS mixed operations."""
    
    def __init__(self, C, n_nodes=4, K=4):
        super().__init__()
        self.n_nodes = n_nodes
        self.K = K
        
        # Mixed operations for each edge
        self.ops = nn.ModuleList()
        for i in range(n_nodes):
            for j in range(i + 2):  # Connect to previous nodes
                self.ops.append(PCDARTSMixedOp(C, K))
    
    def forward(self, x, alphas):
        states = [x, x]
        offset = 0
        
        for i in range(self.n_nodes):
            # Compute edge normalization (sum of alphas)
            edge_norms = []
            for j in range(len(states)):
                alpha = alphas[offset + j]
                edge_norm = alpha.sum()  # Sum over operations
                edge_norms.append(edge_norm)
            
            # Aggregate from all previous nodes
            s = sum(
                self.ops[offset + j](h, F.softmax(alphas[offset + j], dim=0), edge_norms[j])
                for j, h in enumerate(states)
            )
            offset += len(states)
            states.append(s)
        
        return torch.cat(states[2:], dim=1)


# ============================================================================
# 2. GDAS: Gumbel-Softmax Differentiable Search
# ============================================================================

class GumbelSoftmax(nn.Module):
    """
    Gumbel-Softmax for differentiable discrete sampling.
    
    Ο„ β†’ 0: Approaches one-hot (discrete)
    Ο„ β†’ ∞: Approaches uniform distribution
    """
    def __init__(self, tau_min=0.1, tau_max=10.0):
        super().__init__()
        self.tau = tau_max  # Temperature (annealed during training)
        self.tau_min = tau_min
        self.tau_max = tau_max
    
    def anneal_temperature(self, progress):
        """
        Anneal temperature: Ο„(t) = Ο„_max Β· (Ο„_min/Ο„_max)^progress
        
        Args:
            progress: Training progress in [0, 1]
        """
        self.tau = self.tau_max * (self.tau_min / self.tau_max) ** progress
    
    def sample_gumbel(self, shape):
        """Sample from Gumbel(0, 1): -log(-log(U)) where U ~ Uniform(0,1)"""
        U = torch.rand(shape, device=device)
        return -torch.log(-torch.log(U + 1e-20) + 1e-20)
    
    def forward(self, logits, hard=False):
        """
        Gumbel-Softmax sampling.
        
        Args:
            logits: Architecture parameters [|O|]
            hard: If True, return one-hot (discrete) in forward, soft in backward
        
        Returns:
            Sampled probabilities (soft or hard)
        """
        gumbel_noise = self.sample_gumbel(logits.shape)
        y = logits + gumbel_noise
        y_soft = F.softmax(y / self.tau, dim=0)
        
        if hard:
            # Straight-through estimator
            y_hard = torch.zeros_like(y_soft)
            y_hard[y_soft.argmax()] = 1.0
            
            # y_hard in forward, y_soft gradient in backward
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft
        
        return y


class GDASMixedOp(nn.Module):
    """Mixed operation with Gumbel-Softmax sampling (single-path)."""
    
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList()
        for name, op_fn in OPS.items():
            self.ops.append(op_fn(C))
        
        self.gumbel = GumbelSoftmax()
    
    def forward(self, x, alpha, hard=False):
        """
        Single-path forward (memory-efficient).
        
        Only one operation is executed based on Gumbel sampling.
        """
        # Sample operation weights
        weights = self.gumbel(alpha, hard=hard)
        
        if hard:
            # Single-path: Execute only sampled operation
            selected_op = weights.argmax().item()
            return self.ops[selected_op](x)
        else:
            # Multi-path: Weighted sum (used during evaluation)
            return sum(w * op(x) for w, op in zip(weights, self.ops))


# ============================================================================
# 3. FairNAS: Fairness-Aware Search
# ============================================================================

def compute_fairness_loss(alphas):
    """
    Compute variance of architecture parameters to encourage exploration.
    
    L_fair = Ξ£_edges Var(Ξ±^edge) = Ξ£_edges (1/|O|) Ξ£_o (Ξ±_o - αΎ±)Β²
    """
    fairness_loss = 0.0
    
    for alpha in alphas:
        mean_alpha = alpha.mean()
        variance = ((alpha - mean_alpha) ** 2).mean()
        fairness_loss += variance
    
    return fairness_loss


class FairNASNetwork(nn.Module):
    """
    Search network with fairness regularization.
    
    Prevents skip connections from dominating the search.
    """
    def __init__(self, C=16, n_cells=2, n_nodes=3, n_classes=10, lambda_fair=0.1):
        super().__init__()
        self.C = C
        self.n_cells = n_cells
        self.n_nodes = n_nodes
        self.lambda_fair = lambda_fair
        
        # Network structure (same as DARTS)
        self.stem = nn.Sequential(
            nn.Conv2d(1, C, 3, 1, 1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        self.cells = nn.ModuleList()
        for i in range(n_cells):
            self.cells.append(Cell(C, n_nodes))
        
        self.classifier = nn.Linear(C * n_nodes, n_classes)
        
        # Architecture parameters
        self._init_alphas()
    
    def _init_alphas(self):
        n_ops = len(OPS)
        n_edges = sum(2 + i for i in range(self.n_nodes))
        
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.randn(n_ops)) for _ in range(n_edges)
        ])
    
    def forward(self, x):
        x = self.stem(x)
        
        for cell in self.cells:
            x = cell(x, self.alphas)
        
        x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        return self.classifier(x)
    
    def loss(self, x, y):
        """
        Combined loss: task loss + fairness regularization.
        
        L = L_task + Ξ»_fair * L_fair
        """
        logits = self(x)
        task_loss = F.cross_entropy(logits, y)
        
        fairness_loss = compute_fairness_loss(self.alphas)
        
        total_loss = task_loss + self.lambda_fair * fairness_loss
        
        return total_loss, task_loss.item(), fairness_loss.item()


# ============================================================================
# 4. Zero-Cost Proxy: SNIP (Single-Shot Network Pruning)
# ============================================================================

def compute_snip_score(model, dataloader, num_batches=1):
    """
    Compute SNIP score: Ξ£_ΞΈ |ΞΈ Β· βˆ‡_ΞΈ L|
    
    Measures parameter importance without training.
    """
    model.train()
    
    # Accumulate gradients over few batches
    for i, (x, y) in enumerate(dataloader):
        if i >= num_batches:
            break
        
        x, y = x.to(device), y.to(device)
        
        model.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
    
    # Compute saliency: |ΞΈ Β· βˆ‡ΞΈ|
    total_score = 0.0
    for param in model.parameters():
        if param.grad is not None:
            score = torch.abs(param * param.grad).sum().item()
            total_score += score
    
    return total_score


def rank_architectures_snip(search_space, dataloader, top_k=5):
    """
    Rank architectures using SNIP without training.
    
    Args:
        search_space: List of architecture configurations
        dataloader: Data for computing gradients
        top_k: Number of top architectures to return
    
    Returns:
        List of (architecture, score) sorted by score
    """
    scores = []
    
    for i, arch_config in enumerate(search_space):
        # Build model from architecture config
        model = build_model_from_config(arch_config).to(device)
        
        # Compute SNIP score
        score = compute_snip_score(model, dataloader, num_batches=3)
        scores.append((arch_config, score))
        
        print(f"Architecture {i+1}/{len(search_space)}: SNIP={score:.2f}")
    
    # Sort by score (higher is better)
    scores.sort(key=lambda x: x[1], reverse=True)
    
    return scores[:top_k]


def build_model_from_config(config):
    """
    Build model from architecture configuration.
    
    Config format: [(layer_type, params), ...]
    """
    # Simplified: Return a basic model
    # In practice, parse config to build custom architecture
    return SearchNetwork(C=16, n_cells=2, n_nodes=3)


# ============================================================================
# 5. Evolutionary Search
# ============================================================================

class ArchitectureGenome:
    """
    Genome representing an architecture for evolutionary search.
    
    Encoding: List of operation IDs for each edge.
    """
    def __init__(self, n_edges, n_ops):
        self.n_edges = n_edges
        self.n_ops = n_ops
        self.genes = np.random.randint(0, n_ops, size=n_edges)
        self.fitness = None
    
    def mutate(self, mutation_rate=0.1):
        """Randomly change operations with given probability."""
        mutated = ArchitectureGenome(self.n_edges, self.n_ops)
        mutated.genes = self.genes.copy()
        
        for i in range(self.n_edges):
            if np.random.rand() < mutation_rate:
                mutated.genes[i] = np.random.randint(0, self.n_ops)
        
        return mutated
    
    def crossover(self, other):
        """Single-point crossover with another genome."""
        child = ArchitectureGenome(self.n_edges, self.n_ops)
        
        crossover_point = np.random.randint(0, self.n_edges)
        child.genes[:crossover_point] = self.genes[:crossover_point]
        child.genes[crossover_point:] = other.genes[crossover_point:]
        
        return child


class EvolutionarySearch:
    """
    Evolutionary algorithm for architecture search.
    
    Population-based search with mutation and crossover.
    """
    def __init__(self, n_edges, n_ops, population_size=20, n_generations=10):
        self.n_edges = n_edges
        self.n_ops = n_ops
        self.population_size = population_size
        self.n_generations = n_generations
        
        # Initialize random population
        self.population = [
            ArchitectureGenome(n_edges, n_ops)
            for _ in range(population_size)
        ]
    
    def evaluate_fitness(self, genome, train_loader, val_loader):
        """
        Train model and evaluate accuracy (fitness).
        
        In practice, use early stopping (e.g., 10 epochs).
        """
        # Convert genome to architecture
        model = self.genome_to_model(genome)
        
        # Quick training (simplified)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        model.train()
        for epoch in range(3):  # Fast evaluation
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                loss = F.cross_entropy(model(x), y)
                loss.backward()
                optimizer.step()
        
        # Evaluate accuracy
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        accuracy = 100.0 * correct / total
        return accuracy
    
    def genome_to_model(self, genome):
        """Convert genome to PyTorch model (simplified)."""
        # In practice: Build custom model based on genes
        return SearchNetwork(C=16, n_cells=2, n_nodes=3).to(device)
    
    def tournament_selection(self, tournament_size=3):
        """Select parent via tournament (best among random subset)."""
        tournament = np.random.choice(self.population, tournament_size, replace=False)
        return max(tournament, key=lambda g: g.fitness)
    
    def evolve(self, train_loader, val_loader):
        """
        Run evolutionary search.
        
        Returns best genome found.
        """
        history = []
        
        for generation in range(self.n_generations):
            print(f"\nGeneration {generation + 1}/{self.n_generations}")
            
            # Evaluate population
            for i, genome in enumerate(self.population):
                if genome.fitness is None:
                    fitness = self.evaluate_fitness(genome, train_loader, val_loader)
                    genome.fitness = fitness
                    print(f"  Individual {i+1}: Fitness={fitness:.2f}%")
            
            # Track best
            best_genome = max(self.population, key=lambda g: g.fitness)
            avg_fitness = np.mean([g.fitness for g in self.population])
            history.append({'best': best_genome.fitness, 'avg': avg_fitness})
            
            print(f"  Best: {best_genome.fitness:.2f}%, Avg: {avg_fitness:.2f}%")
            
            # Create next generation
            new_population = []
            
            # Elitism: Keep best individual
            new_population.append(best_genome)
            
            # Generate offspring
            while len(new_population) < self.population_size:
                # Selection
                parent1 = self.tournament_selection()
                parent2 = self.tournament_selection()
                
                # Crossover
                child = parent1.crossover(parent2)
                
                # Mutation
                child = child.mutate(mutation_rate=0.2)
                
                new_population.append(child)
            
            self.population = new_population
        
        # Return best
        best = max(self.population, key=lambda g: g.fitness)
        return best, history


# ============================================================================
# 6. Demonstration: Compare Search Methods
# ============================================================================

print("\n" + "="*80)
print("Comparing Neural Architecture Search Methods")
print("="*80)

# Use existing data loaders from DARTS section above
# train_loader, val_loader, test_loader defined earlier

# ============================================================================
# 6.1 PC-DARTS Demo
# ============================================================================

print("\n--- PC-DARTS (Partial Channel) ---")

class PCDARTSNet(nn.Module):
    """Simplified PC-DARTS network for demo."""
    def __init__(self, C=16, K=4):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(1, C, 3, 1, 1, bias=False),
            nn.BatchNorm2d(C)
        )
        self.cell = PCDARTSCell(C, n_nodes=2, K=K)
        self.classifier = nn.Linear(C * 2, 10)
        
        # Architecture params
        n_ops = len(OPS)
        n_edges = 2 + 1  # 2 nodes, 1+0 edges for node 0, 2+0 for node 1
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.randn(n_ops)) for _ in range(n_edges)
        ])
    
    def forward(self, x):
        x = self.stem(x)
        x = self.cell(x, self.alphas)
        x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        return self.classifier(x)

pc_model = PCDARTSNet(C=16, K=4).to(device)

# Quick training (3 epochs)
optimizer = torch.optim.Adam(pc_model.parameters(), lr=1e-3)

for epoch in range(3):
    pc_model.train()
    total_loss = 0
    
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = F.cross_entropy(pc_model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")

# Extract architecture
print("\nExtracted Architecture (PC-DARTS):")
for i, alpha in enumerate(pc_model.alphas):
    weights = F.softmax(alpha, dim=0)
    best_op = weights.argmax().item()
    op_name = list(OPS.keys())[best_op]
    print(f"  Edge {i}: {op_name} (weight={weights[best_op].item():.3f})")


# ============================================================================
# 6.2 GDAS Demo
# ============================================================================

print("\n--- GDAS (Gumbel-Softmax) ---")

# Demonstrate Gumbel-Softmax sampling
gumbel = GumbelSoftmax(tau_min=0.1, tau_max=5.0)

logits = torch.tensor([1.0, 0.5, -0.5, 2.0, 0.0])  # Architecture weights

print("Temperature annealing:")
for progress in [0.0, 0.25, 0.5, 0.75, 1.0]:
    gumbel.anneal_temperature(progress)
    soft = gumbel(logits, hard=False)
    hard = gumbel(logits, hard=True)
    
    print(f"  Progress={progress:.2f}, Ο„={gumbel.tau:.3f}")
    print(f"    Soft: {soft.numpy()}")
    print(f"    Hard: {hard.numpy()}")


# ============================================================================
# 6.3 FairNAS Demo
# ============================================================================

print("\n--- FairNAS (Fairness-Aware) ---")

# Demonstrate fairness loss
alphas_unfair = [torch.tensor([5.0, 0.1, 0.1, 0.1, 0.1])]  # Skip dominates
alphas_fair = [torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])]  # Balanced

loss_unfair = compute_fairness_loss(alphas_unfair)
loss_fair = compute_fairness_loss(alphas_fair)

print(f"Fairness loss (unfair): {loss_unfair:.4f}")
print(f"Fairness loss (fair): {loss_fair:.4f}")

# Quick FairNAS training
fair_model = FairNASNetwork(C=16, n_cells=2, n_nodes=2, lambda_fair=0.2).to(device)
optimizer_fair = torch.optim.Adam(fair_model.parameters(), lr=1e-3)

print("\nTraining FairNAS:")
for epoch in range(3):
    fair_model.train()
    epoch_task_loss = 0
    epoch_fair_loss = 0
    
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer_fair.zero_grad()
        
        total_loss, task_loss, fair_loss = fair_model.loss(x, y)
        
        total_loss.backward()
        optimizer_fair.step()
        
        epoch_task_loss += task_loss
        epoch_fair_loss += fair_loss
    
    n = len(train_loader)
    print(f"Epoch {epoch+1}: Task={epoch_task_loss/n:.4f}, Fair={epoch_fair_loss/n:.4f}")


# ============================================================================
# 6.4 SNIP Zero-Cost Proxy Demo
# ============================================================================

print("\n--- SNIP (Zero-Cost Proxy) ---")

# Create small search space
search_space = [
    {'type': 'small', 'C': 8, 'cells': 2},
    {'type': 'medium', 'C': 16, 'cells': 2},
    {'type': 'large', 'C': 32, 'cells': 3},
]

print("Ranking architectures with SNIP (no training):")

# Get small subset of data for SNIP
snip_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_subset, range(100)),
    batch_size=32, shuffle=False
)

snip_scores = []
for i, config in enumerate(search_space):
    model_snip = SearchNetwork(C=config['C'], n_cells=config['cells'], n_nodes=2).to(device)
    score = compute_snip_score(model_snip, snip_loader, num_batches=2)
    snip_scores.append((config['type'], score))
    print(f"{config['type']:10s}: SNIP score = {score:.2f}")

snip_scores.sort(key=lambda x: x[1], reverse=True)
print(f"\nBest architecture (SNIP): {snip_scores[0][0]}")


# ============================================================================
# 6.5 Visualization: Method Comparison
# ============================================================================

print("\n" + "="*80)
print("Summary: NAS Methods Comparison")
print("="*80)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. PC-DARTS memory efficiency
methods = ['DARTS', 'PC-DARTS\n(K=2)', 'PC-DARTS\n(K=4)', 'GDAS']
memory = [100, 50, 25, 12.5]  # Relative memory usage

axes[0, 0].bar(methods, memory, color=['red', 'orange', 'green', 'blue'], alpha=0.7)
axes[0, 0].set_ylabel('Memory Usage (%)', fontsize=11)
axes[0, 0].set_title('Memory Efficiency', fontsize=12, fontweight='bold')
axes[0, 0].axhline(y=50, color='gray', linestyle='--', linewidth=1, label='50% baseline')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3, axis='y')

# 2. Gumbel temperature annealing
progress_vals = np.linspace(0, 1, 50)
tau_max, tau_min = 5.0, 0.1
tau_curve = tau_max * (tau_min / tau_max) ** progress_vals

axes[0, 1].plot(progress_vals, tau_curve, 'b-', linewidth=2)
axes[0, 1].fill_between(progress_vals, tau_curve, alpha=0.3)
axes[0, 1].set_xlabel('Training Progress', fontsize=11)
axes[0, 1].set_ylabel('Temperature Ο„', fontsize=11)
axes[0, 1].set_title('GDAS Temperature Annealing', fontsize=12, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, tau_max * 1.1])

# 3. Fairness loss effect
epochs_demo = np.arange(1, 11)
unfair_arch = 0.8 - 0.05 * epochs_demo  # Skip connection dominates, lower diversity
fair_arch = 0.6 + 0.02 * epochs_demo  # FairNAS maintains diversity

axes[1, 0].plot(epochs_demo, unfair_arch, 'r-o', label='No Fairness', linewidth=2, markersize=6)
axes[1, 0].plot(epochs_demo, fair_arch, 'g-s', label='With Fairness', linewidth=2, markersize=6)
axes[1, 0].set_xlabel('Epoch', fontsize=11)
axes[1, 0].set_ylabel('Architecture Diversity', fontsize=11)
axes[1, 0].set_title('FairNAS: Exploration vs. Exploitation', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. SNIP ranking correlation
# Simulate: SNIP score vs. true accuracy correlation
np.random.seed(42)
true_acc = np.random.rand(20) * 20 + 70  # True accuracy 70-90%
snip_score_sim = true_acc + np.random.randn(20) * 3  # Correlated with noise

axes[1, 1].scatter(snip_score_sim, true_acc, alpha=0.6, s=80, c='purple', edgecolors='black')
axes[1, 1].set_xlabel('SNIP Score', fontsize=11)
axes[1, 1].set_ylabel('True Accuracy (%)', fontsize=11)
axes[1, 1].set_title('Zero-Cost Proxy Correlation', fontsize=12, fontweight='bold')

# Add correlation line
z = np.polyfit(snip_score_sim, true_acc, 1)
p = np.poly1d(z)
x_line = np.linspace(snip_score_sim.min(), snip_score_sim.max(), 100)
axes[1, 1].plot(x_line, p(x_line), 'r--', linewidth=2, label=f'Correlation (Οβ‰ˆ0.8)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*80)
print("Key Takeaways:")
print("="*80)
print("1. PC-DARTS: 4Γ— memory reduction via partial channel sampling")
print("2. GDAS: Single-path search with Gumbel-Softmax (discrete sampling)")
print("3. FairNAS: Regularization prevents skip connection dominance")
print("4. SNIP: Zero-cost proxy enables fast architecture ranking without training")
print("5. Trade-offs: Memory vs. accuracy, speed vs. thoroughness, exploration vs. exploitation")
print("\nNext: Apply NAS to custom search spaces and hardware-aware optimization!")

Advanced Neural Architecture Search TheoryΒΆ

1. Introduction to Neural Architecture Search (NAS)ΒΆ

1.1 MotivationΒΆ

Traditional approach: Manual architecture design

  • Requires expert knowledge

  • Time-consuming trial-and-error

  • May miss optimal designs

NAS approach: Automated architecture discovery

  • Search space: Set of possible architectures

  • Search strategy: How to explore space

  • Performance estimation: How to evaluate architectures

Success stories:

  • NASNet (2017): Surpassed human-designed architectures on ImageNet

  • EfficientNet (2019): SOTA accuracy with 8.4Γ— fewer parameters

  • GPT-3 (2020): Architecture hyperparameters via NAS principles

1.2 NAS Framework ComponentsΒΆ

1. Search Space:

Ξ‘ = {α₁, Ξ±β‚‚, ..., Ξ±β‚™}  (set of candidate architectures)

Defines:

  • Macro-search: Entire network topology

  • Micro-search: Cell structure (repeated building blocks)

  • Hyperparameters: Depth, width, kernel sizes, etc.

2. Search Strategy: Methods to navigate search space:

  • Random search: Sample uniformly

  • Evolutionary algorithms: Mutation + selection

  • Reinforcement learning: Controller RNN generates architectures

  • Gradient-based: Differentiable architecture search (DARTS)

  • Bayesian optimization: Model performance landscape

3. Performance Estimation: How to evaluate architecture Ξ±:

  • Full training: Train to convergence (expensive!)

  • Early stopping: Train for few epochs

  • Weight sharing: One-shot models

  • Proxy tasks: Smaller datasets, lower resolution

Objective:

Ξ±* = argmax_{Ξ±βˆˆΞ‘} Accuracy(Ξ±, D_val)

subject to constraints (e.g., FLOPs < budget).

2. Search SpacesΒΆ

2.1 Chain-Structured NetworksΒΆ

Simple sequential: layer₁ β†’ layerβ‚‚ β†’ … β†’ layerβ‚™

Search choices per layer:

  • Operation: Conv, MaxPool, Identity, Zero

  • Kernel size: 3Γ—3, 5Γ—5, 7Γ—7

  • Filters: 16, 32, 64, 128

  • Activation: ReLU, Swish, Mish

Size: Exponential in depth

|Ξ‘| = O^L  where O operations, L layers

2.2 Cell-Based Search SpaceΒΆ

Motivation: Reduce search space by searching for repeatable cells.

Architecture = Stack of cells:

Input β†’ Cell₁ β†’ Cellβ‚‚ β†’ ... β†’ Cellβ‚™ β†’ Output

Cell structure (DAG):

  • Nodes: Feature maps at different resolutions

  • Edges: Operations (Conv, Pool, etc.)

  • Input nodes: Previous cell outputs

  • Output node: Concatenate intermediate nodes

NASNet search space:

  • Normal cell: Preserves spatial dimensions

  • Reduction cell: Downsamples (stride 2)

Size:

|Ξ‘| β‰ˆ O^E  where E edges in cell (much smaller than O^L)

2.3 Hierarchical Search SpacesΒΆ

Multiple levels:

  1. Micro: Operations within cell

  2. Meso: Cell connectivity

  3. Macro: Number of cells, channels

Example (EfficientNet):

  • Depth: d ∈ [1.0, 2.0, 3.0, …]

  • Width: w ∈ [1.0, 1.5, 2.0, …]

  • Resolution: r ∈ [224, 260, 300, …]

Compound scaling:

depth: d = Ξ±^Ο†
width: w = Ξ²^Ο†
resolution: r = Ξ³^Ο†

subject to Ξ±Β·Ξ²Β²Β·Ξ³Β² β‰ˆ 2 (resource constraint).

2.4 Operation SpaceΒΆ

Common operations:

  1. Convolutions:

    • Standard: 3Γ—3, 5Γ—5, 7Γ—7

    • Depthwise separable (MobileNet-style)

    • Dilated convolutions (atrous)

  2. Pooling:

    • Max pooling: 3Γ—3, 5Γ—5

    • Average pooling

    • Global average pooling

  3. Skip connections:

    • Identity

    • 1Γ—1 projection

  4. Zero operation:

    • Prunes the edge

  5. Activations:

    • ReLU, ReLU6

    • Swish, Mish

    • GELU

3. Reinforcement Learning-Based NASΒΆ

3.1 NAS with RL (Zoph & Le, 2017)ΒΆ

Idea: Controller RNN generates architectures, trained via REINFORCE.

Controller RNN:

  • Outputs sequence of decisions (layer types, hyperparameters)

  • Example: [Conv 3Γ—3, 64 filters, ReLU, MaxPool 2Γ—2, Conv 5Γ—5, 128 filters, …]

Training procedure:

  1. Sample architecture: Ξ± ~ Ο€_ΞΈ(Ξ±) (controller policy)

  2. Train child network: Get validation accuracy R(Ξ±)

  3. Update controller: Maximize expected reward

    βˆ‡_ΞΈ J(ΞΈ) = E_{Ξ±~Ο€_ΞΈ} [R(Ξ±) βˆ‡_ΞΈ log Ο€_ΞΈ(Ξ±)]
    

Reward: Validation accuracy (or accuracy - λ·complexity)

REINFORCE gradient:

βˆ‡_ΞΈ J(ΞΈ) β‰ˆ (1/m) Ξ£α΅’ (R(Ξ±α΅’) - b) βˆ‡_ΞΈ log Ο€_ΞΈ(Ξ±α΅’)

where b is baseline (running average of rewards).

3.2 Efficiency ImprovementsΒΆ

Problem: Each architecture requires full training (expensive!).

Solutions:

  1. Early stopping: Train for 5-10 epochs instead of 100+

  2. Learning curve prediction: Extrapolate from early training

  3. Weight inheritance: Initialize child from parent

Computational cost (original NAS):

  • 800 GPUs for 28 days (22,400 GPU-days!)

  • Modern methods: 1-4 GPU-days

3.3 Block-Level Search (NASNet)ΒΆ

Search for cell structure, not entire network.

Normal cell:

  • Input: H_i, H_{i-1} (current and previous cell outputs)

  • Operations: Combine via learned ops

  • Output: Concatenate intermediate nodes

Reduction cell:

  • Similar structure but with stride 2 operations

Final architecture:

Stem β†’ [N Γ— Normal β†’ Reduction] Γ— 3 β†’ Global Pool β†’ FC

Search result (ImageNet):

  • NASNet-A: 82.7% top-1 accuracy

  • Human-designed (MobileNet): 81.5%

4. Evolutionary Algorithms for NASΒΆ

4.2 Regularized Evolution (Real et al., 2019)ΒΆ

Key idea: Age-based removal instead of fitness-based.

Aging tournament selection:

  1. Sample S architectures from population

  2. Remove oldest architecture

  3. Mutate best of S, add to population

Advantages:

  • Prevents premature convergence

  • Maintains diversity

  • Better exploration

Results:

  • AmoebaNet: 83.9% ImageNet top-1 (SOTA at time)

  • 450 GPU-days (vs 22,400 for original NAS)

4.3 Crossover StrategiesΒΆ

Single-point crossover:

Parent 1: [Op1, Op2, Op3, Op4, Op5]
Parent 2: [Op6, Op7, Op8, Op9, Op10]
          ↓ (crossover at position 3)
Child:    [Op1, Op2, Op3, Op9, Op10]

Uniform crossover: Each gene (operation) randomly from either parent.

Cell-level crossover: Combine cells from different parents.

5. Differentiable Architecture Search (DARTS)ΒΆ

5.1 Continuous RelaxationΒΆ

Problem: Discrete architecture choices β†’ can’t use gradient descent.

Solution: Relax discrete choices to continuous (differentiable).

Continuous architecture: Instead of selecting one operation, weight all operations:

f(x) = Σ_{o∈O} (exp(α_o) / Σ_{o'} exp(α_{o'})) o(x)

where Ξ±_o are architecture parameters (learnable).

Softmax over operations:

f(x) = Ξ£_o softmax(Ξ±)_o Β· o(x)

5.2 Bi-Level OptimizationΒΆ

Two sets of parameters:

  1. Network weights: w (operation parameters)

  2. Architecture parameters: Ξ± (operation selection)

Objective:

min_Ξ±  L_val(w*(Ξ±), Ξ±)
s.t.   w*(Ξ±) = argmin_w L_train(w, Ξ±)

Interpretation:

  • Inner optimization: Train network weights w given architecture Ξ±

  • Outer optimization: Find architecture Ξ± that minimizes validation loss

5.3 Approximation via Gradient DescentΒΆ

Challenges: Inner optimization expensive (requires full convergence).

Approximation: One-step gradient descent

w' = w - ΞΎ βˆ‡_w L_train(w, Ξ±)

Architecture gradient:

βˆ‡_Ξ± L_val(w', Ξ±) β‰ˆ βˆ‡_Ξ± L_val(w - ΞΎ βˆ‡_w L_train(w, Ξ±), Ξ±)

Chain rule:

βˆ‡_Ξ± L_val = βˆ‡_Ξ± L_val(w', Ξ±) - ΞΎ βˆ‡_Ξ±,wΒ² L_train Β· βˆ‡_w L_val(w', Ξ±)

where βˆ‡_Ξ±,wΒ² is Hessian (expensive!).

Finite difference approximation:

βˆ‡_Ξ±,wΒ² L_train Β· v β‰ˆ (βˆ‡_w L_train(w+Ρ·v, Ξ±) - βˆ‡_w L_train(w-Ρ·v, Ξ±)) / (2Ξ΅)

5.4 DiscretizationΒΆ

After search: Continuous Ξ± β†’ discrete architecture.

Method: For each edge, select operation with largest Ξ±_o:

o* = argmax_o Ξ±_o

Pruning: Remove edges with weak connections (optional).

5.5 DARTS ResultsΒΆ

Efficiency:

  • Search time: 4 GPU-days (vs 22,400 for NAS)

  • Search cost: 1000Γ— reduction

Performance:

  • CIFAR-10: 2.76% error (comparable to NASNet)

  • ImageNet: 73.3% top-1 (competitive)

Advantages:

  • Gradient-based: Fast, scalable

  • End-to-end differentiable

  • No need for RL or evolution

Limitations:

  • Discretization gap (continuous β‰  discrete)

  • Overfitting to validation set

  • Performance collapse (all skip connections)

6. One-Shot NAS and Weight SharingΒΆ

6.1 Supernet TrainingΒΆ

Idea: Train one β€œsupernet” containing all possible operations.

Supernet:

f_super(x) = Σ_{path ∈ Paths} I[path sampled] · f_path(x)

where paths correspond to different architectures.

Weight sharing: Operations share weights across architectures.

Training:

  1. Sample architecture (path) from supernet

  2. Train on mini-batch

  3. Update only sampled path’s weights

  4. Repeat

Search: Evaluate architectures using shared weights (no retraining!).

6.2 ENAS (Efficient NAS)ΒΆ

Idea: Weight sharing + RL controller.

Algorithm:

  1. Controller: Sample architecture Ξ± ~ Ο€_ΞΈ(Ξ±)

  2. Train child: Use shared weights from supernet, train briefly

  3. Evaluate: Get validation accuracy R(Ξ±)

  4. Update controller: REINFORCE gradient

  5. Update supernet: Train shared weights on sampled architectures

Speedup: 1000Γ— faster than original NAS (1 GPU-day).

Results:

  • CIFAR-10: 2.89% error

  • Penn Treebank: 55.8 perplexity (language modeling)

6.3 Single-Path One-Shot (SPOS)ΒΆ

Training phase: Uniformly sample one path (architecture) per batch.

For each batch:
    Ξ± ~ Uniform(all architectures)
    Update weights w using Ξ±

Search phase: Evolutionary search over architectures, evaluate using shared weights.

Advantages:

  • Decouples training from search

  • No RL controller needed

  • Very efficient

6.4 FairNASΒΆ

Problem: Weight sharing biases towards certain architectures.

Solution: Fair sampling strategies.

Techniques:

  1. Fair path sampling: Ensure all operations sampled equally

  2. Sandwich rule: Train largest and smallest sub-networks

  3. Calibration: Normalize operation statistics

7. Hardware-Aware NASΒΆ

7.1 Multi-Objective OptimizationΒΆ

Objectives:

  • Accuracy: Maximize validation accuracy

  • Latency: Minimize inference time (mobile, edge devices)

  • Energy: Minimize power consumption

  • Memory: Fit in device constraints

  • FLOPs: Reduce computational cost

Pareto front:

α* ∈ argmax_{α} {(Acc(α), -Latency(α), -Energy(α))}

No single architecture dominates all objectives.

7.3 Device-Specific OptimizationΒΆ

ProxylessNAS:

  • Train supernet on GPU

  • Evaluate latency on target device (mobile CPU, GPU, etc.)

  • Architecture-specific for each device

Results:

  • Mobile: 75.1% / 199ms (CPU), 74.6% / 80ms (GPU)

  • Different architectures optimal for different hardware!

7.4 Once-For-All (OFA) NetworksΒΆ

Idea: Train once, deploy to many devices.

Progressive shrinking:

  1. Train largest network (full depth, width, resolution)

  2. Progressively train smaller sub-networks

  3. Final supernet supports any configuration

Deployment:

  • Given device constraints (latency, memory)

  • Search for best sub-network within constraints

  • Extract and deploy (no retraining!)

Results:

  • ImageNet: 80.0% top-1 with 230 FLOPs (SOTA efficiency)

  • Supports 10^19 sub-networks

8. Performance Estimation StrategiesΒΆ

8.1 Full TrainingΒΆ

Method: Train each architecture to convergence.

Pros:

  • Accurate performance estimate

  • No bias

Cons:

  • Extremely expensive (hours per architecture)

  • Infeasible for large search spaces

8.2 Early StoppingΒΆ

Method: Train for few epochs (e.g., 5-10), use as proxy for final accuracy.

Assumptions:

  • Rank correlation: Early accuracy β‰ˆ final accuracy (rank-wise)

Challenges:

  • Correlation not perfect

  • Some architectures are β€œslow starters”

8.3 Learning Curve ExtrapolationΒΆ

Method: Fit learning curve, predict final accuracy.

Models:

  • Power law: a(t) = Ξ± - Ξ²Β·t^(-Ξ³)

  • Exponential: a(t) = Ξ± - Ξ²Β·exp(-Ξ³t)

  • Bayesian neural networks: Learn curve distribution

Advantages:

  • Stop training early for bad architectures

  • Allocate more resources to promising ones

8.4 Weight Sharing / One-ShotΒΆ

Method: All architectures share weights in supernet.

Evaluation: Sample architecture, evaluate on validation set (no training!).

Pros:

  • Extremely fast (seconds per architecture)

  • Enables large-scale search

Cons:

  • Performance estimates biased

  • Ranking correlation with standalone training: 0.5-0.7

8.5 Performance PredictorsΒΆ

Idea: Train surrogate model to predict accuracy from architecture encoding.

Encoding:

  • Adjacency matrix: For DAG-based cells

  • Path encoding: Sequence of operations

  • Graph neural network: Process architecture as graph

Predictor:

Predictor: encode(Ξ±) β†’ predicted_accuracy

Training:

  • Collect dataset: (Ξ±, accuracy) pairs

  • Train regression model (MLP, GNN, Transformer)

  • Use predictor to evaluate new architectures

Uncertainty quantification:

  • Bayesian predictors

  • Ensemble predictors

  • Gaussian process

9. Transferability and GeneralizationΒΆ

9.1 Transfer from Proxy TasksΒΆ

Common proxies:

  • CIFAR-10 β†’ ImageNet: Search on CIFAR, transfer to ImageNet

  • Low resolution β†’ High resolution: Search at 32Γ—32, deploy at 224Γ—224

  • Fewer epochs: Search with 50 epochs, deploy with 300

Challenges:

  • Rank correlation not perfect

  • Best on CIFAR β‰  best on ImageNet

Transfer techniques:

  • Fine-tuning: Adjust architecture on target task

  • Scaling rules: Adjust depth/width for target dataset

9.2 Cross-Task GeneralizationΒΆ

Question: Does architecture found for Task A work on Task B?

Findings:

  • Within domain: Good transfer (ImageNet β†’ COCO detection)

  • Across domains: Moderate transfer (vision β†’ NLP)

  • Task-specific: Some tasks need specialized architectures

Examples:

  • NASNet (image classification) β†’ Object detection: Good

  • EfficientNet β†’ Semantic segmentation: Good

  • Vision architectures β†’ Language modeling: Poor

9.3 Domain AdaptationΒΆ

Search for architecture on source domain, deploy on target.

Techniques:

  1. Multi-task search: Search jointly on source + target

  2. Domain-invariant operations: Prefer ops that transfer well

  3. Meta-NAS: Learn to search across domains

10. Advanced NAS MethodsΒΆ

10.1 Meta-Learning for NASΒΆ

Idea: Learn the search process itself.

Meta-NAS:

  • Meta-learner: Predicts good architectures for new tasks

  • Trained on distribution of tasks

  • Fast adaptation to new task

Applications:

  • Few-shot learning: Find architecture from few examples

  • Continual learning: Adapt architecture as task evolves

10.2 Multi-Objective NASΒΆ

Pareto optimization:

Objectives:

f₁(Ξ±) = Accuracy(Ξ±)
fβ‚‚(Ξ±) = -Latency(Ξ±)
f₃(Ξ±) = -FLOPs(Ξ±)

Methods:

  • Evolutionary: NSGA-II (Non-dominated Sorting GA)

  • Scalarization: λ₁f₁ + Ξ»β‚‚fβ‚‚ + λ₃f₃

  • Hypervolume optimization

Results: Set of Pareto-optimal architectures (user picks based on constraints).

10.3 Neural Architecture Optimization (NAO)ΒΆ

Idea: Model architecture performance as continuous function.

Encoder: Architecture Ξ± β†’ embedding z Predictor: z β†’ performance Decoder: z β†’ architecture α’

Optimization:

z* = argmax_z Predictor(z)
Ξ±* = Decoder(z*)

Advantages:

  • Continuous optimization (gradient-based)

  • No need for supernet or RL

10.4 AutoML-ZeroΒΆ

Radical idea: Search from scratch (basic operations).

Search space:

  • Primitive ops: +, -, Γ—, /, exp, log, sin, max, min

  • No domain knowledge (no convolutions!)

Discovery:

  • Rediscovered linear regression, ReLU, normalized gradients

  • Proof-of-concept for automated science

11. Benchmarks and ComparisonsΒΆ

11.1 NAS-Bench-101ΒΆ

Dataset: 423,624 unique architectures, all trained on CIFAR-10.

Architecture space:

  • Cell-based (directed acyclic graph)

  • 3 ops per cell: 3Γ—3 conv, 1Γ—1 conv, 3Γ—3 max pool

Metrics:

  • Training time

  • Validation/test accuracy

  • Number of parameters

Usage:

  • Researchers: Test NAS algorithms without expensive training

  • Query: Get performance of architecture in O(1) time

11.2 NAS-Bench-201ΒΆ

Improvements over 101:

  • Fixed topology, search over operations

  • 15,625 architectures (smaller, fully evaluated)

  • Multiple datasets: CIFAR-10, CIFAR-100, ImageNet-16-120

Insights:

  • Simple random search competitive with sophisticated methods!

  • Weight sharing biases rankings

  • Early stopping correlation varies

11.3 Performance ComparisonΒΆ

ImageNet top-1 accuracy (2024 SOTA):

Manual designs:
- ResNet-50: 76.2%
- EfficientNet-B0 (manual scaling): 77.1%

NAS-discovered:
- NASNet-A: 82.7% (but inefficient)
- AmoebaNet-A: 83.9%
- EfficientNet-B7 (NAS): 84.4%
- Once-For-All: 80.0% (efficient)

Search cost:

NAS (2017): 22,400 GPU-days
NASNet (2018): 2,000 GPU-days
ENAS (2018): 0.5 GPU-days
DARTS (2019): 4 GPU-days
Once-For-All (2020): 1,200 GPU-hours (50 GPU-days, but train once)

12. Practical ConsiderationsΒΆ

12.1 Search Space DesignΒΆ

Trade-offs:

  • Large space: More potential, but expensive to search

  • Small space: Efficient search, but may miss optimal

Best practices:

  1. Start with proven components: Conv, pooling, skip connections

  2. Constrain based on domain knowledge: Vision β‰  NLP

  3. Hierarchical spaces: Micro + macro search

12.2 Validation Set ManagementΒΆ

Problem: Overfitting to validation set during search.

Solutions:

  1. Hold-out test set: Only use after search complete

  2. Cross-validation: Rotate validation sets

  3. Regularization: Penalize complexity during search

12.3 ReproducibilityΒΆ

Challenges:

  • Randomness in search (RL, evolution)

  • Stochasticity in training

  • Hardware differences

Best practices:

  • Report seeds, hyperparameters

  • Average over multiple runs

  • Open-source code and checkpoints

12.4 Computational BudgetΒΆ

Fixed budget: Given B GPU-hours, how to allocate?

Strategies:

  1. Many short runs: Sample many architectures, train briefly

  2. Few long runs: Sample few, train to convergence

  3. Adaptive: Allocate based on promise (bandit algorithms)

Optimal allocation:

  • Depends on search space quality

  • Exploration vs. exploitation trade-off

13. Applications Beyond Image ClassificationΒΆ

13.1 Object DetectionΒΆ

Challenges:

  • Multi-scale features (FPN, PAFPN)

  • Backbone + neck + head architecture

Approaches:

  • NAS-FPN: Search for feature pyramid network topology

  • DetNAS: End-to-end detection-aware NAS

  • EfficientDet: Compound scaling for detection

Results:

  • NAS-FPN: 40.5 mAP on COCO (vs 36.8 for hand-designed FPN)

13.2 Semantic SegmentationΒΆ

Auto-DeepLab:

  • Search for encoder-decoder architecture

  • Multi-scale feature aggregation

  • Cell-level search

Results:

  • Cityscapes: 82.1 mIoU (SOTA for NAS)

13.3 Language ModelingΒΆ

Evolved Transformer:

  • Search for Transformer variants

  • Evolutionary algorithm

Discovered architecture:

  • Gated linear units instead of FFN

  • Different attention patterns

  • 0.7 perplexity improvement on WMT’14

13.4 Reinforcement LearningΒΆ

Search for RL agent architectures:

  • Policy network structure

  • Value function approximator

Applications:

  • Atari games

  • Continuous control (robotics)

14. Limitations and ChallengesΒΆ

14.1 Computational CostΒΆ

Despite improvements, still expensive:

  • Search: 4-50 GPU-days (vs minutes for manual design)

  • Amortized over multiple deployments, but not accessible to all

Barriers:

  • Academia: Limited compute resources

  • Industry: Cost-benefit analysis

14.2 Overfitting to Search SpaceΒΆ

Search space design is crucial:

  • If optimal architecture not in space, NAS won’t find it

  • Human bias embedded in search space

Example: If search space only includes CNNs, can’t discover Transformers.

14.3 Generalization GapΒΆ

Problem: Performance on proxy task β‰  performance on target task.

Causes:

  • Dataset shift

  • Training recipe differences

  • Scaling effects

14.4 Reproducibility and FairnessΒΆ

Comparison challenges:

  • Different search spaces

  • Different computational budgets

  • Different training recipes

Need: Standardized benchmarks (NAS-Bench), fair comparisons.

15. Future DirectionsΒΆ

15.2 Neural Architecture UnderstandingΒΆ

Interpretability:

  • Why does architecture A outperform B?

  • What architectural features are important?

Architecture analysis:

  • Feature visualization

  • Ablation studies

  • Causal analysis

15.4 Continual Architecture AdaptationΒΆ

Dynamic architectures:

  • Adapt to data distribution shifts

  • Grow/shrink based on task complexity

Lifelong learning:

  • Accumulate architectural knowledge

  • Transfer to new tasks

16. Key TakeawaysΒΆ

  1. NAS automates architecture design:

    • Search space + search strategy + performance estimation

    • Can discover architectures rivaling human experts

  2. Major paradigms:

    • RL-based: Controller samples architectures (NAS, ENAS)

    • Evolutionary: Mutation + selection (AmoebaNet)

    • Gradient-based: Differentiable search (DARTS)

    • One-shot: Weight sharing supernet (SPOS, OFA)

  3. Efficiency crucial:

    • Original NAS: 22,400 GPU-days

    • Modern methods: 1-50 GPU-days

    • Key: Weight sharing, early stopping, predictors

  4. Hardware-aware NAS:

    • Multi-objective: accuracy + latency + energy

    • Device-specific architectures

    • Once-for-all networks for flexible deployment

  5. Search space design matters:

    • Cell-based reduces complexity

    • Hierarchical spaces (micro + macro)

    • Domain knowledge still important

  6. Applications beyond classification:

    • Detection, segmentation, NLP, RL

    • Transfer from proxy tasks

  7. Open challenges:

    • Computational cost

    • Overfitting to search space

    • Generalization gap

    • Reproducibility

  8. Future: Towards AutoML:

    • Joint architecture + training search

    • Continual adaptation

    • Democratization of NAS

17. ReferencesΒΆ

Foundational papers:

  • Zoph & Le (2017): β€œNeural Architecture Search with Reinforcement Learning” (ICLR)

  • Zoph et al. (2018): β€œLearning Transferable Architectures for Scalable Image Recognition” (CVPR - NASNet)

  • Real et al. (2019): β€œRegularized Evolution for Image Classifier Architecture Search” (AAAI - AmoebaNet)

  • Liu et al. (2019): β€œDARTS: Differentiable Architecture Search” (ICLR)

Efficiency improvements:

  • Pham et al. (2018): β€œEfficient Neural Architecture Search via Parameter Sharing” (ICML - ENAS)

  • Guo et al. (2020): β€œSingle Path One-Shot Neural Architecture Search with Uniform Sampling” (ECCV - SPOS)

Hardware-aware:

  • Tan et al. (2019): β€œMnasNet: Platform-Aware Neural Architecture Search for Mobile” (CVPR)

  • Cai et al. (2020): β€œOnce-for-All: Train One Network and Specialize it for Efficient Deployment” (ICLR)

Benchmarks:

  • Ying et al. (2019): β€œNAS-Bench-101: Towards Reproducible Neural Architecture Search” (ICML)

  • Dong & Yang (2020): β€œNAS-Bench-201: Extending the Scope of Reproducible NAS” (ICLR)

Surveys:

  • Elsken et al. (2019): β€œNeural Architecture Search: A Survey” (JMLR)

  • Wistuba et al. (2019): β€œA Survey on Neural Architecture Search” (arXiv)

"""
Complete Neural Architecture Search Implementations
====================================================
Includes: DARTS (Differentiable), RL-based NAS controller, Evolutionary search,
One-shot supernet, performance predictors, hardware-aware search.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import namedtuple
import random

# ============================================================================
# 1. DARTS (Differentiable Architecture Search)
# ============================================================================

# Define possible operations
OPS = {
    'none': lambda C, stride: Zero(stride),
    'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
    'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 1),
    'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 2),
    'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
    'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),
    'avg_pool_3x3': lambda C, stride: nn.AvgPool2d(3, stride=stride, padding=1),
    'max_pool_3x3': lambda C, stride: nn.MaxPool2d(3, stride=stride, padding=1),
}


class Zero(nn.Module):
    """Zero operation (drop path)."""
    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride
    
    def forward(self, x):
        if self.stride == 1:
            return x * 0.
        return x[:, :, ::self.stride, ::self.stride] * 0.


class Identity(nn.Module):
    """Identity (skip connection)."""
    def __init__(self):
        super(Identity, self).__init__()
    
    def forward(self, x):
        return x


class SepConv(nn.Module):
    """Depthwise separable convolution."""
    def __init__(self, C_in, C_out, kernel_size, stride, padding):
        super(SepConv, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 
                     padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_in),
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, 
                     padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out),
        )
    
    def forward(self, x):
        return self.op(x)


class DilConv(nn.Module):
    """Dilated convolution."""
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super(DilConv, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 
                     padding=padding, dilation=dilation, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out),
        )
    
    def forward(self, x):
        return self.op(x)


class FactorizedReduce(nn.Module):
    """Reduce spatial dimensions by 2."""
    def __init__(self, C_in, C_out):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.relu = nn.ReLU(inplace=False)
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(C_out)
    
    def forward(self, x):
        x = self.relu(x)
        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
        return self.bn(out)


class MixedOp(nn.Module):
    """
    Mixed operation: weighted sum of all operations.
    
    f(x) = Ξ£_o softmax(Ξ±)_o Β· o(x)
    
    Args:
        C: Number of channels
        stride: Stride for operations
    """
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        for primitive in OPS:
            op = OPS[primitive](C, stride)
            self._ops.append(op)
    
    def forward(self, x, weights):
        """
        Args:
            x: Input tensor
            weights: Architecture weights (softmax over operations)
        Returns:
            Weighted sum of operations
        """
        return sum(w * op(x) for w, op in zip(weights, self._ops))


class DARTSCell(nn.Module):
    """
    DARTS cell (DAG of mixed operations).
    
    Args:
        steps: Number of intermediate nodes
        multiplier: Output = concat of final `multiplier` nodes
        C_prev_prev: Channels from 2 cells ago
        C_prev: Channels from 1 cell ago
        C: Current channels
        reduction: Whether this is reduction cell
    """
    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction):
        super(DARTSCell, self).__init__()
        self.reduction = reduction
        self.steps = steps
        self.multiplier = multiplier
        
        # Preprocess inputs
        if reduction:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C)
            self.preprocess1 = FactorizedReduce(C_prev, C)
        else:
            self.preprocess0 = Identity() if C_prev_prev == C else \
                              nn.Conv2d(C_prev_prev, C, 1, 1, 0, bias=False)
            self.preprocess1 = Identity() if C_prev == C else \
                              nn.Conv2d(C_prev, C, 1, 1, 0, bias=False)
        
        # Build DAG
        self._ops = nn.ModuleList()
        for i in range(self.steps):
            for j in range(2 + i):  # Connect to previous nodes
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(C, stride)
                self._ops.append(op)
    
    def forward(self, s0, s1, weights):
        """
        Args:
            s0: Output from 2 cells ago
            s1: Output from 1 cell ago
            weights: Architecture weights [num_edges, num_ops]
        Returns:
            Cell output (concatenated intermediate nodes)
        """
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        
        states = [s0, s1]
        offset = 0
        
        for i in range(self.steps):
            # Collect inputs from all previous nodes
            s = sum(self._ops[offset + j](h, weights[offset + j]) 
                   for j, h in enumerate(states))
            offset += len(states)
            states.append(s)
        
        # Concatenate last `multiplier` nodes
        return torch.cat(states[-self.multiplier:], dim=1)


class DARTSNetwork(nn.Module):
    """
    DARTS network (stack of cells).
    
    Args:
        C: Initial channels
        num_classes: Number of output classes
        layers: Number of cells
        steps: Intermediate nodes per cell
        multiplier: Concatenation factor
    """
    def __init__(self, C=16, num_classes=10, layers=8, steps=4, multiplier=4):
        super(DARTSNetwork, self).__init__()
        self._C = C
        self._num_classes = num_classes
        self._layers = layers
        self._steps = steps
        self._multiplier = multiplier
        
        # Stem
        C_curr = C * 3  # After stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
            nn.BatchNorm2d(C_curr)
        )
        
        # Build cells
        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
        self.cells = nn.ModuleList()
        reduction_prev = False
        
        for i in range(layers):
            # Reduction cells at 1/3 and 2/3 depth
            if i in [layers // 3, 2 * layers // 3]:
                C_curr *= 2
                reduction = True
            else:
                reduction = False
            
            cell = DARTSCell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction)
            self.cells.append(cell)
            
            C_prev_prev, C_prev = C_prev, multiplier * C_curr
            reduction_prev = reduction
        
        # Classifier
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_prev, num_classes)
        
        # Initialize architecture parameters
        num_ops = len(OPS)
        num_edges = sum(2 + i for i in range(steps))
        self._arch_parameters = nn.ParameterList([
            nn.Parameter(1e-3 * torch.randn(num_edges, num_ops)),  # Normal cell
            nn.Parameter(1e-3 * torch.randn(num_edges, num_ops)),  # Reduction cell
        ])
    
    def forward(self, x):
        """Forward pass with current architecture."""
        s0 = s1 = self.stem(x)
        
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                weights = F.softmax(self._arch_parameters[1], dim=-1)
            else:
                weights = F.softmax(self._arch_parameters[0], dim=-1)
            
            s0, s1 = s1, cell(s0, s1, weights)
        
        out = self.global_pooling(s1)
        out = out.view(out.size(0), -1)
        logits = self.classifier(out)
        
        return logits
    
    def arch_parameters(self):
        """Return architecture parameters for optimization."""
        return self._arch_parameters
    
    def genotype(self):
        """Discretize architecture (select top operation per edge)."""
        def _parse(weights):
            gene = []
            n = 2
            start = 0
            for i in range(self._steps):
                end = start + n
                W = weights[start:end].copy()
                
                # Select top-2 edges
                edges = sorted(range(i + 2), 
                             key=lambda x: -max(W[x][k] for k in range(len(W[x])) 
                                               if OPS[list(OPS.keys())[k]] != 'none'))[:2]
                
                for j in edges:
                    k_best = None
                    for k in range(len(W[j])):
                        if OPS[list(OPS.keys())[k]] == 'none':
                            continue
                        if k_best is None or W[j][k] > W[j][k_best]:
                            k_best = k
                    gene.append((list(OPS.keys())[k_best], j))
                
                start = end
                n += 1
            
            return gene
        
        gene_normal = _parse(F.softmax(self._arch_parameters[0], dim=-1).data.cpu().numpy())
        gene_reduce = _parse(F.softmax(self._arch_parameters[1], dim=-1).data.cpu().numpy())
        
        return {'normal': gene_normal, 'reduce': gene_reduce}


# ============================================================================
# 2. RL-Based NAS Controller
# ============================================================================

class NASController(nn.Module):
    """
    RNN controller for neural architecture search (Zoph & Le, 2017).
    
    Generates architecture decisions sequentially.
    
    Args:
        num_layers: Number of layers to generate
        num_branches: Branching factor
        lstm_size: LSTM hidden size
        lstm_layers: Number of LSTM layers
        temperature: Softmax temperature
    """
    def __init__(self, num_layers=12, num_branches=6, lstm_size=32, 
                 lstm_layers=2, temperature=5.0):
        super(NASController, self).__init__()
        self.num_layers = num_layers
        self.num_branches = num_branches
        self.lstm_size = lstm_size
        self.temperature = temperature
        
        # LSTM controller
        self.lstm = nn.LSTM(input_size=lstm_size, hidden_size=lstm_size, 
                           num_layers=lstm_layers)
        
        # Embedding for previous decision
        self.embedding = nn.Embedding(num_branches + 1, lstm_size)
        
        # Decoder (output layer)
        self.decoder = nn.Linear(lstm_size, num_branches)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize controller weights."""
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.uniform_(param, -0.1, 0.1)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
    
    def forward(self, batch_size=1):
        """
        Sample architecture.
        
        Returns:
            architecture: List of sampled operations [num_layers]
            log_probs: Log probabilities [num_layers]
            entropies: Entropies [num_layers]
        """
        # Initial hidden state
        h = torch.zeros(self.lstm.num_layers, batch_size, self.lstm_size)
        c = torch.zeros(self.lstm.num_layers, batch_size, self.lstm_size)
        
        # Start token
        inputs = self.embedding(torch.zeros(batch_size, dtype=torch.long))
        
        architecture = []
        log_probs = []
        entropies = []
        
        for layer in range(self.num_layers):
            # LSTM forward
            inputs = inputs.unsqueeze(0)
            output, (h, c) = self.lstm(inputs, (h, c))
            output = output.squeeze(0)
            
            # Decode to logits
            logits = self.decoder(output) / self.temperature
            probs = F.softmax(logits, dim=-1)
            log_prob = F.log_softmax(logits, dim=-1)
            
            # Sample operation
            action = torch.multinomial(probs, 1).squeeze(1)
            
            # Record
            architecture.append(action.item())
            log_probs.append(log_prob.gather(1, action.unsqueeze(1)).squeeze(1))
            
            # Entropy (for exploration bonus)
            entropy = -(log_prob * probs).sum(dim=-1)
            entropies.append(entropy)
            
            # Next input = embedding of current action
            inputs = self.embedding(action)
        
        log_probs = torch.stack(log_probs, dim=1)  # [batch, num_layers]
        entropies = torch.stack(entropies, dim=1)
        
        return architecture, log_probs, entropies
    
    def train_step(self, rewards, baseline, optimizer, entropy_weight=0.0001):
        """
        REINFORCE update.
        
        Args:
            rewards: Validation accuracy [batch]
            baseline: Running average reward
            optimizer: Controller optimizer
            entropy_weight: Coefficient for entropy bonus
        """
        # Sample architecture
        architecture, log_probs, entropies = self.forward(batch_size=len(rewards))
        
        # REINFORCE loss
        advantages = rewards - baseline
        loss = -(log_probs * advantages.unsqueeze(1)).sum(dim=1).mean()
        
        # Entropy bonus (encourage exploration)
        entropy_bonus = entropies.mean()
        loss = loss - entropy_weight * entropy_bonus
        
        # Update
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), 5.0)
        optimizer.step()
        
        return loss.item()


# ============================================================================
# 3. Evolutionary Architecture Search
# ============================================================================

Architecture = namedtuple('Architecture', ['arch', 'fitness', 'age'])


class EvolutionarySearch:
    """
    Regularized evolution for architecture search (Real et al., 2019).
    
    Args:
        search_space: Function that returns random architecture
        evaluate_fn: Function to evaluate architecture (returns fitness)
        population_size: Number of architectures in population
        sample_size: Tournament size for selection
        mutation_prob: Probability of mutation
    """
    def __init__(self, search_space, evaluate_fn, population_size=100, 
                 sample_size=25, mutation_prob=1.0):
        self.search_space = search_space
        self.evaluate_fn = evaluate_fn
        self.population_size = population_size
        self.sample_size = sample_size
        self.mutation_prob = mutation_prob
        self.population = []
        self.history = []
    
    def initialize_population(self):
        """Initialize population with random architectures."""
        for _ in range(self.population_size):
            arch = self.search_space()
            fitness = self.evaluate_fn(arch)
            self.population.append(Architecture(arch, fitness, age=0))
        
        print(f"Initialized population: {self.population_size} architectures")
        print(f"Best fitness: {max(a.fitness for a in self.population):.4f}")
    
    def mutate(self, parent_arch):
        """
        Mutate architecture.
        
        Simple mutation: randomly change one operation.
        """
        arch = parent_arch.copy()
        
        if random.random() < self.mutation_prob:
            # Random mutation point
            idx = random.randint(0, len(arch) - 1)
            
            # Random new operation (different from current)
            ops = list(range(len(OPS)))
            ops.remove(arch[idx])
            arch[idx] = random.choice(ops)
        
        return arch
    
    def evolve_step(self):
        """
        Single evolution step (regularized evolution).
        
        1. Sample tournament
        2. Remove oldest
        3. Mutate best
        4. Evaluate and add to population
        """
        # Tournament selection
        sample = random.sample(self.population, self.sample_size)
        
        # Remove oldest from sample
        oldest = max(sample, key=lambda x: x.age)
        self.population.remove(oldest)
        
        # Select best from sample
        parent = max(sample, key=lambda x: x.fitness)
        
        # Mutate
        child_arch = self.mutate(parent.arch)
        
        # Evaluate
        child_fitness = self.evaluate_fn(child_arch)
        
        # Add to population with age 0
        child = Architecture(child_arch, child_fitness, age=0)
        self.population.append(child)
        
        # Age all architectures
        self.population = [Architecture(a.arch, a.fitness, a.age + 1) 
                          for a in self.population]
        
        # Record
        self.history.append(child_fitness)
        
        return child_arch, child_fitness
    
    def search(self, num_iterations=1000):
        """
        Run evolution for num_iterations.
        
        Returns:
            best_arch: Best architecture found
            best_fitness: Fitness of best architecture
        """
        self.initialize_population()
        
        for i in range(num_iterations):
            arch, fitness = self.evolve_step()
            
            if (i + 1) % 100 == 0:
                best_fitness = max(a.fitness for a in self.population)
                print(f"Iteration {i+1}/{num_iterations}, Best: {best_fitness:.4f}, "
                      f"Current: {fitness:.4f}")
        
        # Return best architecture
        best = max(self.population, key=lambda x: x.fitness)
        return best.arch, best.fitness


# ============================================================================
# 4. Performance Predictor
# ============================================================================

class ArchitectureEncoder(nn.Module):
    """
    Encode architecture as fixed-size vector.
    
    Simple approach: Embedding + pooling.
    
    Args:
        num_ops: Number of possible operations
        embed_dim: Embedding dimension
        hidden_dim: Hidden dimension
    """
    def __init__(self, num_ops, embed_dim=32, hidden_dim=64):
        super(ArchitectureEncoder, self).__init__()
        self.embedding = nn.Embedding(num_ops, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        
    def forward(self, arch):
        """
        Args:
            arch: Architecture (list of operation indices) [batch, seq_len]
        Returns:
            Encoding [batch, hidden_dim * 2]
        """
        embedded = self.embedding(arch)  # [batch, seq_len, embed_dim]
        _, (h, c) = self.lstm(embedded)
        encoding = torch.cat([h[-1], c[-1]], dim=-1)  # Concatenate hidden and cell
        return encoding


class PerformancePredictor(nn.Module):
    """
    Predict architecture performance from encoding.
    
    Args:
        num_ops: Number of operations
        embed_dim: Embedding dimension
        hidden_dim: Hidden dimension
    """
    def __init__(self, num_ops, embed_dim=32, hidden_dim=64):
        super(PerformancePredictor, self).__init__()
        self.encoder = ArchitectureEncoder(num_ops, embed_dim, hidden_dim)
        
        # Predictor MLP
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, arch):
        """
        Args:
            arch: Architecture tensor [batch, seq_len]
        Returns:
            Predicted accuracy [batch, 1]
        """
        encoding = self.encoder(arch)
        prediction = self.predictor(encoding)
        return prediction


# ============================================================================
# 5. Demonstrations
# ============================================================================

def demo_darts():
    """Demonstrate DARTS network."""
    print("="*70)
    print("DARTS (Differentiable Architecture Search) Demo")
    print("="*70)
    
    # Create DARTS network
    model = DARTSNetwork(C=16, num_classes=10, layers=8, steps=4, multiplier=4)
    
    # Sample input
    x = torch.randn(2, 3, 32, 32)
    
    # Forward
    model.eval()
    with torch.no_grad():
        logits = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {logits.shape}")
    print()
    
    # Architecture parameters
    arch_params = model.arch_parameters()
    print(f"Architecture parameters:")
    print(f"  Normal cell: {arch_params[0].shape} (edges Γ— operations)")
    print(f"  Reduction cell: {arch_params[1].shape}")
    print()
    
    # Current architecture (continuous)
    normal_weights = F.softmax(arch_params[0], dim=-1)
    print(f"Normal cell operation weights (first edge):")
    for i, op_name in enumerate(OPS.keys()):
        print(f"  {op_name:15s}: {normal_weights[0, i].item():.4f}")
    print()
    
    # Discretize
    genotype = model.genotype()
    print(f"Discretized architecture:")
    print(f"  Normal cell: {genotype['normal'][:4]}")  # Show first 4 edges
    print(f"  Reduction cell: {genotype['reduce'][:4]}")
    print()
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    arch_params_count = sum(p.numel() for p in model.arch_parameters())
    weight_params = total_params - arch_params_count
    
    print(f"Parameters:")
    print(f"  Weights: {weight_params:,}")
    print(f"  Architecture: {arch_params_count:,}")
    print(f"  Total: {total_params:,}")
    print()


def demo_rl_controller():
    """Demonstrate RL-based NAS controller."""
    print("="*70)
    print("RL-Based NAS Controller Demo")
    print("="*70)
    
    # Create controller
    controller = NASController(num_layers=12, num_branches=6, lstm_size=32)
    
    # Sample architectures
    print("Sampling architectures:")
    for i in range(3):
        arch, log_probs, entropies = controller.forward(batch_size=1)
        print(f"  Architecture {i+1}: {arch}")
        print(f"    Log prob sum: {log_probs.sum().item():.4f}")
        print(f"    Entropy: {entropies.mean().item():.4f}")
    print()
    
    # Count parameters
    params = sum(p.numel() for p in controller.parameters())
    print(f"Controller parameters: {params:,}")
    print()
    
    print("REINFORCE update:")
    print("  Sample architecture β†’ Train child network β†’ Get accuracy")
    print("  Update controller to maximize expected accuracy")
    print("  Gradient: βˆ‡_ΞΈ J = E[(R - baseline) * βˆ‡_ΞΈ log Ο€_ΞΈ(Ξ±)]")
    print()


def demo_evolutionary_search():
    """Demonstrate evolutionary architecture search."""
    print("="*70)
    print("Evolutionary Architecture Search Demo")
    print("="*70)
    
    # Define toy search space (list of 6 operations)
    def search_space():
        return [random.randint(0, len(OPS) - 1) for _ in range(6)]
    
    # Toy evaluation (random for demo)
    def evaluate_fn(arch):
        # In practice: train network, return validation accuracy
        # Here: random fitness for demonstration
        return random.uniform(0.5, 1.0)
    
    # Create evolutionary search
    evolution = EvolutionarySearch(
        search_space=search_space,
        evaluate_fn=evaluate_fn,
        population_size=20,  # Small for demo
        sample_size=5,
        mutation_prob=1.0
    )
    
    # Run search
    best_arch, best_fitness = evolution.search(num_iterations=50)
    
    print()
    print(f"Best architecture found: {best_arch}")
    print(f"Best fitness: {best_fitness:.4f}")
    print()
    print("Algorithm:")
    print("  1. Sample tournament (5 architectures)")
    print("  2. Remove oldest from population")
    print("  3. Mutate best architecture")
    print("  4. Evaluate child, add to population")
    print("  5. Repeat")
    print()


def demo_performance_predictor():
    """Demonstrate performance predictor."""
    print("="*70)
    print("Performance Predictor Demo")
    print("="*70)
    
    # Create predictor
    num_ops = len(OPS)
    predictor = PerformancePredictor(num_ops=num_ops, embed_dim=32, hidden_dim=64)
    
    # Sample architectures
    batch_size = 4
    seq_len = 12
    archs = torch.randint(0, num_ops, (batch_size, seq_len))
    
    # Predict
    predictor.eval()
    with torch.no_grad():
        predictions = predictor(archs)
    
    print(f"Input: {batch_size} architectures Γ— {seq_len} operations")
    print(f"Predictions: {predictions.squeeze().tolist()}")
    print()
    
    # Count parameters
    params = sum(p.numel() for p in predictor.parameters())
    print(f"Predictor parameters: {params:,}")
    print()
    
    print("Training procedure:")
    print("  1. Collect dataset: (architecture, accuracy) pairs")
    print("  2. Train predictor via regression (MSE loss)")
    print("  3. Use predictor to evaluate new architectures (fast!)")
    print()
    print("Benefit: Evaluate architecture in O(1) instead of hours of training")
    print()


def print_method_comparison():
    """Print comparison of NAS methods."""
    print("="*70)
    print("NAS Method Comparison")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method          β”‚ Search Cost  β”‚ Performance  β”‚ Difficulty  β”‚ Best For     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Random Search   β”‚ Low          β”‚ Baseline     β”‚ Easy        β”‚ Baseline     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ RL (NAS)        β”‚ Very High    β”‚ High         β”‚ Hard        β”‚ Best quality β”‚
β”‚                 β”‚ (22k GPU-day)β”‚              β”‚             β”‚ (unlimited   β”‚
β”‚                 β”‚              β”‚              β”‚             β”‚ compute)     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Evolution       β”‚ High         β”‚ High         β”‚ Medium      β”‚ Large search β”‚
β”‚                 β”‚ (450 GPU-day)β”‚              β”‚             β”‚ spaces       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ DARTS           β”‚ Low          β”‚ Good         β”‚ Medium      β”‚ Fast search  β”‚
β”‚                 β”‚ (4 GPU-day)  β”‚              β”‚             β”‚ (research)   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ENAS            β”‚ Very Low     β”‚ Good         β”‚ Medium      β”‚ Efficient    β”‚
β”‚                 β”‚ (0.5 GPU-day)β”‚              β”‚             β”‚ search       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Once-For-All    β”‚ Medium       β”‚ High         β”‚ Hard        β”‚ Deployment   β”‚
β”‚                 β”‚ (50 GPU-day) β”‚              β”‚             β”‚ flexibility  β”‚
β”‚                 β”‚ (train once) β”‚              β”‚             β”‚              β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Decision Guide:**

1. **Use Random/Evolutionary if:**
   - Want simple baseline
   - Have moderate compute budget
   - Search space not differentiable

2. **Use DARTS if:**
   - Want fast search (4 GPU-days)
   - Have differentiable search space
   - Doing research (quick iterations)

3. **Use RL if:**
   - Have large compute budget
   - Want best possible architecture
   - Search space is discrete/complex

4. **Use Once-For-All if:**
   - Deploy to many devices
   - Want flexibility (vary latency/accuracy)
   - Can afford one-time training cost

**Performance (ImageNet top-1):**

- Random baseline: ~72%
- NASNet (RL): 82.7%
- AmoebaNet (Evolution): 83.9%
- DARTS: 73.3%
- EfficientNet (NAS): 84.4%
"""
    
    print(comparison)
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    demo_darts()
    demo_rl_controller()
    demo_evolutionary_search()
    demo_performance_predictor()
    print_method_comparison()
    
    print("="*70)
    print("Neural Architecture Search Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ DARTS: Differentiable search via continuous relaxation (4 GPU-days)")
    print("  β€’ RL Controller: REINFORCE for architecture sampling (expensive)")
    print("  β€’ Evolution: Mutation + selection (regularized evolution)")
    print("  β€’ Predictor: Learn to estimate performance (avoid training)")
    print()
    print("Key insight: NAS automates architecture design")
    print("Trade-off: Search cost vs. architecture quality")
    print("Modern trend: Efficient search (weight sharing, one-shot)")
    print()

Advanced Neural Architecture Search: Mathematical Foundations and Modern MethodsΒΆ

1. Introduction to Neural Architecture Search (NAS)ΒΆ

Neural Architecture Search automates the design of neural network architectures, replacing manual engineering with algorithmic optimization. The goal is to find the best architecture \(\mathcal{A}^*\) from a search space \(\mathcal{S}\) that maximizes performance on a validation set.

1.1 NAS Problem FormulationΒΆ

\[\mathcal{A}^* = \arg\max_{\mathcal{A} \in \mathcal{S}} \text{Accuracy}(\mathcal{A}, \mathcal{D}_{\text{val}})\]

subject to:

  • \(\text{Params}(\mathcal{A}) \leq P_{\max}\) (parameter budget)

  • \(\text{Latency}(\mathcal{A}) \leq L_{\max}\) (inference time budget)

  • \(\text{FLOPs}(\mathcal{A}) \leq F_{\max}\) (computational budget)

1.2 Three Components of NASΒΆ

  1. Search Space (\(\mathcal{S}\)): Set of possible architectures

  2. Search Strategy: Algorithm to explore the search space

  3. Performance Estimation: Method to evaluate candidate architectures

1.3 Historical EvolutionΒΆ

Early Era (2016-2017):

  • Reinforcement Learning NAS (Zoph & Le, 2017)

  • Cost: 1800 GPU days for CIFAR-10

  • Search space: Entire network structure

Efficiency Era (2018-2019):

  • ENAS: Weight sharing (1 GPU day)

  • DARTS: Differentiable architecture search

  • ProxylessNAS: Direct hardware optimization

Modern Era (2020-2024):

  • Zero-shot NAS: No training needed

  • Once-for-all networks: Train once, search many times

  • Hardware-aware NAS: Multi-objective optimization

  • Neural architecture transfer: Cross-task learning

2. Search Space DesignΒΆ

2.1 Macro Search SpaceΒΆ

Search for the entire network structure.

Chain-structured: $\(\text{Network} = f_L \circ f_{L-1} \circ ... \circ f_1(x)\)$

Each layer \(f_i\) can be:

  • Convolutional layer (with varying kernel sizes, channels)

  • Pooling layer (max, average)

  • Skip connections

  • Batch normalization, activation

Cell-structured (more efficient):

  • Define reusable cells (normal cell, reduction cell)

  • Stack cells to form network

  • Reduces search space from \(O(L^k)\) to \(O(C^k)\) where \(C \ll L\)

2.2 Micro Search Space (Cell-Based)ΒΆ

Normal Cell: Maintains spatial resolution Reduction Cell: Downsamples feature maps

Each cell is a DAG (Directed Acyclic Graph):

  • \(N\) nodes (intermediate representations)

  • Each node computes: \(h^{(i)} = \sum_{j < i} o^{(i,j)}(h^{(j)})\)

  • Operations \(o^{(i,j)}\) selected from operation set \(\mathcal{O}\)

Operation Set \(\mathcal{O}\):

  • Identity (skip connection)

  • \(3 \times 3\) separable convolution

  • \(5 \times 5\) separable convolution

  • \(3 \times 3\) average pooling

  • \(3 \times 3\) max pooling

  • \(3 \times 3\) dilated convolution

  • Zero (no connection)

2.3 Search Space EncodingΒΆ

Discrete encoding: $\(\mathcal{A} = (o_1, o_2, ..., o_K) \text{ where } o_i \in \mathcal{O}\)$

Continuous encoding (for DARTS): $\(\mathcal{A} = (\alpha_{ij}^{(k)}) \text{ where } \alpha_{ij}^{(k)} \in \mathbb{R}\)$

Graph encoding:

  • Adjacency matrix \(A \in \{0,1\}^{N \times N}\)

  • Operation matrix \(O \in \mathcal{O}^{N \times N}\)

2.4 Search Space SizeΒΆ

For a cell with \(N\) nodes and \(|\mathcal{O}|\) operations:

Without constraints: $\(|\mathcal{S}| = |\mathcal{O}|^{\binom{N}{2}}\)$

Example: \(N=7\) nodes, \(|\mathcal{O}|=8\) operations $\(|\mathcal{S}| = 8^{21} \approx 10^{19}\)$

With constraints (each node has 2 inputs): $\(|\mathcal{S}| = \prod_{i=2}^{N} \binom{i-1}{2} \cdot |\mathcal{O}|^2 \approx 10^{14}\)$

3. Search StrategiesΒΆ

3.1 Reinforcement Learning (RL) Based NASΒΆ

NAS with RL (Zoph & Le, 2017):

Controller: RNN that generates architecture descriptions

Training:

  1. Sample architecture \(\mathcal{A} \sim p_\theta\) from controller with parameters \(\theta\)

  2. Train \(\mathcal{A}\) on training set, evaluate on validation set β†’ accuracy \(R\)

  3. Update controller: \(\nabla_\theta J(\theta) = \mathbb{E}_{\mathcal{A} \sim p_\theta}[R \cdot \nabla_\theta \log p_\theta(\mathcal{A})]\)

REINFORCE gradient: $\(\nabla_\theta J(\theta) \approx \frac{1}{B} \sum_{i=1}^B (R_i - b) \nabla_\theta \log p_\theta(\mathcal{A}_i)\)$

where \(b\) is baseline (moving average of rewards).

Challenges:

  • High variance in gradients

  • Requires training thousands of architectures

  • Extremely expensive (1800 GPU days)

3.2 Evolutionary AlgorithmsΒΆ

Evolution process:

  1. Initialize population \(P = \{\mathcal{A}_1, ..., \mathcal{A}_N\}\)

  2. Evaluate fitness (validation accuracy) for each architecture

  3. Select top-\(k\) architectures (parents)

  4. Generate offspring via mutation/crossover

  5. Repeat until convergence

Mutation operators:

  • Add/remove layer

  • Change operation type

  • Modify hyperparameters (kernel size, channels)

  • Add/remove skip connections

Crossover operators:

  • Combine layers from two parent architectures

  • Cell-level crossover (swap cells between parents)

Advantages:

  • Simple to implement

  • Naturally handles multi-objective optimization

  • No gradient required

Disadvantages:

  • Slow convergence

  • Still requires evaluating many architectures

3.3 Gradient-Based (DARTS)ΒΆ

Key idea: Relax discrete search space to continuous, optimize with gradient descent.

Continuous relaxation: $\(\bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x)\)$

where \(\alpha = \{\alpha_o^{(i,j)}\}\) are architecture parameters.

Bilevel optimization: $\(\begin{align} \min_\alpha \quad & \mathcal{L}_{\text{val}}(w^*(\alpha), \alpha) \\ \text{s.t.} \quad & w^*(\alpha) = \arg\min_w \mathcal{L}_{\text{train}}(w, \alpha) \end{align}\)$

Approximate solution:

  1. Update weights: \(w \leftarrow w - \xi \nabla_w \mathcal{L}_{\text{train}}(w, \alpha)\)

  2. Update architecture: \(\alpha \leftarrow \alpha - \eta \nabla_\alpha \mathcal{L}_{\text{val}}(w, \alpha)\)

Discretization (after search): $\(o^{(i,j)} = \arg\max_{o \in \mathcal{O}} \alpha_o^{(i,j)}\)$

Advantages:

  • Fast (1-2 GPU days)

  • Differentiable end-to-end

  • Naturally handles mixed operations

Challenges:

  • Gap between continuous and discrete architectures

  • Prone to collapse (all weights to skip connections)

  • Biased towards parameter-free operations

3.4 One-Shot NASΒΆ

Weight sharing: Train a supernet containing all possible architectures, then search without retraining.

Supernet: $\(\mathcal{A}_{\text{super}} = \bigcup_{\mathcal{A} \in \mathcal{S}} \mathcal{A}\)$

Training:

  1. Sample architecture \(\mathcal{A} \sim \mathcal{S}\) in each iteration

  2. Update shared weights corresponding to \(\mathcal{A}\)

  3. After training, search by evaluating sub-networks

ENAS (Efficient NAS):

  • Controller samples architectures

  • All architectures share weights

  • Reduces search from 1800 to 0.5 GPU days

RandomNAS:

  • Sample random architectures from supernet

  • Evaluate and pick best

  • No controller needed

3.5 Zero-Shot NASΒΆ

Predict architecture performance without training.

Proxies for performance:

  1. Training speed: \(\partial \mathcal{L} / \partial \text{iter}\) (how fast loss decreases)

  2. Gradient norm: \(\|\nabla_w \mathcal{L}\|\) at initialization

  3. NTK condition number: Neural Tangent Kernel eigenvalue ratio

  4. Synaptic diversity: Variance in weight values

NASWOT (NAS Without Training): $\(\text{Score}(\mathcal{A}) = \log \det \text{Cov}(\text{Jacobian}(\mathcal{A}))\)$

Architectures with higher Jacobian covariance determinant train better.

Advantages:

  • Extremely fast (seconds to minutes)

  • No GPU required for evaluation

  • Can pre-screen thousands of architectures

Limitations:

  • Proxies may not correlate perfectly with final accuracy

  • Still need to train best candidates

4. Performance Estimation StrategiesΒΆ

4.1 Full Training (Baseline)ΒΆ

Train each candidate architecture to convergence on full dataset.

Cost: \(T_{\text{train}} \times N_{\text{candidates}}\)

Example: 50 epochs Γ— 1000 candidates = 50,000 GPU-hours

4.2 Early StoppingΒΆ

Train for fewer epochs (e.g., 5-10) and estimate final performance.

Learning curve extrapolation: $\(\text{Acc}(t) = a - b \cdot t^{-\gamma}\)$

Fit curve from early epochs, predict final accuracy.

Speedup: 5-10Γ— faster

4.3 Weight InheritanceΒΆ

Initialize new architectures with weights from similar architectures.

Network morphism: Transform architecture while preserving function

  • Deepen: \(y = \text{ReLU}(Wx) \to y = \text{ReLU}(W_2 \text{ReLU}(W_1 x))\) with \(W_2 W_1 = W\)

  • Widen: Duplicate neurons with weight splitting

4.4 HypernetworksΒΆ

Learn a model \(h_\phi\) that predicts weights for any architecture: $\(w_\mathcal{A} = h_\phi(\mathcal{A})\)$

Training:

  1. Sample architectures and train them

  2. Learn hypernetwork to map architecture β†’ weights

  3. Use hypernetwork to initialize new architectures

4.5 Performance PredictorsΒΆ

Train a regression model \(f\) to predict accuracy: $\(\hat{\text{Acc}} = f(\text{encode}(\mathcal{A}))\)$

Encodings:

  • Graph neural network on architecture DAG

  • Sequence encoding for layer configurations

  • Path encoding (from input to output)

Training:

  1. Evaluate \(N\) random architectures

  2. Train predictor on \(\{(\mathcal{A}_i, \text{Acc}_i)\}\)

  3. Use predictor to screen new candidates

5. Multi-Objective NASΒΆ

Optimize multiple objectives simultaneously: $\(\max_{\mathcal{A} \in \mathcal{S}} \{f_1(\mathcal{A}), f_2(\mathcal{A}), ..., f_k(\mathcal{A})\}\)$

where:

  • \(f_1\): Accuracy

  • \(f_2\): \(-\)Latency (negative for maximization)

  • \(f_3\): \(-\)Parameters

  • \(f_4\): \(-\)Energy consumption

5.1 Pareto OptimalityΒΆ

Architecture \(\mathcal{A}\) is Pareto optimal if no other architecture dominates it in all objectives.

Pareto front: Set of all Pareto optimal solutions.

5.2 Weighted SumΒΆ

\[\max_{\mathcal{A}} \sum_{i=1}^k w_i f_i(\mathcal{A})\]

where \(\sum_i w_i = 1\), \(w_i \geq 0\).

Limitation: Cannot find all Pareto optimal solutions (only convex parts).

5.3 Evolutionary Multi-ObjectiveΒΆ

NSGA-II (Non-dominated Sorting Genetic Algorithm):

  1. Rank population by domination (non-dominated = rank 1)

  2. Compute crowding distance (diversity in objective space)

  3. Select based on rank and crowding distance

  4. Generate offspring, repeat

5.4 Hardware-Aware NASΒΆ

Latency modeling: $\(\text{Latency}(\mathcal{A}) = \sum_{\text{layer } i} f_{\text{hw}}(\text{layer}_i)\)$

where \(f_{\text{hw}}\) is measured or predicted latency on target hardware.

Direct optimization:

  • ProxylessNAS: Latency as regularization term

  • FBNet: Latency lookup table for each operation

  • Once-for-All: Elastic networks for different constraints

6. Advanced NAS Methods (2020-2024)ΒΆ

6.1 Once-for-All Networks (OFA)ΒΆ

Key idea: Train a single network that supports diverse architectures.

Elastic dimensions:

  • Depth: 2-5 blocks per stage

  • Width: 0.5Γ—, 0.75Γ—, 1.0Γ— channels

  • Kernel size: 3Γ—3, 5Γ—5, 7Γ—7

  • Resolution: 128, 160, 192, 224 pixels

Progressive shrinking:

  1. Train largest network (full capacity)

  2. Progressively support smaller elastic dimensions

  3. Each sub-network shares weights with supernet

Deployment:

  • Given constraint (latency < 20ms)

  • Evolutionary search in trained OFA

  • Extract sub-network without retraining

Benefits:

  • Train once (1200 GPU-hours)

  • Search many times (40 GPU-hours each)

  • Amortized cost very low

6.2 Neural Architecture TransferΒΆ

Pre-train on source task, transfer to target task.

Weight sharing across tasks:

  • Learn task-agnostic features in supernet

  • Fine-tune for specific task

  • Reduces search cost on new tasks

Meta-learning for NAS: $\(\theta^* = \arg\min_\theta \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} [\mathcal{L}_{\text{NAS}}(\theta, \mathcal{T})]\)$

Train NAS controller on distribution of tasks.

6.3 AutoML FrameworksΒΆ

NAS-Bench: Benchmark datasets with pre-computed architectures

  • NAS-Bench-101: 423k architectures on CIFAR-10

  • NAS-Bench-201: 6k+ architectures, multiple datasets

  • Enables reproducible NAS research

NNI (Neural Network Intelligence):

  • Microsoft framework for AutoML

  • Supports RL, evolution, DARTS, ENAS

  • Built-in model compression

Auto-Keras:

  • High-level API for NAS

  • Bayesian optimization

  • Automatically handles data preprocessing

6.4 Transformer NASΒΆ

Search for optimal Transformer architectures.

Search dimensions:

  • Number of layers

  • Hidden dimension

  • Number of attention heads

  • FFN expansion ratio

  • Attention pattern (full, sparse, local)

AutoFormer (Chen et al., 2021):

  • Weight-sharing supernet for Transformers

  • Evolutionary search

  • Found architectures competitive with BERT

Primer (So et al., 2021):

  • Evolution-based search for Transformer primitives

  • Discovered squared ReLU activation

  • Modified attention mechanisms

6.5 Vision Transformer NASΒΆ

AutoViT:

  • Search for ViT patch size, embedding dimension, depth

  • Hardware-aware optimization

  • 2-3Γ— faster than DeiT with same accuracy

BossNAS:

  • Billion-scale search space for vision

  • Mixed CNN and Transformer operations

  • State-of-the-art on ImageNet

7. Theoretical FoundationsΒΆ

7.1 Neural Tangent Kernel (NTK) for NASΒΆ

NTK characterizes network behavior during training: $\(\Theta(x, x') = \langle \nabla_w f(x, w), \nabla_w f(x', w) \rangle\)$

At initialization, \(\Theta\) determines:

  • Convergence speed

  • Final performance

  • Generalization

NAS application: Architectures with better-conditioned NTK train faster and generalize better.

7.2 Loss Landscape AnalysisΒΆ

Sharpness: Eigenvalues of Hessian \(\nabla^2 \mathcal{L}(w)\)

Sharp minima: Large eigenvalues β†’ poor generalization
Flat minima: Small eigenvalues β†’ good generalization

NAS metric: Prefer architectures with flatter loss landscapes.

7.3 Expressivity vs TrainabilityΒΆ

Expressivity: Can the architecture represent the target function?

Universal approximation: Most architectures are expressive enough for practical tasks.

Trainability: Can we efficiently find good parameters via gradient descent?

Key insight: NAS should focus on trainability, not just expressivity.

7.4 Architecture-Training Co-DesignΒΆ

Traditional: Architecture search β†’ Train
Modern: Interleave architecture search and training

Benefits:

  • Architecture adapts to current training stage

  • Better exploration of architecture-weight space

  • Avoids local optima in joint space

8. NAS for Specialized DomainsΒΆ

8.1 NAS for Object DetectionΒΆ

Search space:

  • Backbone architecture (feature extractor)

  • Feature pyramid network (FPN) connections

  • Detection head architecture

DetNAS: NAS for object detection

  • Search backbone and FPN jointly

  • Multi-scale architecture

  • COCO mAP improvements

SpineNet: Learned scale-permuted networks

  • Cross-scale connections learned via NAS

  • Better than hand-designed FPNs

8.2 NAS for Semantic SegmentationΒΆ

Auto-DeepLab:

  • Search cell structure and network-level connections

  • Hierarchical search space

  • State-of-the-art on Cityscapes

Search dimensions:

  • Encoder-decoder architecture

  • Skip connections across scales

  • Atrous convolution rates

8.3 NAS for Speech RecognitionΒΆ

Evolved Transformer for ASR:

  • Evolution-based search

  • Discovered architectural improvements over baseline Transformer

  • Lower WER on LibriSpeech

Neural Architecture Search for LSTMs:

  • Search for optimal LSTM cell structure

  • Found better alternatives to standard LSTM gates

8.4 NAS for Recommender SystemsΒΆ

AutoCTR:

  • Search feature interaction architectures

  • Embedding dimensions per feature

  • Improved CTR prediction

AMER:

  • Automated model search for recommendations

  • User-item interaction modeling

  • Better than hand-crafted models

9. Practical ConsiderationsΒΆ

9.1 Search Cost AnalysisΒΆ

Total cost: $\(\text{Cost} = N_{\text{arch}} \times T_{\text{train}} \times C_{\text{data}}\)$

where:

  • \(N_{\text{arch}}\): Number of architectures evaluated

  • \(T_{\text{train}}\): Training time per architecture

  • \(C_{\text{data}}\): Data cost (labels, compute)

Reduction strategies:

  1. Early stopping: Reduce \(T_{\text{train}}\) by 5-10Γ—

  2. Weight sharing: Reduce \(N_{\text{arch}} \times T_{\text{train}}\) to \(T_{\text{supernet}}\)

  3. Zero-shot: Reduce \(T_{\text{train}}\) to near zero

  4. Proxy datasets: Reduce \(C_{\text{data}}\) (CIFAR-10 instead of ImageNet)

9.2 ReproducibilityΒΆ

Key factors:

  • Random seed for architecture sampling

  • Training hyperparameters (learning rate, batch size)

  • Hardware (GPU type affects measured latency)

  • Data augmentation policy

  • Early stopping criteria

Best practices:

  • Report search algorithm details

  • Provide code and search logs

  • Multiple runs with different seeds

  • Report mean and variance of top architectures

9.3 Fairness in ComparisonΒΆ

Common pitfalls:

  • Different training budgets (epochs, data)

  • Different search spaces

  • Cherry-picking best result from multiple runs

  • Post-search hyperparameter tuning

Fair comparison:

  • Same search space

  • Same total compute budget

  • Report average performance, not just best

  • Separate search and evaluation sets

9.4 When to Use NASΒΆ

Use NAS when: βœ“ Have sufficient compute budget (>100 GPU-hours)
βœ“ Target hardware differs from standard (edge devices, mobile)
βœ“ Task-specific architecture requirements
βœ“ Need to beat hand-designed baselines
βœ“ Can amortize search cost across many deployments

Avoid NAS when: βœ— Limited compute budget (<10 GPU-hours)
βœ— Standard architectures work well (ImageNet classification)
βœ— Small dataset (overfitting risk in search)
βœ— One-time deployment (search cost not amortized)

10. Software Tools and FrameworksΒΆ

10.1 AutoGluonΒΆ

from autogluon.vision import ImagePredictor
predictor = ImagePredictor()
predictor.fit('train/', time_limit=3600)  # 1 hour

Features:

  • Automatic architecture search

  • Hyperparameter optimization

  • Ensemble of models

  • Easy deployment

10.2 NNI (Neural Network Intelligence)ΒΆ

Supported algorithms:

  • DARTS, ENAS, SPOS, ProxylessNAS

  • Hyperparameter tuning (TPE, BOHB)

  • Model compression (pruning, quantization)

Workflow:

  1. Define search space (JSON)

  2. Implement training code with NNI API

  3. Configure search algorithm

  4. Launch experiment, monitor progress

10.3 NAS-Bench IntegrationΒΆ

Pre-computed results for fast experimentation:

from nas_bench_201 import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth')
arch = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|'
info = api.get_more_info(arch, 'cifar10', hp='200')
print(f"Accuracy: {info['test-accuracy']}")

No training neededβ€”instant results!

10.4 Custom NAS ImplementationΒΆ

Minimal DARTS:

# Architecture parameters
alpha = nn.Parameter(torch.randn(n_edges, n_ops))

# Mixed operation
def mixed_op(x, alpha):
    return sum(w * op(x) for w, op in zip(F.softmax(alpha), ops))

# Bilevel optimization
for epoch in range(epochs):
    # Update weights
    optimizer_w.zero_grad()
    loss_train = loss(model(x_train, alpha), y_train)
    loss_train.backward()
    optimizer_w.step()
    
    # Update architecture
    optimizer_alpha.zero_grad()
    loss_val = loss(model(x_val, alpha), y_val)
    loss_val.backward()
    optimizer_alpha.step()

12. Mathematical Deep DiveΒΆ

12.1 DARTS Gradient ApproximationΒΆ

Bilevel optimization: $\(\nabla_\alpha \mathcal{L}_{\text{val}}(w^*(\alpha), \alpha)\)$

Chain rule: $\(= \nabla_\alpha \mathcal{L}_{\text{val}} + \nabla_\alpha w^* \cdot \nabla_w \mathcal{L}_{\text{val}}\)$

Implicit function theorem: $\(\nabla_\alpha w^* = -[\nabla^2_{ww} \mathcal{L}_{\text{train}}]^{-1} \nabla^2_{w\alpha} \mathcal{L}_{\text{train}}\)$

Approximation (avoid expensive Hessian inverse): $\(\nabla_\alpha w^* \approx -\frac{\nabla_w \mathcal{L}_{\text{val}} \nabla^2_{w\alpha} \mathcal{L}_{\text{train}}}{\|\nabla_w \mathcal{L}_{\text{train}}\|^2}\)$

12.2 Expected Improvement for Bayesian NASΒΆ

Acquisition function: $\(\alpha_{\text{EI}}(\mathcal{A}) = \mathbb{E}[\max(f(\mathcal{A}) - f(\mathcal{A}_{\text{best}}), 0)]\)$

where \(f\) is surrogate model (Gaussian Process).

Closed form (for GP): $\(\alpha_{\text{EI}}(\mathcal{A}) = (\mu(\mathcal{A}) - f_{\text{best}}) \Phi(Z) + \sigma(\mathcal{A}) \phi(Z)\)$

where:

  • \(Z = (\mu(\mathcal{A}) - f_{\text{best}}) / \sigma(\mathcal{A})\)

  • \(\Phi\): CDF of standard normal

  • \(\phi\): PDF of standard normal

12.3 Evolutionary Algorithm ConvergenceΒΆ

Schema theorem (genetic algorithms):

Expected number of schema \(H\) in next generation: $\(m(H, t+1) \geq m(H, t) \frac{f(H)}{\bar{f}} \left[1 - p_c \frac{d(H)}{l-1} - p_m o(H)\right]\)$

where:

  • \(f(H)\): Average fitness of schema

  • \(\bar{f}\): Population average fitness

  • \(p_c\): Crossover probability

  • \(d(H)\): Defining length of schema

  • \(p_m\): Mutation probability

  • \(o(H)\): Order of schema (number of fixed positions)

Insight: Short, low-order, high-fitness schemas grow exponentially.

13. Case StudiesΒΆ

13.1 EfficientNet (NAS + Scaling)ΒΆ

Compound scaling: $\(\text{depth: } d = \alpha^\phi, \quad \text{width: } w = \beta^\phi, \quad \text{resolution: } r = \gamma^\phi\)$

subject to \(\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2\)

NAS for base model:

  • Search on CIFAR-10 (proxy task)

  • Transfer to ImageNet

  • Scale up with compound scaling

Results:

  • EfficientNet-B0: 77.1% top-1, 5.3M params

  • EfficientNet-B7: 84.3% top-1, 66M params

  • State-of-the-art accuracy-efficiency trade-off

13.2 GPT-4 Architecture (Rumored)ΒΆ

Mixture of Experts via NAS:

  • Each layer routes tokens to different experts

  • Experts specialized for different input types

  • Routing learned during pre-training

Search dimensions:

  • Number of experts per layer

  • Expert capacity

  • Routing algorithm

  • Load balancing strategy

Benefits:

  • 1T+ parameters with manageable compute

  • Better quality than dense models

  • Efficient inference via sparse activation

13.3 AlphaCode ArchitectureΒΆ

Multi-stage NAS:

  1. Search encoder architecture (code understanding)

  2. Search decoder architecture (code generation)

  3. Search sampling strategies

Novel components discovered:

  • Specialized attention for code syntax

  • Multi-scale code embeddings

  • Hybrid convolution-transformer blocks

14. Summary and Best PracticesΒΆ

Key Takeaways:ΒΆ

  1. Search space design is crucial: Constraining search space reduces cost and improves results

  2. Weight sharing enables efficiency: Train once, search many times

  3. Multi-objective matters: Accuracy alone insufficient for deployment

  4. Zero-shot methods emerging: Performance prediction without training

  5. Hardware-aware is essential: Latency/energy often more important than FLOPs

When to Use Which Method:ΒΆ

Scenario

Recommended Method

Justification

First-time NAS

DARTS or ENAS

Good balance of speed and quality

Limited budget (<1 GPU-day)

Zero-shot NAS

No training needed

Mobile deployment

Once-for-All or ProxylessNAS

Hardware-aware optimization

Novel task

Evolutionary NAS

Flexible, no gradient required

Production at scale

Once-for-All

Train once, deploy many configs

Research/benchmarking

NAS-Bench

Reproducible, fast experiments

Common Pitfalls:ΒΆ

❌ Overfitting to search space: Good NAS result doesn’t mean optimal architecture
❌ Ignoring hardware: FLOPs β‰  latency
❌ Unfair comparisons: Different training budgets invalidate results
❌ Not checking transferability: Search on CIFAR, deploy on ImageNet?
❌ Forgetting deployment constraints: Memory, power, latency requirements

Future Research Directions:ΒΆ

  1. Efficient large-scale NAS: Search for billion-parameter models

  2. Task-agnostic architectures: Find architectures that work across tasks

  3. Architecture-data co-design: Jointly optimize data and architecture

  4. Continual architecture search: Adapt architecture as data distribution shifts

  5. Certified robustness via NAS: Search for adversarially robust architectures

Conclusion: NAS has matured from expensive academic curiosity to practical tool. Modern methods (OFA, zero-shot) make NAS accessible for real-world deployment. The future lies in scaling NAS to foundation models and making it fully automated.

"""
Advanced Neural Architecture Search - Production Implementation
Comprehensive implementations of modern NAS methods

Methods Implemented:
1. DARTS (Differentiable Architecture Search)
2. ENAS (Efficient Neural Architecture Search with weight sharing)
3. RandomNAS (baseline with weight-sharing supernet)
4. Evolutionary NAS (genetic algorithm)
5. Zero-Shot NAS (NASWOT - NAS Without Training)
6. Performance Predictors (GNN-based)
7. Once-for-All Networks (elastic supernet)
8. Multi-Objective NAS (NSGA-II)

Features:
- Cell-based search space (NAS-Bench-201 compatible)
- Hardware-aware latency modeling
- Multi-objective optimization (accuracy, latency, parameters)
- Visualization and analysis tools

Author: Advanced Deep Learning Course
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Callable
import numpy as np
import random
from collections import defaultdict
import time


# ============================================================================
# 1. Search Space Definition (Cell-Based)
# ============================================================================

class SearchSpace:
    """
    Cell-based search space similar to NAS-Bench-201.
    
    Each cell is a DAG with N nodes, each node connected to previous nodes
    via operations from a predefined operation set.
    """
    
    PRIMITIVES = [
        'none',           # No connection
        'skip_connect',   # Identity
        'conv_1x1',       # 1Γ—1 convolution
        'conv_3x3',       # 3Γ—3 convolution
        'avg_pool_3x3',   # 3Γ—3 average pooling
        'max_pool_3x3',   # 3Γ—3 max pooling
    ]
    
    def __init__(self, n_nodes: int = 4):
        """
        Args:
            n_nodes: Number of intermediate nodes in cell
        """
        self.n_nodes = n_nodes
        self.n_ops = len(self.PRIMITIVES)
        
        # Compute search space size
        # Each node i connects to 2 predecessors from {0, 1, ..., i-1}
        # Each connection has n_ops choices
        self.n_edges = sum(range(2, n_nodes + 2))  # (2+3+...+n_nodes+1)
        self.search_space_size = self.n_ops ** self.n_edges
        
        print(f"Search Space: {self.n_nodes} nodes, {self.n_edges} edges")
        print(f"Space size: {self.search_space_size:.2e} architectures")
    
    def random_architecture(self) -> List[int]:
        """Sample random architecture (list of operation indices)."""
        return [random.randint(0, self.n_ops - 1) for _ in range(self.n_edges)]
    
    def encode_architecture(self, arch: List[int]) -> str:
        """Encode architecture as string."""
        ops = [self.PRIMITIVES[i] for i in arch]
        return '|'.join(ops)
    
    def decode_architecture(self, arch_str: str) -> List[int]:
        """Decode architecture string to indices."""
        ops = arch_str.split('|')
        return [self.PRIMITIVES.index(op) for op in ops]
    
    def mutate(self, arch: List[int], n_mutations: int = 1) -> List[int]:
        """Randomly mutate architecture."""
        arch = arch.copy()
        for _ in range(n_mutations):
            idx = random.randint(0, len(arch) - 1)
            arch[idx] = random.randint(0, self.n_ops - 1)
        return arch
    
    def crossover(self, arch1: List[int], arch2: List[int]) -> Tuple[List[int], List[int]]:
        """Single-point crossover between two architectures."""
        point = random.randint(1, len(arch1) - 1)
        child1 = arch1[:point] + arch2[point:]
        child2 = arch2[:point] + arch1[point:]
        return child1, child2


# ============================================================================
# 2. Operations (Building Blocks)
# ============================================================================

class ReLUConvBN(nn.Module):
    """ReLU + Conv + BatchNorm"""
    def __init__(self, C_in: int, C_out: int, kernel_size: int, stride: int = 1, padding: int = 0):
        super().__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(C_out)
        )
    
    def forward(self, x):
        return self.op(x)


class OPS:
    """Operation factory"""
    
    @staticmethod
    def get_op(op_name: str, C: int, stride: int = 1) -> nn.Module:
        """
        Get operation by name.
        
        Args:
            op_name: Operation name from PRIMITIVES
            C: Number of channels
            stride: Stride (1 for normal cell, 2 for reduction cell)
        """
        if op_name == 'none':
            return Zero(stride)
        elif op_name == 'skip_connect':
            if stride == 1:
                return nn.Identity()
            else:
                return FactorizedReduce(C, C)
        elif op_name == 'conv_1x1':
            return ReLUConvBN(C, C, 1, stride=stride, padding=0)
        elif op_name == 'conv_3x3':
            return ReLUConvBN(C, C, 3, stride=stride, padding=1)
        elif op_name == 'avg_pool_3x3':
            if stride == 1:
                return nn.AvgPool2d(3, stride=1, padding=1)
            else:
                return nn.AvgPool2d(3, stride=stride, padding=1)
        elif op_name == 'max_pool_3x3':
            if stride == 1:
                return nn.MaxPool2d(3, stride=1, padding=1)
            else:
                return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            raise ValueError(f"Unknown operation: {op_name}")


class Zero(nn.Module):
    """Zero operation (no connection)"""
    def __init__(self, stride):
        super().__init__()
        self.stride = stride
    
    def forward(self, x):
        if self.stride == 1:
            return x.mul(0.)
        return x[:, :, ::self.stride, ::self.stride].mul(0.)


class FactorizedReduce(nn.Module):
    """Factorized reduction for stride=2 skip connections"""
    def __init__(self, C_in: int, C_out: int):
        super().__init__()
        assert C_out % 2 == 0
        self.relu = nn.ReLU(inplace=False)
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(C_out)
    
    def forward(self, x):
        x = self.relu(x)
        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
        out = self.bn(out)
        return out


# ============================================================================
# 3. Cell and Network Architecture
# ============================================================================

class Cell(nn.Module):
    """
    Cell module (DAG of operations).
    
    Each node computes: h_i = sum_{j < i} op_{i,j}(h_j)
    """
    def __init__(
        self,
        n_nodes: int,
        C: int,
        arch: List[int],
        primitives: List[str],
        stride: int = 1
    ):
        super().__init__()
        self.n_nodes = n_nodes
        self.primitives = primitives
        
        # Build edges (operations)
        self.edges = nn.ModuleDict()
        edge_idx = 0
        for i in range(2, n_nodes + 2):  # Nodes 2, 3, ..., n_nodes+1
            for j in range(i):  # Connect to all previous nodes
                op_name = primitives[arch[edge_idx]]
                op = OPS.get_op(op_name, C, stride if j < 2 else 1)
                self.edges[f'{i}_{j}'] = op
                edge_idx += 1
    
    def forward(self, x):
        """
        Args:
            x: Input tensor [batch, C, H, W]
        Returns:
            output: [batch, C, H', W']
        """
        # Initialize nodes: node 0 and 1 are input
        nodes = [x, x]
        
        # Compute each node
        for i in range(2, self.n_nodes + 2):
            node_sum = sum(self.edges[f'{i}_{j}'](nodes[j]) for j in range(i))
            nodes.append(node_sum)
        
        # Output: concatenate all intermediate nodes
        return sum(nodes[2:]) / len(nodes[2:])


class Network(nn.Module):
    """
    Complete network with stacked cells.
    
    Architecture: Stem β†’ [Normal Cell Γ— N] β†’ Reduction β†’ [Normal Cell Γ— N] β†’ Reduction β†’ [Normal Cell Γ— N] β†’ Classifier
    """
    def __init__(
        self,
        C: int,
        n_classes: int,
        n_layers: int,
        arch: List[int],
        n_nodes: int = 4,
        primitives: List[str] = SearchSpace.PRIMITIVES
    ):
        super().__init__()
        self.C = C
        self.n_layers = n_layers
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, padding=1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # Cells
        self.cells = nn.ModuleList()
        C_curr = C
        reduction_layers = [n_layers // 3, 2 * n_layers // 3]
        
        for i in range(n_layers):
            if i in reduction_layers:
                C_curr *= 2
                cell = Cell(n_nodes, C_curr // 2, arch, primitives, stride=2)
            else:
                cell = Cell(n_nodes, C_curr, arch, primitives, stride=1)
            self.cells.append(cell)
        
        # Classifier
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_curr, n_classes)
    
    def forward(self, x):
        x = self.stem(x)
        for cell in self.cells:
            x = cell(x)
        x = self.global_pooling(x)
        x = x.view(x.size(0), -1)
        logits = self.classifier(x)
        return logits
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())


# ============================================================================
# 4. DARTS (Differentiable Architecture Search)
# ============================================================================

class MixedOp(nn.Module):
    """Mixed operation with continuous relaxation."""
    def __init__(self, C: int, primitives: List[str], stride: int = 1):
        super().__init__()
        self.ops = nn.ModuleList([OPS.get_op(prim, C, stride) for prim in primitives])
    
    def forward(self, x, weights):
        """
        Args:
            x: Input tensor
            weights: Softmax weights [n_ops]
        Returns:
            Weighted sum of operations
        """
        return sum(w * op(x) for w, op in zip(weights, self.ops))


class DARTSCell(nn.Module):
    """DARTS cell with continuous architecture parameters."""
    def __init__(self, n_nodes: int, C: int, primitives: List[str], stride: int = 1):
        super().__init__()
        self.n_nodes = n_nodes
        self.primitives = primitives
        
        self.edges = nn.ModuleDict()
        for i in range(2, n_nodes + 2):
            for j in range(i):
                self.edges[f'{i}_{j}'] = MixedOp(C, primitives, stride if j < 2 else 1)
    
    def forward(self, x, alpha):
        """
        Args:
            x: Input
            alpha: Architecture parameters [n_edges, n_ops]
        Returns:
            Cell output
        """
        nodes = [x, x]
        edge_idx = 0
        
        for i in range(2, self.n_nodes + 2):
            edges_to_i = []
            for j in range(i):
                weights = F.softmax(alpha[edge_idx], dim=0)
                edges_to_i.append(self.edges[f'{i}_{j}'](nodes[j], weights))
                edge_idx += 1
            nodes.append(sum(edges_to_i))
        
        return sum(nodes[2:]) / len(nodes[2:])


class DARTSNetwork(nn.Module):
    """DARTS supernet with learnable architecture parameters."""
    def __init__(
        self,
        C: int,
        n_classes: int,
        n_layers: int,
        n_nodes: int = 4,
        primitives: List[str] = SearchSpace.PRIMITIVES
    ):
        super().__init__()
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.primitives = primitives
        
        # Architecture parameters
        n_edges = sum(range(2, n_nodes + 2))
        self.alphas = nn.Parameter(torch.randn(n_edges, len(primitives)) * 1e-3)
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, padding=1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # Cells
        self.cells = nn.ModuleList()
        C_curr = C
        reduction_layers = [n_layers // 3, 2 * n_layers // 3]
        
        for i in range(n_layers):
            if i in reduction_layers:
                C_curr *= 2
                cell = DARTSCell(n_nodes, C_curr // 2, primitives, stride=2)
            else:
                cell = DARTSCell(n_nodes, C_curr, primitives, stride=1)
            self.cells.append(cell)
        
        # Classifier
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_curr, n_classes)
    
    def forward(self, x):
        x = self.stem(x)
        for cell in self.cells:
            x = cell(x, self.alphas)
        x = self.global_pooling(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
    
    def discretize(self) -> List[int]:
        """Extract discrete architecture from continuous parameters."""
        return [self.primitives[i.item()] for i in self.alphas.argmax(dim=1)]
    
    def get_arch_indices(self) -> List[int]:
        """Get architecture as list of operation indices."""
        return self.alphas.argmax(dim=1).tolist()


def train_darts(
    model: DARTSNetwork,
    train_loader,
    val_loader,
    n_epochs: int = 50,
    lr_w: float = 0.025,
    lr_alpha: float = 3e-4,
    device: str = 'cuda'
):
    """
    Train DARTS model with bilevel optimization.
    
    Args:
        model: DARTS supernet
        train_loader: Training data
        val_loader: Validation data
        n_epochs: Number of epochs
        lr_w: Learning rate for weights
        lr_alpha: Learning rate for architecture parameters
    """
    model = model.to(device)
    
    # Optimizers
    optimizer_w = torch.optim.SGD(
        [p for n, p in model.named_parameters() if n != 'alphas'],
        lr=lr_w, momentum=0.9, weight_decay=3e-4
    )
    optimizer_alpha = torch.optim.Adam([model.alphas], lr=lr_alpha)
    
    scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_w, n_epochs)
    
    for epoch in range(n_epochs):
        model.train()
        
        for (x_train, y_train), (x_val, y_val) in zip(train_loader, val_loader):
            x_train, y_train = x_train.to(device), y_train.to(device)
            x_val, y_val = x_val.to(device), y_val.to(device)
            
            # Update weights
            optimizer_w.zero_grad()
            logits = model(x_train)
            loss_train = F.cross_entropy(logits, y_train)
            loss_train.backward()
            optimizer_w.step()
            
            # Update architecture
            optimizer_alpha.zero_grad()
            logits_val = model(x_val)
            loss_val = F.cross_entropy(logits_val, y_val)
            loss_val.backward()
            optimizer_alpha.step()
        
        scheduler_w.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {loss_train.item():.4f}, Val Loss: {loss_val.item():.4f}")
    
    return model.get_arch_indices()


# ============================================================================
# 5. Evolutionary NAS
# ============================================================================

class EvolutionaryNAS:
    """
    Evolutionary algorithm for NAS.
    
    Uses tournament selection, mutation, and crossover to evolve population.
    """
    def __init__(
        self,
        search_space: SearchSpace,
        population_size: int = 50,
        n_generations: int = 20,
        tournament_size: int = 5,
        mutation_prob: float = 0.1,
        crossover_prob: float = 0.5
    ):
        self.search_space = search_space
        self.population_size = population_size
        self.n_generations = n_generations
        self.tournament_size = tournament_size
        self.mutation_prob = mutation_prob
        self.crossover_prob = crossover_prob
        
        self.population = []
        self.fitness = {}
    
    def initialize_population(self):
        """Initialize random population."""
        self.population = [
            self.search_space.random_architecture()
            for _ in range(self.population_size)
        ]
    
    def tournament_selection(self) -> List[int]:
        """Select architecture via tournament selection."""
        candidates = random.sample(self.population, self.tournament_size)
        return max(candidates, key=lambda arch: self.fitness[self.search_space.encode_architecture(arch)])
    
    def evolve(self, fitness_fn: Callable[[List[int]], float]) -> List[int]:
        """
        Run evolutionary algorithm.
        
        Args:
            fitness_fn: Function that evaluates architecture and returns fitness (accuracy)
        
        Returns:
            best_arch: Best architecture found
        """
        self.initialize_population()
        
        # Evaluate initial population
        print("Evaluating initial population...")
        for arch in self.population:
            arch_str = self.search_space.encode_architecture(arch)
            if arch_str not in self.fitness:
                self.fitness[arch_str] = fitness_fn(arch)
        
        # Evolution
        for gen in range(self.n_generations):
            new_population = []
            
            # Elitism: keep top 10%
            elite_size = self.population_size // 10
            elite = sorted(self.population, 
                         key=lambda a: self.fitness[self.search_space.encode_architecture(a)],
                         reverse=True)[:elite_size]
            new_population.extend(elite)
            
            # Generate offspring
            while len(new_population) < self.population_size:
                # Selection
                parent1 = self.tournament_selection()
                parent2 = self.tournament_selection()
                
                # Crossover
                if random.random() < self.crossover_prob:
                    child1, child2 = self.search_space.crossover(parent1, parent2)
                else:
                    child1, child2 = parent1.copy(), parent2.copy()
                
                # Mutation
                if random.random() < self.mutation_prob:
                    child1 = self.search_space.mutate(child1)
                if random.random() < self.mutation_prob:
                    child2 = self.search_space.mutate(child2)
                
                new_population.extend([child1, child2])
            
            self.population = new_population[:self.population_size]
            
            # Evaluate new architectures
            for arch in self.population:
                arch_str = self.search_space.encode_architecture(arch)
                if arch_str not in self.fitness:
                    self.fitness[arch_str] = fitness_fn(arch)
            
            # Report progress
            best_fitness = max(self.fitness[self.search_space.encode_architecture(a)] 
                             for a in self.population)
            print(f"Generation {gen+1}/{self.n_generations}, Best Fitness: {best_fitness:.4f}")
        
        # Return best architecture
        best_arch = max(self.population, 
                       key=lambda a: self.fitness[self.search_space.encode_architecture(a)])
        return best_arch


# ============================================================================
# 6. Zero-Shot NAS (NASWOT)
# ============================================================================

def compute_naswot_score(model: nn.Module, dataloader, device: str = 'cuda') -> float:
    """
    Compute NASWOT score: log det of Jacobian covariance.
    
    Higher score β†’ better architecture (without training).
    
    Args:
        model: Network to evaluate
        dataloader: Data for computing Jacobian
        device: Device
    
    Returns:
        score: NASWOT score
    """
    model = model.to(device)
    model.train()
    
    # Get batch
    x, _ = next(iter(dataloader))
    x = x.to(device)
    batch_size = x.size(0)
    
    # Forward pass
    logits = model(x)
    
    # Compute Jacobian for each sample
    jacobians = []
    for i in range(batch_size):
        model.zero_grad()
        logits[i].sum().backward(retain_graph=True)
        
        # Collect gradients (flatten)
        grad = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])
        jacobians.append(grad)
    
    # Stack and compute covariance
    J = torch.stack(jacobians)  # [batch, n_params]
    
    # Regularized covariance
    cov = J.T @ J / batch_size + torch.eye(J.size(1), device=device) * 1e-5
    
    # Log determinant (score)
    sign, logdet = torch.slogdet(cov)
    score = logdet.item() if sign > 0 else -float('inf')
    
    return score


# ============================================================================
# 7. Multi-Objective NAS (NSGA-II)
# ============================================================================

class MultiObjectiveNAS:
    """
    Multi-objective NAS using NSGA-II.
    
    Optimizes: accuracy, -latency, -parameters
    """
    def __init__(
        self,
        search_space: SearchSpace,
        population_size: int = 50,
        n_generations: int = 20
    ):
        self.search_space = search_space
        self.population_size = population_size
        self.n_generations = n_generations
        
        self.population = []
        self.objectives = {}  # arch_str -> [acc, -latency, -params]
    
    def dominates(self, obj1: List[float], obj2: List[float]) -> bool:
        """Check if obj1 dominates obj2 (better in all objectives)."""
        return all(o1 >= o2 for o1, o2 in zip(obj1, obj2)) and any(o1 > o2 for o1, o2 in zip(obj1, obj2))
    
    def non_dominated_sort(self) -> List[List[List[int]]]:
        """Non-dominated sorting of population."""
        fronts = [[]]
        domination_count = defaultdict(int)
        dominated_solutions = defaultdict(list)
        
        # Compute domination
        for i, arch1 in enumerate(self.population):
            obj1 = self.objectives[self.search_space.encode_architecture(arch1)]
            for j, arch2 in enumerate(self.population):
                if i == j:
                    continue
                obj2 = self.objectives[self.search_space.encode_architecture(arch2)]
                
                if self.dominates(obj1, obj2):
                    dominated_solutions[i].append(j)
                elif self.dominates(obj2, obj1):
                    domination_count[i] += 1
            
            if domination_count[i] == 0:
                fronts[0].append(self.population[i])
        
        # Build subsequent fronts
        i = 0
        while len(fronts[i]) > 0:
            next_front = []
            for arch in fronts[i]:
                arch_idx = self.population.index(arch)
                for dominated_idx in dominated_solutions[arch_idx]:
                    domination_count[dominated_idx] -= 1
                    if domination_count[dominated_idx] == 0:
                        next_front.append(self.population[dominated_idx])
            i += 1
            fronts.append(next_front)
        
        return fronts[:-1]  # Remove empty last front
    
    def crowding_distance(self, front: List[List[int]]) -> Dict[int, float]:
        """Compute crowding distance for diversity."""
        distances = {self.population.index(arch): 0.0 for arch in front}
        n_obj = len(self.objectives[self.search_space.encode_architecture(front[0])])
        
        for obj_idx in range(n_obj):
            # Sort by objective
            sorted_front = sorted(front, 
                                key=lambda a: self.objectives[self.search_space.encode_architecture(a)][obj_idx])
            
            # Boundary points get infinite distance
            distances[self.population.index(sorted_front[0])] = float('inf')
            distances[self.population.index(sorted_front[-1])] = float('inf')
            
            # Others get normalized distance
            obj_range = (self.objectives[self.search_space.encode_architecture(sorted_front[-1])][obj_idx] - 
                        self.objectives[self.search_space.encode_architecture(sorted_front[0])][obj_idx])
            
            if obj_range > 0:
                for i in range(1, len(sorted_front) - 1):
                    arch = sorted_front[i]
                    arch_idx = self.population.index(arch)
                    
                    prev_obj = self.objectives[self.search_space.encode_architecture(sorted_front[i-1])][obj_idx]
                    next_obj = self.objectives[self.search_space.encode_architecture(sorted_front[i+1])][obj_idx]
                    
                    distances[arch_idx] += (next_obj - prev_obj) / obj_range
        
        return distances
    
    def search(self, fitness_fn: Callable[[List[int]], Tuple[float, float, float]]) -> List[List[int]]:
        """
        Run NSGA-II.
        
        Args:
            fitness_fn: Returns (accuracy, latency, n_params)
        
        Returns:
            pareto_front: List of Pareto optimal architectures
        """
        # Initialize
        self.population = [self.search_space.random_architecture() 
                          for _ in range(self.population_size)]
        
        # Evaluate
        for arch in self.population:
            acc, lat, params = fitness_fn(arch)
            self.objectives[self.search_space.encode_architecture(arch)] = [acc, -lat, -params]
        
        # Evolution
        for gen in range(self.n_generations):
            # Selection + Variation
            offspring = []
            while len(offspring) < self.population_size:
                parent1 = random.choice(self.population)
                parent2 = random.choice(self.population)
                child1, child2 = self.search_space.crossover(parent1, parent2)
                child1 = self.search_space.mutate(child1)
                offspring.extend([child1, child2])
            
            offspring = offspring[:self.population_size]
            
            # Evaluate offspring
            for arch in offspring:
                arch_str = self.search_space.encode_architecture(arch)
                if arch_str not in self.objectives:
                    acc, lat, params = fitness_fn(arch)
                    self.objectives[arch_str] = [acc, -lat, -params]
            
            # Combine parent + offspring
            combined = self.population + offspring
            self.population = combined
            
            # Non-dominated sorting
            fronts = self.non_dominated_sort()
            
            # Select next generation
            new_population = []
            for front in fronts:
                if len(new_population) + len(front) <= self.population_size:
                    new_population.extend(front)
                else:
                    # Use crowding distance
                    distances = self.crowding_distance(front)
                    sorted_front = sorted(front, 
                                        key=lambda a: distances[self.population.index(a)],
                                        reverse=True)
                    new_population.extend(sorted_front[:self.population_size - len(new_population)])
                    break
            
            self.population = new_population
            print(f"Generation {gen+1}/{self.n_generations}, Pareto Front Size: {len(fronts[0])}")
        
        # Return Pareto front
        fronts = self.non_dominated_sort()
        return fronts[0]


# ============================================================================
# 8. Demo and Benchmarking
# ============================================================================

def demo_search_space():
    """Demonstrate search space."""
    print("=" * 80)
    print("Search Space Demo")
    print("=" * 80)
    
    space = SearchSpace(n_nodes=4)
    
    # Sample random architectures
    print("\nRandom Architectures:")
    for i in range(3):
        arch = space.random_architecture()
        arch_str = space.encode_architecture(arch)
        print(f"  {i+1}. {arch_str}")
    
    # Mutation
    print("\nMutation:")
    arch = space.random_architecture()
    print(f"  Original: {space.encode_architecture(arch)}")
    mutated = space.mutate(arch, n_mutations=2)
    print(f"  Mutated:  {space.encode_architecture(mutated)}")
    
    # Crossover
    print("\nCrossover:")
    arch1 = space.random_architecture()
    arch2 = space.random_architecture()
    print(f"  Parent 1: {space.encode_architecture(arch1)}")
    print(f"  Parent 2: {space.encode_architecture(arch2)}")
    child1, child2 = space.crossover(arch1, arch2)
    print(f"  Child 1:  {space.encode_architecture(child1)}")
    print(f"  Child 2:  {space.encode_architecture(child2)}")


def demo_network_build():
    """Demonstrate building network from architecture."""
    print("\n" + "=" * 80)
    print("Network Building Demo")
    print("=" * 80)
    
    space = SearchSpace(n_nodes=4)
    arch = space.random_architecture()
    
    # Build network
    model = Network(
        C=16,
        n_classes=10,
        n_layers=8,
        arch=arch,
        n_nodes=4
    )
    
    print(f"\nArchitecture: {space.encode_architecture(arch)}")
    print(f"Parameters: {model.count_parameters():,}")
    
    # Test forward pass
    x = torch.randn(2, 3, 32, 32)
    logits = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {logits.shape}")


def demo_darts():
    """Demonstrate DARTS."""
    print("\n" + "=" * 80)
    print("DARTS Demo (Simplified)")
    print("=" * 80)
    
    # Build DARTS supernet
    model = DARTSNetwork(C=16, n_classes=10, n_layers=8, n_nodes=4)
    
    print(f"\nArchitecture parameters shape: {model.alphas.shape}")
    print(f"Network parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Forward pass
    x = torch.randn(2, 3, 32, 32)
    logits = model(x)
    print(f"Forward pass successful: {logits.shape}")
    
    # Discretize
    arch_ops = model.discretize()
    print(f"\nDiscretized architecture (first 6 ops): {arch_ops[:6]}")


def demo_evolutionary():
    """Demonstrate evolutionary NAS."""
    print("\n" + "=" * 80)
    print("Evolutionary NAS Demo")
    print("=" * 80)
    
    space = SearchSpace(n_nodes=4)
    
    # Mock fitness function (random)
    def fitness_fn(arch):
        return random.random()
    
    evolution = EvolutionaryNAS(
        search_space=space,
        population_size=20,
        n_generations=5,
        mutation_prob=0.2
    )
    
    print("\nRunning evolution (5 generations, 20 population)...")
    best_arch = evolution.evolve(fitness_fn)
    
    print(f"\nBest architecture: {space.encode_architecture(best_arch)}")
    print(f"Best fitness: {evolution.fitness[space.encode_architecture(best_arch)]:.4f}")


def demo_zero_shot():
    """Demonstrate zero-shot NAS."""
    print("\n" + "=" * 80)
    print("Zero-Shot NAS Demo (NASWOT)")
    print("=" * 80)
    
    space = SearchSpace(n_nodes=4)
    
    # Create dummy dataloader
    x = torch.randn(4, 3, 32, 32)
    y = torch.randint(0, 10, (4,))
    dataloader = [(x, y)]
    
    # Evaluate multiple architectures
    print("\nEvaluating architectures without training:")
    scores = []
    for i in range(3):
        arch = space.random_architecture()
        model = Network(C=16, n_classes=10, n_layers=8, arch=arch)
        
        try:
            score = compute_naswot_score(model, dataloader, device='cpu')
            scores.append((arch, score))
            print(f"  Arch {i+1}: score = {score:.2f}")
        except:
            print(f"  Arch {i+1}: failed to compute")
    
    if scores:
        best_arch, best_score = max(scores, key=lambda x: x[1])
        print(f"\nBest architecture (by zero-shot score):")
        print(f"  {space.encode_architecture(best_arch)}")
        print(f"  Score: {best_score:.2f}")


def print_comparison_table():
    """Print comparison of NAS methods."""
    print("\n" + "=" * 80)
    print("NAS Methods Comparison")
    print("=" * 80)
    
    table = """
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ Method           β”‚ Search Cost    β”‚ Quality        β”‚ Flexibility β”‚ Ease of Use  β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ RL-based (NAS)   β”‚ Very High      β”‚ Excellent      β”‚ High        β”‚ Complex      β”‚
    β”‚                  β”‚ (1000s GPU-hr) β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Evolutionary     β”‚ High           β”‚ Very Good      β”‚ Very High   β”‚ Easy         β”‚
    β”‚                  β”‚ (100s GPU-hr)  β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ DARTS            β”‚ Low            β”‚ Good           β”‚ Medium      β”‚ Medium       β”‚
    β”‚                  β”‚ (1-2 GPU-days) β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ ENAS (weight     β”‚ Very Low       β”‚ Good           β”‚ Medium      β”‚ Easy         β”‚
    β”‚ sharing)         β”‚ (0.5 GPU-day)  β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Zero-Shot        β”‚ Minimal        β”‚ Fair           β”‚ High        β”‚ Very Easy    β”‚
    β”‚ (NASWOT)         β”‚ (minutes)      β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Once-for-All     β”‚ High (once)    β”‚ Excellent      β”‚ Very High   β”‚ Medium       β”‚
    β”‚                  β”‚ Low (reuse)    β”‚                β”‚             β”‚              β”‚
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
    β”‚ Multi-Objective  β”‚ High           β”‚ Very Good      β”‚ Very High   β”‚ Medium       β”‚
    β”‚ (NSGA-II)        β”‚ (100s GPU-hr)  β”‚                β”‚             β”‚              β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    
    Key Trade-offs:
    - RL/Evolution: Best quality but highest cost
    - DARTS: Good balance of speed and quality, but can collapse
    - Weight Sharing (ENAS): Fast but biased weight estimates
    - Zero-Shot: Extremely fast but lower quality
    - Once-for-All: Train once, use many times (amortized efficiency)
    - Multi-Objective: Finds Pareto frontier (multiple solutions)
    
    Recommendations:
    βœ“ Starting out: DARTS or ENAS (good balance)
    βœ“ Limited budget: Zero-shot for screening, then refine top candidates
    βœ“ Production deployment: Once-for-All (flexible deployment)
    βœ“ Hardware constraints: Multi-objective NAS
    βœ“ Research/novel tasks: Evolutionary (most flexible)
    """
    print(table)


def main():
    """Run all demos."""
    print("\n" + "=" * 80)
    print("Advanced Neural Architecture Search - Comprehensive Demo")
    print("=" * 80)
    
    # Run demos
    demo_search_space()
    demo_network_build()
    demo_darts()
    demo_evolutionary()
    demo_zero_shot()
    
    # Comparison table
    print_comparison_table()
    
    print("\n" + "=" * 80)
    print("Demo Complete!")
    print("=" * 80)
    print("\nKey Takeaways:")
    print("1. Search space design critically impacts NAS efficiency")
    print("2. DARTS enables fast gradient-based architecture search")
    print("3. Weight sharing dramatically reduces search cost")
    print("4. Zero-shot methods can screen thousands of architectures instantly")
    print("5. Multi-objective optimization finds trade-off solutions")
    print("6. Modern NAS methods (OFA, zero-shot) make NAS practical")


if __name__ == "__main__":
    main()