Introduction

Neural Architecture Search (NAS) is an automated approach to designing neural network architectures. Instead of manually crafting network designs, NAS algorithms explore the space of possible architectures to find optimal configurations for specific tasks.

Why NAS Matters

  • Automation: Reduces human effort in architecture design
  • Performance: Can discover architectures that outperform human-designed ones
  • Efficiency: Optimizes for specific constraints (latency, memory, energy)
  • Scalability: Adapts to different tasks and domains

Theoretical Foundations

The NAS Framework

NAS consists of three main components:

  1. Search Space: Defines the set of possible architectures
  2. Search Strategy: Determines how to explore the search space
  3. Performance Estimation: Evaluates architecture quality
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional
import random
from collections import defaultdict

class NASFramework:
    def __init__(self, search_space, search_strategy, performance_estimator):
        self.search_space = search_space
        self.search_strategy = search_strategy
        self.performance_estimator = performance_estimator
        self.history = []
    
    def search(self, num_iterations: int):
        """Main NAS loop"""
        for iteration in range(num_iterations):
            # Sample architecture from search space
            architecture = self.search_strategy.sample_architecture(
                self.search_space, self.history
            )
            
            # Evaluate architecture
            performance = self.performance_estimator.evaluate(architecture)
            
            # Update history
            self.history.append({
                'architecture': architecture,
                'performance': performance,
                'iteration': iteration
            })
            
            # Update search strategy
            self.search_strategy.update(architecture, performance)
        
        return self.get_best_architecture()
    
    def get_best_architecture(self):
        return max(self.history, key=lambda x: x['performance'])

Search Space Design

Macro Search Space

Defines the overall structure of the network (number of layers, skip connections, etc.).

class MacroSearchSpace:
    def __init__(self, max_layers: int = 20, operations: List[str] = None):
        self.max_layers = max_layers
        self.operations = operations or [
            'conv3x3', 'conv5x5', 'conv7x7', 'maxpool3x3', 
            'avgpool3x3', 'identity', 'zero'
        ]
    
    def sample_architecture(self) -> Dict:
        """Sample a random architecture"""
        num_layers = random.randint(8, self.max_layers)
        architecture = {
            'layers': [],
            'skip_connections': []
        }
        
        for i in range(num_layers):
            layer = {
                'operation': random.choice(self.operations),
                'filters': random.choice([32, 64, 128, 256, 512]),
                'kernel_size': random.choice([3, 5, 7]) if 'conv' in self.operations[0] else 3
            }
            architecture['layers'].append(layer)
        
        # Add skip connections
        for i in range(1, num_layers):
            if random.random() < 0.3:  # 30% chance of skip connection
                source = random.randint(0, i-1)
                architecture['skip_connections'].append((source, i))
        
        return architecture

Micro Search Space (Cell-based)

Focuses on designing building blocks (cells) that are repeated throughout the network.

class CellSearchSpace:
    def __init__(self, num_nodes: int = 7, num_ops: int = 8):
        self.num_nodes = num_nodes
        self.operations = [
            'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect',
            'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5'
        ]
        self.num_ops = len(self.operations)
    
    def sample_cell(self) -> Dict:
        """Sample a cell architecture"""
        cell = {
            'normal_cell': self._sample_single_cell(),
            'reduction_cell': self._sample_single_cell()
        }
        return cell
    
    def _sample_single_cell(self) -> List[Tuple]:
        """Sample a single cell with intermediate nodes"""
        cell = []
        for i in range(2, self.num_nodes + 2):  # Nodes 2 to num_nodes+1
            # Each node has two inputs
            for j in range(2):
                # Sample input node (0 to i-1)
                input_node = random.randint(0, i-1)
                # Sample operation
                operation = random.choice(self.operations)
                cell.append((input_node, operation))
        return cell

Differentiable Search Space

Enables gradient-based optimization of architectures.

