Neural Architecture Search: Complete Code Guide

Introduction
Neural Architecture Search (NAS) is an automated approach to designing neural network architectures. Instead of manually crafting network designs, NAS algorithms explore the space of possible architectures to find optimal configurations for specific tasks.
Why NAS Matters
- Automation: Reduces human effort in architecture design
- Performance: Can discover architectures that outperform human-designed ones
- Efficiency: Optimizes for specific constraints (latency, memory, energy)
- Scalability: Adapts to different tasks and domains
Theoretical Foundations
The NAS Framework
NAS consists of three main components:
- Search Space: Defines the set of possible architectures
- Search Strategy: Determines how to explore the search space
- Performance Estimation: Evaluates architecture quality
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional
import random
from collections import defaultdict
class NASFramework:
def __init__(self, search_space, search_strategy, performance_estimator):
self.search_space = search_space
self.search_strategy = search_strategy
self.performance_estimator = performance_estimator
self.history = []
def search(self, num_iterations: int):
"""Main NAS loop"""
for iteration in range(num_iterations):
# Sample architecture from search space
architecture = self.search_strategy.sample_architecture(
self.search_space, self.history
)
# Evaluate architecture
performance = self.performance_estimator.evaluate(architecture)
# Update history
self.history.append({
'architecture': architecture,
'performance': performance,
'iteration': iteration
})
# Update search strategy
self.search_strategy.update(architecture, performance)
return self.get_best_architecture()
def get_best_architecture(self):
return max(self.history, key=lambda x: x['performance'])Search Space Design
Macro Search Space
Defines the overall structure of the network (number of layers, skip connections, etc.).
class MacroSearchSpace:
def __init__(self, max_layers: int = 20, operations: List[str] = None):
self.max_layers = max_layers
self.operations = operations or [
'conv3x3', 'conv5x5', 'conv7x7', 'maxpool3x3',
'avgpool3x3', 'identity', 'zero'
]
def sample_architecture(self) -> Dict:
"""Sample a random architecture"""
num_layers = random.randint(8, self.max_layers)
architecture = {
'layers': [],
'skip_connections': []
}
for i in range(num_layers):
layer = {
'operation': random.choice(self.operations),
'filters': random.choice([32, 64, 128, 256, 512]),
'kernel_size': random.choice([3, 5, 7]) if 'conv' in self.operations[0] else 3
}
architecture['layers'].append(layer)
# Add skip connections
for i in range(1, num_layers):
if random.random() < 0.3: # 30% chance of skip connection
source = random.randint(0, i-1)
architecture['skip_connections'].append((source, i))
return architectureMicro Search Space (Cell-based)
Focuses on designing building blocks (cells) that are repeated throughout the network.
class CellSearchSpace:
def __init__(self, num_nodes: int = 7, num_ops: int = 8):
self.num_nodes = num_nodes
self.operations = [
'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect',
'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5'
]
self.num_ops = len(self.operations)
def sample_cell(self) -> Dict:
"""Sample a cell architecture"""
cell = {
'normal_cell': self._sample_single_cell(),
'reduction_cell': self._sample_single_cell()
}
return cell
def _sample_single_cell(self) -> List[Tuple]:
"""Sample a single cell with intermediate nodes"""
cell = []
for i in range(2, self.num_nodes + 2): # Nodes 2 to num_nodes+1
# Each node has two inputs
for j in range(2):
# Sample input node (0 to i-1)
input_node = random.randint(0, i-1)
# Sample operation
operation = random.choice(self.operations)
cell.append((input_node, operation))
return cellDifferentiable Search Space
Enables gradient-based optimization of architectures.
class DifferentiableSearchSpace(nn.Module):
def __init__(self, operations: List[str], num_nodes: int = 4):
super().__init__()
self.operations = operations
self.num_nodes = num_nodes
self.num_ops = len(operations)
# Architecture parameters (alpha)
self.alpha = nn.Parameter(torch.randn(num_nodes, num_ops))
# Operation modules
self.ops = nn.ModuleList([
self._get_operation(op) for op in operations
])
def _get_operation(self, op_name: str) -> nn.Module:
"""Get operation module by name"""
if op_name == 'conv3x3':
return nn.Conv2d(32, 32, 3, padding=1)
elif op_name == 'conv5x5':
return nn.Conv2d(32, 32, 5, padding=2)
elif op_name == 'maxpool3x3':
return nn.MaxPool2d(3, stride=1, padding=1)
elif op_name == 'avgpool3x3':
return nn.AvgPool2d(3, stride=1, padding=1)
elif op_name == 'identity':
return nn.Identity()
elif op_name == 'zero':
return Zero()
else:
raise ValueError(f"Unknown operation: {op_name}")
def forward(self, x):
# Softmax over operations
weights = torch.softmax(self.alpha, dim=-1)
# Mixed operation
output = 0
for i, op in enumerate(self.ops):
output += weights[0, i] * op(x) # Simplified for single node
return output
def get_discrete_architecture(self):
"""Extract discrete architecture from continuous parameters"""
arch = []
for node in range(self.num_nodes):
best_op_idx = torch.argmax(self.alpha[node])
arch.append(self.operations[best_op_idx])
return arch
class Zero(nn.Module):
def forward(self, x):
return torch.zeros_like(x)Search Strategies
Random Search
Simple baseline that samples architectures randomly.
class RandomSearch:
def __init__(self):
self.history = []
def sample_architecture(self, search_space, history):
return search_space.sample_architecture()
def update(self, architecture, performance):
self.history.append((architecture, performance))Evolutionary Search
Uses genetic algorithms to evolve architectures.
class EvolutionarySearch:
def __init__(self, population_size: int = 50, mutation_rate: float = 0.1):
self.population_size = population_size
self.mutation_rate = mutation_rate
self.population = []
self.fitness_scores = []
def initialize_population(self, search_space):
"""Initialize random population"""
self.population = [
search_space.sample_architecture()
for _ in range(self.population_size)
]
def sample_architecture(self, search_space, history):
if not self.population:
self.initialize_population(search_space)
return self.population[0]
# Tournament selection
return self._tournament_selection()
def _tournament_selection(self, tournament_size: int = 3):
"""Select parent via tournament selection"""
tournament_indices = random.sample(
range(len(self.population)), tournament_size
)
tournament_fitness = [self.fitness_scores[i] for i in tournament_indices]
winner_idx = tournament_indices[np.argmax(tournament_fitness)]
return self.population[winner_idx]
def update(self, architecture, performance):
"""Update population with new architecture"""
if len(self.population) < self.population_size:
self.population.append(architecture)
self.fitness_scores.append(performance)
else:
# Replace worst performing architecture
worst_idx = np.argmin(self.fitness_scores)
if performance > self.fitness_scores[worst_idx]:
self.population[worst_idx] = architecture
self.fitness_scores[worst_idx] = performance
def mutate_architecture(self, architecture, search_space):
"""Mutate architecture"""
if random.random() < self.mutation_rate:
# Simple mutation: change random operation
if 'layers' in architecture:
layer_idx = random.randint(0, len(architecture['layers']) - 1)
architecture['layers'][layer_idx]['operation'] = random.choice(
search_space.operations
)
return architectureReinforcement Learning Search
Uses RL to learn architecture sampling policies.
class RLController(nn.Module):
def __init__(self, vocab_size: int, hidden_size: int = 64):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.lstm = nn.LSTM(vocab_size, hidden_size, batch_first=True)
self.classifier = nn.Linear(hidden_size, vocab_size)
def forward(self, x):
lstm_out, _ = self.lstm(x)
logits = self.classifier(lstm_out)
return logits
def sample_architecture(self, max_length: int = 20):
"""Sample architecture using the controller"""
self.eval()
with torch.no_grad():
sequence = []
hidden = None
# Start token
input_token = torch.zeros(1, 1, self.vocab_size)
for _ in range(max_length):
logits, hidden = self.lstm(input_token, hidden)
logits = self.classifier(logits)
# Sample next token
probs = torch.softmax(logits.squeeze(), dim=0)
next_token = torch.multinomial(probs, 1).item()
sequence.append(next_token)
# Prepare input for next step
input_token = torch.zeros(1, 1, self.vocab_size)
input_token[0, 0, next_token] = 1
return sequence
class ReinforcementLearningSearch:
def __init__(self, vocab_size: int, learning_rate: float = 0.001):
self.controller = RLController(vocab_size)
self.optimizer = torch.optim.Adam(
self.controller.parameters(), lr=learning_rate
)
self.baseline = 0
self.baseline_decay = 0.99
def sample_architecture(self, search_space, history):
sequence = self.controller.sample_architecture()
return self._sequence_to_architecture(sequence, search_space)
def _sequence_to_architecture(self, sequence, search_space):
"""Convert sequence to architecture"""
# Simplified conversion
architecture = {'layers': []}
for i in range(0, len(sequence), 2):
if i + 1 < len(sequence):
op_idx = sequence[i] % len(search_space.operations)
filter_idx = sequence[i + 1] % 4
layer = {
'operation': search_space.operations[op_idx],
'filters': [32, 64, 128, 256][filter_idx]
}
architecture['layers'].append(layer)
return architecture
def update(self, architecture, performance):
"""Update controller using REINFORCE"""
# Update baseline
self.baseline = self.baseline_decay * self.baseline + \
(1 - self.baseline_decay) * performance
# Calculate advantage
advantage = performance - self.baseline
# Update controller (simplified)
self.optimizer.zero_grad()
# In practice, you'd compute the log probability of the sampled architecture
# and multiply by the advantage for the REINFORCE update
# loss = -log_prob * advantage
self.optimizer.step()Differentiable Architecture Search (DARTS)
Gradient-based search using continuous relaxation.
class DARTSSearch:
def __init__(self, model: DifferentiableSearchSpace, learning_rate: float = 0.025):
self.model = model
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=learning_rate, momentum=0.9
)
self.arch_optimizer = torch.optim.Adam(
[self.model.alpha], lr=3e-4
)
def search_step(self, train_data, val_data, criterion):
"""Single search step in DARTS"""
# Update architecture parameters
self.arch_optimizer.zero_grad()
val_loss = self._compute_val_loss(val_data, criterion)
val_loss.backward()
self.arch_optimizer.step()
# Update model parameters
self.optimizer.zero_grad()
train_loss = self._compute_train_loss(train_data, criterion)
train_loss.backward()
self.optimizer.step()
return train_loss.item(), val_loss.item()
def _compute_train_loss(self, data, criterion):
"""Compute training loss"""
inputs, targets = data
outputs = self.model(inputs)
return criterion(outputs, targets)
def _compute_val_loss(self, data, criterion):
"""Compute validation loss"""
inputs, targets = data
outputs = self.model(inputs)
return criterion(outputs, targets)
def get_final_architecture(self):
"""Extract final discrete architecture"""
return self.model.get_discrete_architecture()Performance Estimation
Full Training
Most accurate but computationally expensive.
class FullTrainingEvaluator:
def __init__(self, dataset, num_epochs: int = 100):
self.dataset = dataset
self.num_epochs = num_epochs
def evaluate(self, architecture) -> float:
"""Evaluate architecture by full training"""
model = self._build_model(architecture)
# Training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(self.num_epochs):
for batch in self.dataset:
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Evaluate on validation set
return self._evaluate_model(model)
def _build_model(self, architecture):
"""Build model from architecture description"""
# Implementation depends on architecture format
pass
def _evaluate_model(self, model):
"""Evaluate model accuracy"""
# Implementation for model evaluation
passEarly Stopping
Reduces training time while maintaining correlation with full training.
class EarlyStoppingEvaluator:
def __init__(self, dataset, max_epochs: int = 20, patience: int = 5):
self.dataset = dataset
self.max_epochs = max_epochs
self.patience = patience
def evaluate(self, architecture) -> float:
"""Evaluate with early stopping"""
model = self._build_model(architecture)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0
patience_counter = 0
for epoch in range(self.max_epochs):
# Training
train_loss = self._train_epoch(model, optimizer, criterion)
# Validation
val_acc = self._validate_epoch(model)
# Early stopping check
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= self.patience:
break
return best_val_accWeight Sharing (One-Shot)
Trains a super-network once and evaluates sub-networks by inheritance.
class WeightSharingEvaluator:
def __init__(self, supernet: nn.Module, dataset):
self.supernet = supernet
self.dataset = dataset
self.trained = False
def train_supernet(self):
"""Train the supernet once"""
if self.trained:
return
optimizer = torch.optim.SGD(self.supernet.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(50): # Train supernet
for batch in self.dataset:
inputs, targets = batch
optimizer.zero_grad()
# Sample random path through supernet
self.supernet.sample_active_subnet()
outputs = self.supernet(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
self.trained = True
def evaluate(self, architecture) -> float:
"""Evaluate architecture using trained supernet"""
if not self.trained:
self.train_supernet()
# Configure supernet for specific architecture
self.supernet.set_active_subnet(architecture)
# Evaluate on validation set
return self._evaluate_subnet()
def _evaluate_subnet(self):
"""Evaluate current subnet configuration"""
self.supernet.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in self.dataset:
inputs, targets = batch
outputs = self.supernet(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
return correct / totalAdvanced Techniques
Progressive Search
Gradually increases search space complexity.
class ProgressiveSearch:
def __init__(self, base_search_space, max_complexity: int = 5):
self.base_search_space = base_search_space
self.max_complexity = max_complexity
self.current_complexity = 1
self.search_strategy = EvolutionarySearch()
def search(self, iterations_per_stage: int = 100):
"""Progressive search with increasing complexity"""
best_architectures = []
for complexity in range(1, self.max_complexity + 1):
self.current_complexity = complexity
# Search at current complexity level
for _ in range(iterations_per_stage):
architecture = self._sample_architecture_at_complexity()
performance = self._evaluate_architecture(architecture)
self.search_strategy.update(architecture, performance)
# Get best architecture at this complexity
best_arch = max(self.search_strategy.history,
key=lambda x: x[1])
best_architectures.append(best_arch)
return best_architectures
def _sample_architecture_at_complexity(self):
"""Sample architecture with limited complexity"""
arch = self.base_search_space.sample_architecture()
# Limit architecture complexity
arch['layers'] = arch['layers'][:self.current_complexity * 3]
return archMulti-Objective NAS
Optimizes multiple objectives simultaneously.
class MultiObjectiveNAS:
def __init__(self, objectives: List[str]):
self.objectives = objectives # e.g., ['accuracy', 'latency', 'flops']
self.pareto_front = []
def evaluate_architecture(self, architecture) -> Dict[str, float]:
"""Evaluate architecture on multiple objectives"""
results = {}
if 'accuracy' in self.objectives:
results['accuracy'] = self._evaluate_accuracy(architecture)
if 'latency' in self.objectives:
results['latency'] = self._evaluate_latency(architecture)
if 'flops' in self.objectives:
results['flops'] = self._evaluate_flops(architecture)
return results
def update_pareto_front(self, architecture, objectives):
"""Update Pareto front with new architecture"""
# Check if architecture is dominated
dominated = False
for pareto_arch, pareto_obj in self.pareto_front:
if self._dominates(pareto_obj, objectives):
dominated = True
break
if not dominated:
# Remove dominated architectures
self.pareto_front = [
(arch, obj) for arch, obj in self.pareto_front
if not self._dominates(objectives, obj)
]
# Add new architecture
self.pareto_front.append((architecture, objectives))
def _dominates(self, obj1: Dict, obj2: Dict) -> bool:
"""Check if obj1 dominates obj2"""
better_in_all = True
strictly_better_in_one = False
for objective in self.objectives:
if objective in ['accuracy']: # Higher is better
if obj1[objective] < obj2[objective]:
better_in_all = False
elif obj1[objective] > obj2[objective]:
strictly_better_in_one = True
else: # Lower is better (latency, flops)
if obj1[objective] > obj2[objective]:
better_in_all = False
elif obj1[objective] < obj2[objective]:
strictly_better_in_one = True
return better_in_all and strictly_better_in_oneImplementation Examples
Complete DARTS Implementation
class DARTSCell(nn.Module):
def __init__(self, num_nodes: int, channels: int):
super().__init__()
self.num_nodes = num_nodes
self.channels = channels
# Mixed operations for each edge
self.mixed_ops = nn.ModuleList()
for i in range(num_nodes):
for j in range(2 + i): # Each node connects to all previous nodes
self.mixed_ops.append(MixedOp(channels))
# Architecture parameters
self.alpha = nn.Parameter(torch.randn(len(self.mixed_ops), 8))
def forward(self, inputs):
# inputs[0] and inputs[1] are the two input nodes
states = [inputs[0], inputs[1]]
offset = 0
for i in range(self.num_nodes):
# Collect inputs from all previous nodes
node_inputs = []
for j in range(len(states)):
op_idx = offset + j
node_inputs.append(self.mixed_ops[op_idx](states[j], self.alpha[op_idx]))
# Sum all inputs to this node
state = sum(node_inputs)
states.append(state)
offset += len(states) - 1
# Concatenate final nodes
return torch.cat(states[-self.num_nodes:], dim=1)
class MixedOp(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.ops = nn.ModuleList([
SepConv(channels, channels, 3, 1, 1),
SepConv(channels, channels, 5, 1, 2),
DilConv(channels, channels, 3, 1, 2, 2),
DilConv(channels, channels, 5, 1, 4, 2),
nn.MaxPool2d(3, 1, 1),
nn.AvgPool2d(3, 1, 1),
Identity(),
Zero()
])
def forward(self, x, alpha):
# Apply weighted sum of operations
weights = torch.softmax(alpha, dim=0)
return sum(w * op(x) for w, op in zip(weights, self.ops))
class SepConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding,
groups=in_channels),
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class DilConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class Identity(nn.Module):
def forward(self, x):
return x
class Zero(nn.Module):
def forward(self, x):
return torch.zeros_like(x)
# Complete DARTS Network
class DARTSNetwork(nn.Module):
def __init__(self, num_classes: int, num_cells: int = 8, channels: int = 36):
super().__init__()
self.num_cells = num_cells
self.channels = channels
# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, channels, 3, 1, 1),
nn.BatchNorm2d(channels)
)
# Cells
self.cells = nn.ModuleList()
for i in range(num_cells):
if i in [num_cells // 3, 2 * num_cells // 3]:
# Reduction cell
self.cells.append(DARTSCell(4, channels))
channels *= 2
else:
# Normal cell
self.cells.append(DARTSCell(4, channels))
# Classifier
self.classifier = nn.Linear(channels, num_classes)
self.global_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.stem(x)
for cell in self.cells:
x = cell([x, x]) # Use same input for both inputs
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return xEvolutionary Search Example
class EvolutionaryNAS:
def __init__(self, population_size: int = 50, generations: int = 100):
self.population_size = population_size
self.generations = generations
self.population = []
self.fitness_history = []
def run_search(self, search_space, evaluator):
"""Run evolutionary search"""
# Initialize population
self.population = [
search_space.sample_architecture()
for _ in range(self.population_size)
]
for generation in range(self.generations):
# Evaluate population
fitness_scores = []
for individual in self.population:
fitness = evaluator.evaluate(individual)
fitness_scores.append(fitness)
self.fitness_history.append(max(fitness_scores))
# Selection and reproduction
new_population = []
for _ in range(self.population_size):
# Tournament selection
parent1 = self._tournament_selection(fitness_scores)
parent2 = self._tournament_selection(fitness_scores)
# Crossover
child = self._crossover(parent1, parent2)
# Mutation
child = self._mutate(child, search_space)
new_population.append(child)
self.population = new_population
# Return best architecture
final_fitness = [evaluator.evaluate(ind) for ind in self.population]
best_idx = np.argmax(final_fitness)
return self.population[best_idx], final_fitness[best_idx]
def _tournament_selection(self, fitness_scores, tournament_size: int = 3):
"""Tournament selection"""
tournament_indices = random.sample(range(len(fitness_scores)), tournament_size)
tournament_fitness = [fitness_scores[i] for i in tournament_indices]
winner_idx = tournament_indices[np.argmax(tournament_fitness)]
return self.population[winner_idx]
def _crossover(self, parent1, parent2):
"""Single-point crossover"""
child = parent1.copy()
if 'layers' in parent1 and 'layers' in parent2:
# Crossover layers
min_length = min(len(parent1['layers']), len(parent2['layers']))
if min_length > 1:
crossover_point = random.randint(1, min_length - 1)
child['layers'] = (parent1['layers'][:crossover_point] +
parent2['layers'][crossover_point:])
return child
def _mutate(self, individual, search_space, mutation_rate: float = 0.1):
"""Mutate individual"""
if random.random() < mutation_rate:
if 'layers' in individual and individual['layers']:
# Randomly mutate a layer
layer_idx = random.randint(0, len(individual['layers']) - 1)
layer = individual['layers'][layer_idx]
# Mutate operation
if random.random() < 0.5:
layer['operation'] = random.choice(search_space.operations)
# Mutate filters
if random.random() < 0.5:
layer['filters'] = random.choice([32, 64, 128, 256, 512])
return individual
### Reinforcement Learning NAS Example
```python
class RLNASController(nn.Module):
def __init__(self, num_layers: int = 6, lstm_size: int = 32,
num_branches: int = 6, out_filters: int = 48):
super().__init__()
self.num_layers = num_layers
self.lstm_size = lstm_size
self.num_branches = num_branches
self.out_filters = out_filters
# LSTM controller
self.lstm = nn.LSTMCell(lstm_size, lstm_size)
# Embedding layers for different architecture decisions
self.g_emb = nn.Embedding(1, lstm_size) # Go embedding
self.encoder = nn.Linear(lstm_size, lstm_size)
# Decision heads
self.conv_op = nn.Linear(lstm_size, len(CONV_OPS))
self.conv_ksize = nn.Linear(lstm_size, len(CONV_KERNEL_SIZES))
self.conv_filters = nn.Linear(lstm_size, len(CONV_FILTERS))
self.pooling_op = nn.Linear(lstm_size, len(POOLING_OPS))
self.pooling_ksize = nn.Linear(lstm_size, len(POOLING_KERNEL_SIZES))
# Initialize parameters
self.reset_parameters()
def reset_parameters(self):
init_range = 0.1
for param in self.parameters():
param.data.uniform_(-init_range, init_range)
def forward(self, batch_size: int = 1):
"""Sample architecture using the controller"""
# Initialize hidden state
h = torch.zeros(batch_size, self.lstm_size)
c = torch.zeros(batch_size, self.lstm_size)
# Start with go embedding
inputs = self.g_emb.weight.repeat(batch_size, 1)
# Store sampled architecture
arc_seq = []
entropies = []
log_probs = []
for layer_id in range(self.num_layers):
# LSTM step
h, c = self.lstm(inputs, (h, c))
# Sample convolution operation
conv_op_logits = self.conv_op(h)
conv_op_prob = F.softmax(conv_op_logits, dim=-1)
conv_op_log_prob = F.log_softmax(conv_op_logits, dim=-1)
conv_op_entropy = -(conv_op_log_prob * conv_op_prob).sum(1, keepdim=True)
conv_op_sample = torch.multinomial(conv_op_prob, 1)
conv_op_sample = conv_op_sample.view(-1)
arc_seq.append(conv_op_sample)
entropies.append(conv_op_entropy)
log_probs.append(conv_op_log_prob.gather(1, conv_op_sample.unsqueeze(1)))
# Sample kernel size
conv_ksize_logits = self.conv_ksize(h)
conv_ksize_prob = F.softmax(conv_ksize_logits, dim=-1)
conv_ksize_log_prob = F.log_softmax(conv_ksize_logits, dim=-1)
conv_ksize_entropy = -(conv_ksize_log_prob * conv_ksize_prob).sum(1, keepdim=True)
conv_ksize_sample = torch.multinomial(conv_ksize_prob, 1)
conv_ksize_sample = conv_ksize_sample.view(-1)
arc_seq.append(conv_ksize_sample)
entropies.append(conv_ksize_entropy)
log_probs.append(conv_ksize_log_prob.gather(1, conv_ksize_sample.unsqueeze(1)))
# Continue for other decisions...
inputs = h # Use current hidden state as input for next step
return arc_seq, torch.cat(log_probs), torch.cat(entropies)
# Constants for architecture choices
CONV_OPS = ['conv', 'depthwise_conv', 'separable_conv']
CONV_KERNEL_SIZES = [3, 5, 7]
CONV_FILTERS = [24, 36, 48, 64]
POOLING_OPS = ['max_pool', 'avg_pool', 'no_pool']
POOLING_KERNEL_SIZES = [2, 3]
class RLNASTrainer:
def __init__(self, controller, child_model_builder, evaluator):
self.controller = controller
self.child_model_builder = child_model_builder
self.evaluator = evaluator
# Controller optimizer
self.controller_optimizer = torch.optim.Adam(
controller.parameters(), lr=3.5e-4
)
# Baseline for variance reduction
self.baseline = None
self.baseline_decay = 0.99
def train_controller(self, num_epochs: int = 2000):
"""Train the controller using REINFORCE"""
for epoch in range(num_epochs):
# Sample architectures
arc_seq, log_probs, entropies = self.controller()
# Build and evaluate child model
child_model = self.child_model_builder.build(arc_seq)
reward = self.evaluator.evaluate(child_model)
# Update baseline
if self.baseline is None:
self.baseline = reward
else:
self.baseline = self.baseline_decay * self.baseline + \
(1 - self.baseline_decay) * reward
# Compute advantage
advantage = reward - self.baseline
# Controller loss (REINFORCE)
controller_loss = -log_probs * advantage
controller_loss = controller_loss.sum()
# Add entropy regularization
entropy_penalty = -entropies.sum() * 1e-4
total_loss = controller_loss + entropy_penalty
# Update controller
self.controller_optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 5.0)
self.controller_optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Reward: {reward:.4f}, '
f'Baseline: {self.baseline:.4f}, Loss: {total_loss.item():.4f}')Best Practices
1. Search Space Design Guidelines
class SearchSpaceDesignPrinciples:
"""
Guidelines for designing effective search spaces
"""
def __init__(self):
self.principles = {
'expressiveness': 'Include diverse operations and connections',
'efficiency': 'Balance search space size with computational cost',
'human_knowledge': 'Incorporate domain-specific insights',
'scalability': 'Design for different input sizes and tasks'
}
def design_macro_space(self, task_type: str):
"""Design macro search space based on task"""
if task_type == 'image_classification':
return {
'operations': ['conv3x3', 'conv5x5', 'depthwise_conv', 'pointwise_conv',
'max_pool', 'avg_pool', 'global_pool', 'identity'],
'max_layers': 20,
'channels': [16, 32, 64, 128, 256, 512],
'skip_connections': True,
'batch_norm': True,
'activation': ['relu', 'relu6', 'swish']
}
elif task_type == 'object_detection':
return {
'operations': ['conv3x3', 'conv5x5', 'depthwise_conv', 'atrous_conv',
'max_pool', 'avg_pool', 'identity'],
'max_layers': 30,
'channels': [32, 64, 128, 256, 512, 1024],
'skip_connections': True,
'fpn_layers': True,
'anchor_scales': [32, 64, 128, 256, 512]
}
def validate_search_space(self, search_space):
"""Validate search space design"""
issues = []
# Check for minimal viable operations
if len(search_space.get('operations', [])) < 3:
issues.append("Too few operations - may limit expressiveness")
# Check for identity operation
if 'identity' not in search_space.get('operations', []):
issues.append("Missing identity operation - may hurt skip connections")
# Check channel progression
channels = search_space.get('channels', [])
if channels and not all(channels[i] <= channels[i+1] for i in range(len(channels)-1)):
issues.append("Non-monotonic channel progression")
return issues2. Training Strategies
class NASTrainingStrategies:
"""Advanced training strategies for NAS"""
def __init__(self):
self.strategies = {}
def progressive_shrinking(self, supernet, dataset, stages: int = 4):
"""Progressive shrinking strategy"""
current_channels = supernet.max_channels
for stage in range(stages):
# Reduce search space
target_channels = current_channels // (2 ** stage)
supernet.set_channel_constraint(target_channels)
# Train for this stage
self._train_stage(supernet, dataset, epochs=50)
print(f"Stage {stage + 1}: Max channels = {target_channels}")
def sandwich_sampling(self, supernet, dataset):
"""Sandwich sampling for training efficiency"""
optimizer = torch.optim.SGD(supernet.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
for batch in dataset:
inputs, targets = batch
# Sample architectures: largest, smallest, and random
architectures = [
supernet.largest_architecture(),
supernet.smallest_architecture(),
supernet.random_architecture(),
supernet.random_architecture()
]
total_loss = 0
for arch in architectures:
supernet.set_active_subnet(arch)
optimizer.zero_grad()
outputs = supernet(inputs)
loss = criterion(outputs, targets)
loss.backward()
total_loss += loss.item()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {total_loss/len(architectures):.4f}")
def knowledge_distillation(self, student_arch, teacher_model, dataset):
"""Knowledge distillation for architecture evaluation"""
student_model = self._build_model(student_arch)
optimizer = torch.optim.SGD(student_model.parameters(), lr=0.01)
kd_loss = nn.KLDivLoss()
ce_loss = nn.CrossEntropyLoss()
alpha = 0.7 # Distillation weight
temperature = 4.0
for epoch in range(50):
for batch in dataset:
inputs, targets = batch
# Teacher predictions
with torch.no_grad():
teacher_outputs = teacher_model(inputs)
teacher_probs = F.softmax(teacher_outputs / temperature, dim=1)
# Student predictions
student_outputs = student_model(inputs)
student_log_probs = F.log_softmax(student_outputs / temperature, dim=1)
# Combined loss
distill_loss = kd_loss(student_log_probs, teacher_probs)
hard_loss = ce_loss(student_outputs, targets)
total_loss = alpha * distill_loss + (1 - alpha) * hard_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return self._evaluate_model(student_model)3. Evaluation and Benchmarking
class NASBenchmarking:
"""Benchmarking and evaluation utilities"""
def __init__(self):
self.metrics = {}
def comprehensive_evaluation(self, architecture, datasets):
"""Comprehensive evaluation across multiple metrics"""
results = {}
# Build model
model = self._build_model(architecture)
# Accuracy metrics
for dataset_name, dataset in datasets.items():
accuracy = self._evaluate_accuracy(model, dataset)
results[f'{dataset_name}_accuracy'] = accuracy
# Efficiency metrics
results['params'] = self._count_parameters(model)
results['flops'] = self._count_flops(model)
results['latency'] = self._measure_latency(model)
results['memory'] = self._measure_memory(model)
# Robustness metrics
results['adversarial_robustness'] = self._evaluate_adversarial_robustness(model)
results['noise_robustness'] = self._evaluate_noise_robustness(model)
return results
def _count_parameters(self, model):
"""Count model parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def _count_flops(self, model, input_size=(1, 3, 224, 224)):
"""Count FLOPs using a simple profiler"""
def flop_count_hook(module, input, output):
if isinstance(module, nn.Conv2d):
# Conv2d FLOPs
batch_size, in_channels, input_height, input_width = input[0].shape
output_height, output_width = output.shape[2], output.shape[3]
kernel_height, kernel_width = module.kernel_size
flops = batch_size * in_channels * kernel_height * kernel_width * \
output_height * output_width * module.out_channels
if hasattr(module, 'flops'):
module.flops += flops
else:
module.flops = flops
# Register hooks
hooks = []
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
hooks.append(module.register_forward_hook(flop_count_hook))
# Forward pass
model.eval()
with torch.no_grad():
dummy_input = torch.randn(input_size)
model(dummy_input)
# Collect FLOPs
total_flops = 0
for module in model.modules():
if hasattr(module, 'flops'):
total_flops += module.flops
# Remove hooks
for hook in hooks:
hook.remove()
return total_flops
def _measure_latency(self, model, input_size=(1, 3, 224, 224), runs=100):
"""Measure inference latency"""
model.eval()
dummy_input = torch.randn(input_size)
# Warmup
for _ in range(10):
with torch.no_grad():
model(dummy_input)
# Measure
torch.cuda.synchronize() if torch.cuda.is_available() else None
start_time = time.time()
for _ in range(runs):
with torch.no_grad():
model(dummy_input)
torch.cuda.synchronize() if torch.cuda.is_available() else None
end_time = time.time()
return (end_time - start_time) / runs
def compare_search_methods(self, methods, search_space, evaluator, runs=5):
"""Compare different search methods"""
results = {}
for method_name, method in methods.items():
method_results = []
for run in range(runs):
# Set random seed for reproducibility
torch.manual_seed(run)
random.seed(run)
np.random.seed(run)
# Run search
best_arch, best_performance = method.search(search_space, evaluator)
method_results.append({
'architecture': best_arch,
'performance': best_performance,
'run': run
})
results[method_name] = method_results
return self._analyze_comparison_results(results)
def _analyze_comparison_results(self, results):
"""Analyze comparison results"""
analysis = {}
for method_name, method_results in results.items():
performances = [r['performance'] for r in method_results]
analysis[method_name] = {
'mean_performance': np.mean(performances),
'std_performance': np.std(performances),
'best_performance': np.max(performances),
'worst_performance': np.min(performances),
'median_performance': np.median(performances)
}
# Rank methods
ranked_methods = sorted(analysis.items(),
key=lambda x: x[1]['mean_performance'],
reverse=True)
return {
'detailed_results': analysis,
'ranking': ranked_methods,
'summary': self._generate_summary(ranked_methods)
}
def _generate_summary(self, ranked_methods):
"""Generate summary of comparison"""
summary = []
summary.append("=== NAS Method Comparison Results ===")
for i, (method_name, stats) in enumerate(ranked_methods):
summary.append(f"{i+1}. {method_name}:")
summary.append(f" Mean: {stats['mean_performance']:.4f}")
summary.append(f" Std: {stats['std_performance']:.4f}")
summary.append(f" Best: {stats['best_performance']:.4f}")
summary.append("")
return "\n".join(summary)Tools and Frameworks
Popular NAS Libraries
class NASFrameworkGuide:
"""Guide to popular NAS frameworks"""
def __init__(self):
self.frameworks = {
'nni': {
'description': 'Microsoft\'s Neural Network Intelligence toolkit',
'strengths': ['Easy to use', 'Multiple search algorithms', 'Good documentation'],
'installation': 'pip install nni',
'example_usage': '''
from nni.nas.pytorch import DartsTrainer
from nni.nas.pytorch.search_space_zoo import ENASMacroSearchSpace
# Define search space
search_space = ENASMacroSearchSpace()
# Create trainer
trainer = DartsTrainer(
model=search_space,
loss=nn.CrossEntropyLoss(),
optimizer=torch.optim.SGD(search_space.parameters(), lr=0.1)
)
# Train
trainer.train()
'''
},
'automl': {
'description': 'Google\'s AutoML toolkit',
'strengths': ['State-of-the-art methods', 'Research-oriented'],
'installation': 'Custom installation from GitHub',
'example_usage': '''
# Example for AdaNet
import adanet
# Define search space and estimator
estimator = adanet.Estimator(
head=head,
subnetwork_generator=generator,
max_iteration_steps=1000
)
# Train
estimator.train(input_fn=train_input_fn)
'''
},
'optuna': {
'description': 'Hyperparameter optimization framework',
'strengths': ['Flexible', 'Multiple optimization algorithms', 'Good for hyperparameter tuning'],
'installation': 'pip install optuna',
'example_usage': '''
import optuna
def objective(trial):
# Define architecture parameters
n_layers = trial.suggest_int('n_layers', 2, 8)
n_filters = trial.suggest_int('n_filters', 16, 128)
# Build and train model
model = build_model(n_layers, n_filters)
accuracy = train_and_evaluate(model)
return accuracy
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
'''
},
'ray_tune': {
'description': 'Distributed hyperparameter tuning',
'strengths': ['Scalable', 'Multiple search algorithms', 'Good for distributed training'],
'installation': 'pip install ray[tune]',
'example_usage': '''
from ray import tune
from ray.tune.schedulers import ASHAScheduler
def trainable(config):
model = build_model(config)
accuracy = train_model(model)
tune.report(accuracy=accuracy)
scheduler = ASHAScheduler(metric="accuracy", mode="max")
result = tune.run(
trainable,
resources_per_trial={"cpu": 2, "gpu": 1},
config={
"n_layers": tune.choice([2, 4, 6, 8]),
"n_filters": tune.choice([16, 32, 64, 128])
},
scheduler=scheduler
)
'''
}
}
def get_framework_recommendation(self, use_case: str):
"""Get framework recommendation based on use case"""
recommendations = {
'research': ['automl', 'custom_implementation'],
'production': ['nni', 'ray_tune'],
'hyperparameter_tuning': ['optuna', 'ray_tune'],
'distributed_training': ['ray_tune'],
'beginner_friendly': ['nni', 'optuna']
}
return recommendations.get(use_case, ['nni'])
# Example: Custom NAS implementation using PyTorch
class CustomNASExample:
def __init__(self):
self.search_space = None
self.search_strategy = None
self.evaluator = None
def setup_cifar10_nas(self):
"""Setup NAS for CIFAR-10"""
# Define search space
self.search_space = CellSearchSpace(num_nodes=4)
# Define search strategy
self.search_strategy = EvolutionarySearch(population_size=20)
# Define evaluator
self.evaluator = EarlyStoppingEvaluator(
dataset=self._get_cifar10_dataset(),
max_epochs=10,
patience=3
)
def run_nas_experiment(self):
"""Run complete NAS experiment"""
framework = NASFramework(
search_space=self.search_space,
search_strategy=self.search_strategy,
performance_estimator=self.evaluator
)
# Run search
best_result = framework.search(num_iterations=100)
# Analyze results
print(f"Best architecture: {best_result['architecture']}")
print(f"Best performance: {best_result['performance']:.4f}")
return best_result
def _get_cifar10_dataset(self):
"""Get CIFAR-10 dataset"""
# Implementation depends on your data loading setup
passThis comprehensive guide covers the essential aspects of Neural Architecture Search, from theoretical foundations to practical implementations. The code examples provide a solid foundation for understanding and implementing NAS algorithms, while the best practices and framework recommendations help guide practical applications.
The key takeaways from this guide are:
- NAS Framework: Understanding the three core components (search space, search strategy, performance estimation) is crucial
- Search Space Design: Careful design of search spaces balances expressiveness with computational efficiency
- Search Strategies: Different strategies have different trade-offs between exploration and exploitation
- Performance Estimation: Efficient evaluation methods are essential for practical NAS
- Implementation: Modern frameworks provide good starting points, but custom implementations offer more control
- Best Practices: Following established guidelines improves NAS effectiveness and reproducibility
For beginners, I recommend starting with existing frameworks like NNI or Optuna, then gradually moving to custom implementations as understanding deepens. For research applications, implementing methods from scratch using the patterns shown in this guide provides the most flexibility and insight into the algorithm