class DifferentiableSearchSpace(nn.Module):
    def __init__(self, operations: List[str], num_nodes: int = 4):
        super().__init__()
        self.operations = operations
        self.num_nodes = num_nodes
        self.num_ops = len(operations)
        
        # Architecture parameters (alpha)
        self.alpha = nn.Parameter(torch.randn(num_nodes, num_ops))
        
        # Operation modules
        self.ops = nn.ModuleList([
            self._get_operation(op) for op in operations
        ])
    
    def _get_operation(self, op_name: str) -> nn.Module:
        """Get operation module by name"""
        if op_name == 'conv3x3':
            return nn.Conv2d(32, 32, 3, padding=1)
        elif op_name == 'conv5x5':
            return nn.Conv2d(32, 32, 5, padding=2)
        elif op_name == 'maxpool3x3':
            return nn.MaxPool2d(3, stride=1, padding=1)
        elif op_name == 'avgpool3x3':
            return nn.AvgPool2d(3, stride=1, padding=1)
        elif op_name == 'identity':
            return nn.Identity()
        elif op_name == 'zero':
            return Zero()
        else:
            raise ValueError(f"Unknown operation: {op_name}")
    
    def forward(self, x):
        # Softmax over operations
        weights = torch.softmax(self.alpha, dim=-1)
        
        # Mixed operation
        output = 0
        for i, op in enumerate(self.ops):
            output += weights[0, i] * op(x)  # Simplified for single node
        
        return output
    
    def get_discrete_architecture(self):
        """Extract discrete architecture from continuous parameters"""
        arch = []
        for node in range(self.num_nodes):
            best_op_idx = torch.argmax(self.alpha[node])
            arch.append(self.operations[best_op_idx])
        return arch

class Zero(nn.Module):
    def forward(self, x):
        return torch.zeros_like(x)

Search Strategies

Differentiable Architecture Search (DARTS)

Gradient-based search using continuous relaxation.

class DARTSSearch:
    def __init__(self, model: DifferentiableSearchSpace, learning_rate: float = 0.025):
        self.model = model
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=learning_rate, momentum=0.9
        )
        self.arch_optimizer = torch.optim.Adam(
            [self.model.alpha], lr=3e-4
        )
    
    def search_step(self, train_data, val_data, criterion):
        """Single search step in DARTS"""
        # Update architecture parameters
        self.arch_optimizer.zero_grad()
        val_loss = self._compute_val_loss(val_data, criterion)
        val_loss.backward()
        self.arch_optimizer.step()
        
        # Update model parameters
        self.optimizer.zero_grad()
        train_loss = self._compute_train_loss(train_data, criterion)
        train_loss.backward()
        self.optimizer.step()
        
        return train_loss.item(), val_loss.item()
    
    def _compute_train_loss(self, data, criterion):
        """Compute training loss"""
        inputs, targets = data
        outputs = self.model(inputs)
        return criterion(outputs, targets)
    
    def _compute_val_loss(self, data, criterion):
        """Compute validation loss"""
        inputs, targets = data
        outputs = self.model(inputs)
        return criterion(outputs, targets)
    
    def get_final_architecture(self):
        """Extract final discrete architecture"""
        return self.model.get_discrete_architecture()

Performance Estimation

Full Training

Most accurate but computationally expensive.

class FullTrainingEvaluator:
    def __init__(self, dataset, num_epochs: int = 100):
        self.dataset = dataset
        self.num_epochs = num_epochs
    
    def evaluate(self, architecture) -> float:
        """Evaluate architecture by full training"""
        model = self._build_model(architecture)
        
        # Training loop
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(self.num_epochs):
            for batch in self.dataset:
                inputs, targets = batch
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
        
        # Evaluate on validation set
        return self._evaluate_model(model)
    
    def _build_model(self, architecture):
        """Build model from architecture description"""
        # Implementation depends on architecture format
        pass
    
    def _evaluate_model(self, model):
        """Evaluate model accuracy"""
        # Implementation for model evaluation
        pass

Early Stopping

Reduces training time while maintaining correlation with full training.

class EarlyStoppingEvaluator:
    def __init__(self, dataset, max_epochs: int = 20, patience: int = 5):
        self.dataset = dataset
        self.max_epochs = max_epochs
        self.patience = patience
    
    def evaluate(self, architecture) -> float:
        """Evaluate with early stopping"""
        model = self._build_model(architecture)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        best_val_acc = 0
        patience_counter = 0
        
        for epoch in range(self.max_epochs):
            # Training
            train_loss = self._train_epoch(model, optimizer, criterion)
            
            # Validation
            val_acc = self._validate_epoch(model)
            
            # Early stopping check
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= self.patience:
                    break
        
        return best_val_acc

Weight Sharing (One-Shot)

Trains a super-network once and evaluates sub-networks by inheritance.

class WeightSharingEvaluator:
    def __init__(self, supernet: nn.Module, dataset):
        self.supernet = supernet
        self.dataset = dataset
        self.trained = False
    
    def train_supernet(self):
        """Train the supernet once"""
        if self.trained:
            return
        
        optimizer = torch.optim.SGD(self.supernet.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(50):  # Train supernet
            for batch in self.dataset:
                inputs, targets = batch
                optimizer.zero_grad()
                
                # Sample random path through supernet
                self.supernet.sample_active_subnet()
                outputs = self.supernet(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
        
        self.trained = True
    
    def evaluate(self, architecture) -> float:
        """Evaluate architecture using trained supernet"""
        if not self.trained:
            self.train_supernet()
        
        # Configure supernet for specific architecture
        self.supernet.set_active_subnet(architecture)
        
        # Evaluate on validation set
        return self._evaluate_subnet()
    
    def _evaluate_subnet(self):
        """Evaluate current subnet configuration"""
        self.supernet.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in self.dataset:
                inputs, targets = batch
                outputs = self.supernet(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        return correct / total

Advanced Techniques

Multi-Objective NAS

Optimizes multiple objectives simultaneously.

class MultiObjectiveNAS:
    def __init__(self, objectives: List[str]):
        self.objectives = objectives  # e.g., ['accuracy', 'latency', 'flops']
        self.pareto_front = []
    
    def evaluate_architecture(self, architecture) -> Dict[str, float]:
        """Evaluate architecture on multiple objectives"""
        results = {}
        
        if 'accuracy' in self.objectives:
            results['accuracy'] = self._evaluate_accuracy(architecture)
        
        if 'latency' in self.objectives:
            results['latency'] = self._evaluate_latency(architecture)
        
        if 'flops' in self.objectives:
            results['flops'] = self._evaluate_flops(architecture)
        
        return results
    
    def update_pareto_front(self, architecture, objectives):
        """Update Pareto front with new architecture"""
        # Check if architecture is dominated
        dominated = False
        for pareto_arch, pareto_obj in self.pareto_front:
            if self._dominates(pareto_obj, objectives):
                dominated = True
                break
        
        if not dominated:
            # Remove dominated architectures
            self.pareto_front = [
                (arch, obj) for arch, obj in self.pareto_front
                if not self._dominates(objectives, obj)
            ]
            # Add new architecture
            self.pareto_front.append((architecture, objectives))
    
    def _dominates(self, obj1: Dict, obj2: Dict) -> bool:
        """Check if obj1 dominates obj2"""
        better_in_all = True
        strictly_better_in_one = False
        
        for objective in self.objectives:
            if objective in ['accuracy']:  # Higher is better
                if obj1[objective] < obj2[objective]:
                    better_in_all = False
                elif obj1[objective] > obj2[objective]:
                    strictly_better_in_one = True
            else:  # Lower is better (latency, flops)
                if obj1[objective] > obj2[objective]:
                    better_in_all = False
                elif obj1[objective] < obj2[objective]:
                    strictly_better_in_one = True
        
        return better_in_all and strictly_better_in_one

Implementation Examples

Complete DARTS Implementation

class DARTSCell(nn.Module):
    def __init__(self, num_nodes: int, channels: int):
        super().__init__()
        self.num_nodes = num_nodes
        self.channels = channels
        
        # Mixed operations for each edge
        self.mixed_ops = nn.ModuleList()
        for i in range(num_nodes):
            for j in range(2 + i):  # Each node connects to all previous nodes
                self.mixed_ops.append(MixedOp(channels))
        
        # Architecture parameters
        self.alpha = nn.Parameter(torch.randn(len(self.mixed_ops), 8))
    
    def forward(self, inputs):
        # inputs[0] and inputs[1] are the two input nodes
        states = [inputs[0], inputs[1]]
        
        offset = 0
        for i in range(self.num_nodes):
            # Collect inputs from all previous nodes
            node_inputs = []
            for j in range(len(states)):
                op_idx = offset + j
                node_inputs.append(self.mixed_ops[op_idx](states[j], self.alpha[op_idx]))
            
            # Sum all inputs to this node
            state = sum(node_inputs)
            states.append(state)
            offset += len(states) - 1
        
        # Concatenate final nodes
        return torch.cat(states[-self.num_nodes:], dim=1)

class MixedOp(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.ops = nn.ModuleList([
            SepConv(channels, channels, 3, 1, 1),
            SepConv(channels, channels, 5, 1, 2),
            DilConv(channels, channels, 3, 1, 2, 2),
            DilConv(channels, channels, 5, 1, 4, 2),
            nn.MaxPool2d(3, 1, 1),
            nn.AvgPool2d(3, 1, 1),
            Identity(),
            Zero()
        ])
    
    def forward(self, x, alpha):
        # Apply weighted sum of operations
        weights = torch.softmax(alpha, dim=0)
        return sum(w * op(x) for w, op in zip(weights, self.ops))

class SepConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, 
                     groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class DilConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 
                     dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class Identity(nn.Module):
    def forward(self, x):
        return x

class Zero(nn.Module):
    def forward(self, x):
        return torch.zeros_like(x)

# Complete DARTS Network
class DARTSNetwork(nn.Module):
    def __init__(self, num_classes: int, num_cells: int = 8, channels: int = 36):
        super().__init__()
        self.num_cells = num_cells
        self.channels = channels
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
        
        # Cells
        self.cells = nn.ModuleList()
        for i in range(num_cells):
            if i in [num_cells // 3, 2 * num_cells // 3]:
                # Reduction cell
                self.cells.append(DARTSCell(4, channels))
                channels *= 2
            else:
                # Normal cell
                self.cells.append(DARTSCell(4, channels))
        
        # Classifier
        self.classifier = nn.Linear(channels, num_classes)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
    
    def forward(self, x):
        x = self.stem(x)
        
        for cell in self.cells:
            x = cell([x, x])  # Use same input for both inputs
        
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        
        return x

Evolutionary Search Example

class EvolutionaryNAS:
    def __init__(self, population_size: int = 50, generations: int = 100):
        self.population_size = population_size
        self.generations = generations
        self.population = []
        self.fitness_history = []
    
    def run_search(self, search_space, evaluator):
        """Run evolutionary search"""
        # Initialize population
        self.population = [
            search_space.sample_architecture() 
            for _ in range(self.population_size)
        ]
        
        for generation in range(self.generations):
            # Evaluate population
            fitness_scores = []
            for individual in self.population:
                fitness = evaluator.evaluate(individual)
                fitness_scores.append(fitness)
            
            self.fitness_history.append(max(fitness_scores))
            
            # Selection and reproduction
            new_population = []
            for _ in range(self.population_size):
                # Tournament selection
                parent1 = self._tournament_selection(fitness_scores)
                parent2 = self._tournament_selection(fitness_scores)
                
                # Crossover
                child = self._crossover(parent1, parent2)
                
                # Mutation
                child = self._mutate(child, search_space)
                
                new_population.append(child)
            
            self.population = new_population
        
        # Return best architecture
        final_fitness = [evaluator.evaluate(ind) for ind in self.population]
        best_idx = np.argmax(final_fitness)
        return self.population[best_idx], final_fitness[best_idx]
    
    def _tournament_selection(self, fitness_scores, tournament_size: int = 3):
        """Tournament selection"""
        tournament_indices = random.sample(range(len(fitness_scores)), tournament_size)
        tournament_fitness = [fitness_scores[i] for i in tournament_indices]
        winner_idx = tournament_indices[np.argmax(tournament_fitness)]
        return self.population[winner_idx]
    
    def _crossover(self, parent1, parent2):
        """Single-point crossover"""
        child = parent1.copy()
        
        if 'layers' in parent1 and 'layers' in parent2:
            # Crossover layers
            min_length = min(len(parent1['layers']), len(parent2['layers']))
            if min_length > 1:
                crossover_point = random.randint(1, min_length - 1)
                child['layers'] = (parent1['layers'][:crossover_point] + 
                                 parent2['layers'][crossover_point:])
        
        return child
    
    def _mutate(self, individual, search_space, mutation_rate: float = 0.1):
        """Mutate individual"""
        if random.random() < mutation_rate:
            if 'layers' in individual and individual['layers']:
                # Randomly mutate a layer
                layer_idx = random.randint(0, len(individual['layers']) - 1)
                layer = individual['layers'][layer_idx]
                
                # Mutate operation
                if random.random() < 0.5:
                    layer['operation'] = random.choice(search_space.operations)
                
                # Mutate filters
                if random.random() < 0.5:
                    layer['filters'] = random.choice([32, 64, 128, 256, 512])
        
        return individual


### Reinforcement Learning NAS Example

```python
class RLNASController(nn.Module):
    def __init__(self, num_layers: int = 6, lstm_size: int = 32, 
                 num_branches: int = 6, out_filters: int = 48):
        super().__init__()
        self.num_layers = num_layers
        self.lstm_size = lstm_size
        self.num_branches = num_branches
        self.out_filters = out_filters
        
        # LSTM controller
        self.lstm = nn.LSTMCell(lstm_size, lstm_size)
        
        # Embedding layers for different architecture decisions
        self.g_emb = nn.Embedding(1, lstm_size)  # Go embedding
        self.encoder = nn.Linear(lstm_size, lstm_size)
        
        # Decision heads
        self.conv_op = nn.Linear(lstm_size, len(CONV_OPS))
        self.conv_ksize = nn.Linear(lstm_size, len(CONV_KERNEL_SIZES))
        self.conv_filters = nn.Linear(lstm_size, len(CONV_FILTERS))
        self.pooling_op = nn.Linear(lstm_size, len(POOLING_OPS))
        self.pooling_ksize = nn.Linear(lstm_size, len(POOLING_KERNEL_SIZES))
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self):
        init_range = 0.1
        for param in self.parameters():
            param.data.uniform_(-init_range, init_range)
    
    def forward(self, batch_size: int = 1):
        """Sample architecture using the controller"""
        # Initialize hidden state
        h = torch.zeros(batch_size, self.lstm_size)
        c = torch.zeros(batch_size, self.lstm_size)
        
        # Start with go embedding
        inputs = self.g_emb.weight.repeat(batch_size, 1)
        
        # Store sampled architecture
        arc_seq = []
        entropies = []
        log_probs = []
        
        for layer_id in range(self.num_layers):
            # LSTM step
            h, c = self.lstm(inputs, (h, c))
            
            # Sample convolution operation
            conv_op_logits = self.conv_op(h)
            conv_op_prob = F.softmax(conv_op_logits, dim=-1)
            conv_op_log_prob = F.log_softmax(conv_op_logits, dim=-1)
            conv_op_entropy = -(conv_op_log_prob * conv_op_prob).sum(1, keepdim=True)
            
            conv_op_sample = torch.multinomial(conv_op_prob, 1)
            conv_op_sample = conv_op_sample.view(-1)
            
            arc_seq.append(conv_op_sample)
            entropies.append(conv_op_entropy)
            log_probs.append(conv_op_log_prob.gather(1, conv_op_sample.unsqueeze(1)))
            
            # Sample kernel size
            conv_ksize_logits = self.conv_ksize(h)
            conv_ksize_prob = F.softmax(conv_ksize_logits, dim=-1)
            conv_ksize_log_prob = F.log_softmax(conv_ksize_logits, dim=-1)
            conv_ksize_entropy = -(conv_ksize_log_prob * conv_ksize_prob).sum(1, keepdim=True)
            
            conv_ksize_sample = torch.multinomial(conv_ksize_prob, 1)
            conv_ksize_sample = conv_ksize_sample.view(-1)
            
            arc_seq.append(conv_ksize_sample)
            entropies.append(conv_ksize_entropy)
            log_probs.append(conv_ksize_log_prob.gather(1, conv_ksize_sample.unsqueeze(1)))
            
            # Continue for other decisions...
            inputs = h  # Use current hidden state as input for next step
        
        return arc_seq, torch.cat(log_probs), torch.cat(entropies)

# Constants for architecture choices
CONV_OPS = ['conv', 'depthwise_conv', 'separable_conv']
CONV_KERNEL_SIZES = [3, 5, 7]
CONV_FILTERS = [24, 36, 48, 64]
POOLING_OPS = ['max_pool', 'avg_pool', 'no_pool']
POOLING_KERNEL_SIZES = [2, 3]

class RLNASTrainer:
    def __init__(self, controller, child_model_builder, evaluator):
        self.controller = controller
        self.child_model_builder = child_model_builder
        self.evaluator = evaluator
        
        # Controller optimizer
        self.controller_optimizer = torch.optim.Adam(
            controller.parameters(), lr=3.5e-4
        )
        
        # Baseline for variance reduction
        self.baseline = None
        self.baseline_decay = 0.99
        
    def train_controller(self, num_epochs: int = 2000):
        """Train the controller using REINFORCE"""
        for epoch in range(num_epochs):
            # Sample architectures
            arc_seq, log_probs, entropies = self.controller()
            
            # Build and evaluate child model
            child_model = self.child_model_builder.build(arc_seq)
            reward = self.evaluator.evaluate(child_model)
            
            # Update baseline
            if self.baseline is None:
                self.baseline = reward
            else:
                self.baseline = self.baseline_decay * self.baseline + \
                              (1 - self.baseline_decay) * reward
            
            # Compute advantage
            advantage = reward - self.baseline
            
            # Controller loss (REINFORCE)
            controller_loss = -log_probs * advantage
            controller_loss = controller_loss.sum()
            
            # Add entropy regularization
            entropy_penalty = -entropies.sum() * 1e-4
            total_loss = controller_loss + entropy_penalty
            
            # Update controller
            self.controller_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 5.0)
            self.controller_optimizer.step()
            
            if epoch % 100 == 0:
                print(f'Epoch {epoch}, Reward: {reward:.4f}, '
                      f'Baseline: {self.baseline:.4f}, Loss: {total_loss.item():.4f}')

Best Practices

1. Search Space Design Guidelines

class SearchSpaceDesignPrinciples:
    """
    Guidelines for designing effective search spaces
    """
    
    def __init__(self):
        self.principles = {
            'expressiveness': 'Include diverse operations and connections',
            'efficiency': 'Balance search space size with computational cost',
            'human_knowledge': 'Incorporate domain-specific insights',
            'scalability': 'Design for different input sizes and tasks'
        }
    
    def design_macro_space(self, task_type: str):
        """Design macro search space based on task"""
        if task_type == 'image_classification':
            return {
                'operations': ['conv3x3', 'conv5x5', 'depthwise_conv', 'pointwise_conv',
                              'max_pool', 'avg_pool', 'global_pool', 'identity'],
                'max_layers': 20,
                'channels': [16, 32, 64, 128, 256, 512],
                'skip_connections': True,
                'batch_norm': True,
                'activation': ['relu', 'relu6', 'swish']
            }
        elif task_type == 'object_detection':
            return {
                'operations': ['conv3x3', 'conv5x5', 'depthwise_conv', 'atrous_conv',
                              'max_pool', 'avg_pool', 'identity'],
                'max_layers': 30,
                'channels': [32, 64, 128, 256, 512, 1024],
                'skip_connections': True,
                'fpn_layers': True,
                'anchor_scales': [32, 64, 128, 256, 512]
            }
    
    def validate_search_space(self, search_space):
        """Validate search space design"""
        issues = []
        
        # Check for minimal viable operations
        if len(search_space.get('operations', [])) < 3:
            issues.append("Too few operations - may limit expressiveness")
        
        # Check for identity operation
        if 'identity' not in search_space.get('operations', []):
            issues.append("Missing identity operation - may hurt skip connections")
        
        # Check channel progression
        channels = search_space.get('channels', [])
        if channels and not all(channels[i] <= channels[i+1] for i in range(len(channels)-1)):
            issues.append("Non-monotonic channel progression")
        
        return issues

2. Training Strategies

class NASTrainingStrategies:
    """Advanced training strategies for NAS"""
    
    def __init__(self):
        self.strategies = {}
    
    def progressive_shrinking(self, supernet, dataset, stages: int = 4):
        """Progressive shrinking strategy"""
        current_channels = supernet.max_channels
        
        for stage in range(stages):
            # Reduce search space
            target_channels = current_channels // (2 ** stage)
            supernet.set_channel_constraint(target_channels)
            
            # Train for this stage
            self._train_stage(supernet, dataset, epochs=50)
            
            print(f"Stage {stage + 1}: Max channels = {target_channels}")
    
    def sandwich_sampling(self, supernet, dataset):
        """Sandwich sampling for training efficiency"""
        optimizer = torch.optim.SGD(supernet.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(100):
            for batch in dataset:
                inputs, targets = batch
                
                # Sample architectures: largest, smallest, and random
                architectures = [
                    supernet.largest_architecture(),
                    supernet.smallest_architecture(),
                    supernet.random_architecture(),
                    supernet.random_architecture()
                ]
                
                total_loss = 0
                for arch in architectures:
                    supernet.set_active_subnet(arch)
                    optimizer.zero_grad()
                    outputs = supernet(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    total_loss += loss.item()
                
                optimizer.step()
                
                if epoch % 10 == 0:
                    print(f"Epoch {epoch}, Loss: {total_loss/len(architectures):.4f}")
    
    def knowledge_distillation(self, student_arch, teacher_model, dataset):
        """Knowledge distillation for architecture evaluation"""
        student_model = self._build_model(student_arch)
        
        optimizer = torch.optim.SGD(student_model.parameters(), lr=0.01)
        kd_loss = nn.KLDivLoss()
        ce_loss = nn.CrossEntropyLoss()
        
        alpha = 0.7  # Distillation weight
        temperature = 4.0
        
        for epoch in range(50):
            for batch in dataset:
                inputs, targets = batch
                
                # Teacher predictions
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)
                    teacher_probs = F.softmax(teacher_outputs / temperature, dim=1)
                
                # Student predictions
                student_outputs = student_model(inputs)
                student_log_probs = F.log_softmax(student_outputs / temperature, dim=1)
                
                # Combined loss
                distill_loss = kd_loss(student_log_probs, teacher_probs)
                hard_loss = ce_loss(student_outputs, targets)
                
                total_loss = alpha * distill_loss + (1 - alpha) * hard_loss
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
        
        return self._evaluate_model(student_model)

3. Evaluation and Benchmarking

class NASBenchmarking:
    """Benchmarking and evaluation utilities"""
    
    def __init__(self):
        self.metrics = {}
    
    def comprehensive_evaluation(self, architecture, datasets):
        """Comprehensive evaluation across multiple metrics"""
        results = {}
        
        # Build model
        model = self._build_model(architecture)
        
        # Accuracy metrics
        for dataset_name, dataset in datasets.items():
            accuracy = self._evaluate_accuracy(model, dataset)
            results[f'{dataset_name}_accuracy'] = accuracy
        
        # Efficiency metrics
        results['params'] = self._count_parameters(model)
        results['flops'] = self._count_flops(model)
        results['latency'] = self._measure_latency(model)
        results['memory'] = self._measure_memory(model)
        
        # Robustness metrics
        results['adversarial_robustness'] = self._evaluate_adversarial_robustness(model)
        results['noise_robustness'] = self._evaluate_noise_robustness(model)
        
        return results
    
    def _count_parameters(self, model):
        """Count model parameters"""
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def _count_flops(self, model, input_size=(1, 3, 224, 224)):
        """Count FLOPs using a simple profiler"""
        def flop_count_hook(module, input, output):
            if isinstance(module, nn.Conv2d):
                # Conv2d FLOPs
                batch_size, in_channels, input_height, input_width = input[0].shape
                output_height, output_width = output.shape[2], output.shape[3]
                kernel_height, kernel_width = module.kernel_size
                
                flops = batch_size * in_channels * kernel_height * kernel_width * \
                       output_height * output_width * module.out_channels
                
                if hasattr(module, 'flops'):
                    module.flops += flops
                else:
                    module.flops = flops
        
        # Register hooks
        hooks = []
        for module in model.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                hooks.append(module.register_forward_hook(flop_count_hook))
        
        # Forward pass
        model.eval()
        with torch.no_grad():
            dummy_input = torch.randn(input_size)
            model(dummy_input)
        
        # Collect FLOPs
        total_flops = 0
        for module in model.modules():
            if hasattr(module, 'flops'):
                total_flops += module.flops
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        return total_flops
    
    def _measure_latency(self, model, input_size=(1, 3, 224, 224), runs=100):
        """Measure inference latency"""
        model.eval()
        dummy_input = torch.randn(input_size)
        
        # Warmup
        for _ in range(10):
            with torch.no_grad():
                model(dummy_input)
        
        # Measure
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start_time = time.time()
        
        for _ in range(runs):
            with torch.no_grad():
                model(dummy_input)
        
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        end_time = time.time()
        
        return (end_time - start_time) / runs
    
    def compare_search_methods(self, methods, search_space, evaluator, runs=5):
        """Compare different search methods"""
        results = {}
        
        for method_name, method in methods.items():
            method_results = []
            
            for run in range(runs):
                # Set random seed for reproducibility
                torch.manual_seed(run)
                random.seed(run)
                np.random.seed(run)
                
                # Run search
                best_arch, best_performance = method.search(search_space, evaluator)
                method_results.append({
                    'architecture': best_arch,
                    'performance': best_performance,
                    'run': run
                })
            
            results[method_name] = method_results
        
        return self._analyze_comparison_results(results)
    
    def _analyze_comparison_results(self, results):
        """Analyze comparison results"""
        analysis = {}
        
        for method_name, method_results in results.items():
            performances = [r['performance'] for r in method_results]
            
            analysis[method_name] = {
                'mean_performance': np.mean(performances),
                'std_performance': np.std(performances),
                'best_performance': np.max(performances),
                'worst_performance': np.min(performances),
                'median_performance': np.median(performances)
            }
        
        # Rank methods
        ranked_methods = sorted(analysis.items(), 
                              key=lambda x: x[1]['mean_performance'], 
                              reverse=True)
        
        return {
            'detailed_results': analysis,
            'ranking': ranked_methods,
            'summary': self._generate_summary(ranked_methods)
        }
    
    def _generate_summary(self, ranked_methods):
        """Generate summary of comparison"""
        summary = []
        summary.append("=== NAS Method Comparison Results ===")
        
        for i, (method_name, stats) in enumerate(ranked_methods):
            summary.append(f"{i+1}. {method_name}:")
            summary.append(f"   Mean: {stats['mean_performance']:.4f}")
            summary.append(f"   Std:  {stats['std_performance']:.4f}")
            summary.append(f"   Best: {stats['best_performance']:.4f}")
            summary.append("")
        
        return "\n".join(summary)

Tools and Frameworks