Optuna for Deep Learning and Neural Architecture Search: A Comprehensive Guide

Introduction
Hyperparameter optimization is one of the most critical yet challenging aspects of deep learning. With the exponential growth in model complexity and the vast hyperparameter search spaces, manual tuning becomes impractical. Optuna, developed by Preferred Networks, emerges as a powerful automatic hyperparameter optimization framework that addresses these challenges with sophisticated algorithms and intuitive APIs.
This comprehensive guide explores how Optuna revolutionizes deep learning workflows, from basic hyperparameter tuning to advanced neural architecture search (NAS), providing practical implementations and real-world optimization strategies.
What is Optuna?
Optuna is an open-source hyperparameter optimization framework designed for machine learning. It offers several key advantages:
- Efficient Sampling: Uses Tree-structured Parzen Estimator (TPE) and other advanced algorithms
- Pruning: Automatically stops unpromising trials early
- Distributed Optimization: Supports parallel and distributed hyperparameter search
- Framework Agnostic: Works with PyTorch, TensorFlow, Keras, and other ML frameworks
- Visualization: Rich dashboard for monitoring optimization progress
Core Concepts
Studies and Trials
In Optuna terminology:
- Study: An optimization session that tries to find optimal hyperparameters
- Trial: A single execution of the objective function with specific hyperparameter values
- Objective Function: The function to optimize (typically validation loss or accuracy)
Sampling Algorithms
Optuna implements several sophisticated sampling strategies:
- TPE (Tree-structured Parzen Estimator): Default algorithm that models the probability distribution of hyperparameters
- Random Sampling: Baseline method for comparison
- Grid Search: Exhaustive search over specified parameter combinations
- CMA-ES: Covariance Matrix Adaptation Evolution Strategy for continuous optimization
Pruning Algorithms
Pruning eliminates unpromising trials early:
- Median Pruner: Prunes trials below the median performance
- Successive Halving: Allocates resources progressively to promising trials
- Hyperband: Combines successive halving with different resource allocations
Installation and Setup
pip install optuna
pip install optuna-dashboard # Optional: for visualizationFor specific deep learning frameworks:
pip install torch torchvision # PyTorch
pip install tensorflow # TensorFlow
pip install optuna[integration] # Framework integrationsBasic Hyperparameter Optimization
Simple PyTorch Example
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def create_model(trial):
# Suggest hyperparameters
n_layers = trial.suggest_int('n_layers', 1, 3)
n_units = trial.suggest_int('n_units', 64, 512)
dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
layers = []
in_features = 784 # MNIST input size
for i in range(n_layers):
layers.append(nn.Linear(in_features, n_units))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
in_features = n_units
layers.append(nn.Linear(in_features, 10)) # Output layer
return nn.Sequential(*layers)
def objective(trial):
# Model hyperparameters
model = create_model(trial)
# Optimizer hyperparameters
lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
if optimizer_name == 'Adam':
optimizer = optim.Adam(model.parameters(), lr=lr)
elif optimizer_name == 'SGD':
momentum = trial.suggest_float('momentum', 0.0, 0.99)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
else: # RMSprop
optimizer = optim.RMSprop(model.parameters(), lr=lr)
# Data loading
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = datasets.MNIST('data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128)
# Training
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(10):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 784), target
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Optional: Report intermediate values for pruning
if batch_idx % 100 == 0:
trial.report(loss.item(), epoch * len(train_loader) + batch_idx)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.view(-1, 784), target
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = correct / total
return accuracy
# Create study and optimize
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
print(f"Best trial: {study.best_trial.value}")
print(f"Best params: {study.best_params}")Advanced Hyperparameter Optimization
Multi-Objective Optimization
def multi_objective_function(trial):
# Suggest hyperparameters
n_layers = trial.suggest_int('n_layers', 1, 5)
n_units = trial.suggest_int('n_units', 32, 512)
dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
# Create and train model (simplified)
model = create_model(trial)
accuracy = train_and_evaluate(model)
# Calculate model complexity (number of parameters)
model_size = sum(p.numel() for p in model.parameters())
# Return multiple objectives
return accuracy, -model_size # Maximize accuracy, minimize model size
# Multi-objective study
study = optuna.create_study(directions=['maximize', 'maximize'])
study.optimize(multi_objective_function, n_trials=100)
# Get Pareto front
pareto_front = study.best_trials
for trial in pareto_front:
print(f"Trial {trial.number}: Accuracy={trial.values[0]:.3f}, "
f"Model Size={-trial.values[1]}")Conditional Hyperparameters
def conditional_objective(trial):
# Main architecture choice
model_type = trial.suggest_categorical('model_type', ['CNN', 'ResNet', 'DenseNet'])
if model_type == 'CNN':
# CNN-specific parameters
n_conv_layers = trial.suggest_int('n_conv_layers', 2, 4)
kernel_size = trial.suggest_categorical('kernel_size', [3, 5, 7])
n_filters = trial.suggest_int('n_filters', 32, 128)
model = create_cnn(n_conv_layers, kernel_size, n_filters)
elif model_type == 'ResNet':
# ResNet-specific parameters
depth = trial.suggest_categorical('depth', [18, 34, 50])
width_multiplier = trial.suggest_float('width_multiplier', 0.5, 2.0)
model = create_resnet(depth, width_multiplier)
else: # DenseNet
# DenseNet-specific parameters
growth_rate = trial.suggest_int('growth_rate', 12, 48)
block_config = trial.suggest_categorical('block_config',
[(6, 12, 24, 16), (6, 12, 32, 32)])
model = create_densenet(growth_rate, block_config)
# Common hyperparameters
lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
return train_and_evaluate(model, lr, batch_size)Neural Architecture Search (NAS)
Basic NAS Implementation
import torch.nn as nn
class SearchableBlock(nn.Module):
def __init__(self, in_channels, out_channels, trial, block_id):
super().__init__()
self.block_id = block_id
# Searchable operations
op_name = trial.suggest_categorical(f'op_{block_id}', [
'conv3x3', 'conv5x5', 'conv7x7', 'depthwise_conv', 'skip_connect'
])
if op_name == 'conv3x3':
self.op = nn.Conv2d(in_channels, out_channels, 3, padding=1)
elif op_name == 'conv5x5':
self.op = nn.Conv2d(in_channels, out_channels, 5, padding=2)
elif op_name == 'conv7x7':
self.op = nn.Conv2d(in_channels, out_channels, 7, padding=3)
elif op_name == 'depthwise_conv':
self.op = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
nn.Conv2d(in_channels, out_channels, 1)
)
else: # skip_connect
self.op = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, 1)
# Activation and normalization
self.activation = trial.suggest_categorical(f'activation_{block_id}',
['relu', 'gelu', 'swish'])
self.use_batch_norm = trial.suggest_categorical(f'batch_norm_{block_id}', [True, False])
if self.use_batch_norm:
self.bn = nn.BatchNorm2d(out_channels)
if self.activation == 'relu':
self.act = nn.ReLU()
elif self.activation == 'gelu':
self.act = nn.GELU()
else: # swish
self.act = nn.SiLU()
def forward(self, x):
out = self.op(x)
if self.use_batch_norm:
out = self.bn(out)
out = self.act(out)
return out
class SearchableNet(nn.Module):
def __init__(self, trial, num_classes=10):
super().__init__()
# Search for overall architecture
num_stages = trial.suggest_int('num_stages', 3, 5)
base_channels = trial.suggest_int('base_channels', 32, 128)
# Build searchable architecture
self.stages = nn.ModuleList()
in_channels = 3
for stage in range(num_stages):
# Number of blocks in this stage
num_blocks = trial.suggest_int(f'num_blocks_stage_{stage}', 1, 4)
# Channel progression
out_channels = base_channels * (2 ** stage)
stage_blocks = nn.ModuleList()
for block in range(num_blocks):
block_id = f'stage_{stage}_block_{block}'
stage_blocks.append(SearchableBlock(in_channels, out_channels, trial, block_id))
in_channels = out_channels
self.stages.append(stage_blocks)
# Downsampling between stages
if stage < num_stages - 1:
self.stages.append(nn.MaxPool2d(2))
# Global pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(in_channels, num_classes)
# Dropout
dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.5)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
for stage in self.stages:
if isinstance(stage, nn.ModuleList):
for block in stage:
x = block(x)
else:
x = stage(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.classifier(x)
return x
def nas_objective(trial):
# Create searchable model
model = SearchableNet(trial)
# Training hyperparameters
lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# Data augmentation search
use_cutmix = trial.suggest_categorical('use_cutmix', [True, False])
use_mixup = trial.suggest_categorical('use_mixup', [True, False])
# Train and evaluate
accuracy = train_model_with_augmentation(model, optimizer, use_cutmix, use_mixup)
return accuracy
# Run NAS
study = optuna.create_study(direction='maximize',
pruner=optuna.pruners.MedianPruner())
study.optimize(nas_objective, n_trials=200)Advanced NAS with Weight Sharing
class SuperNet(nn.Module):
"""Supernet that contains all possible operations"""
def __init__(self, num_classes=10):
super().__init__()
# Define all possible operations
self.operations = nn.ModuleDict({
'conv3x3': nn.Conv2d(64, 64, 3, padding=1),
'conv5x5': nn.Conv2d(64, 64, 5, padding=2),
'conv7x7': nn.Conv2d(64, 64, 7, padding=3),
'depthwise_conv': nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1, groups=64),
nn.Conv2d(64, 64, 1)
),
'skip_connect': nn.Identity()
})
self.stem = nn.Conv2d(3, 64, 3, padding=1)
self.classifier = nn.Linear(64, num_classes)
self.global_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x, architecture):
"""Forward pass with specific architecture"""
x = self.stem(x)
for i, op_name in enumerate(architecture):
x = self.operations[op_name](x)
if i % 2 == 0: # Add downsampling periodically
x = F.max_pool2d(x, 2)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def progressive_nas_objective(trial):
"""NAS with progressive shrinking"""
# Sample architecture
num_blocks = trial.suggest_int('num_blocks', 4, 8)
architecture = []
for i in range(num_blocks):
op = trial.suggest_categorical(f'op_{i}', [
'conv3x3', 'conv5x5', 'conv7x7', 'depthwise_conv', 'skip_connect'
])
architecture.append(op)
# Create supernet (shared across trials)
if not hasattr(progressive_nas_objective, 'supernet'):
progressive_nas_objective.supernet = SuperNet()
model = progressive_nas_objective.supernet
# Training with early stopping
accuracy = train_with_early_stopping(model, architecture, trial)
return accuracyDistributed Optimization
Multi-GPU Training with Optuna
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import pytorch_lightning as pl
class LightningModel(pl.LightningModule):
def __init__(self, trial):
super().__init__()
self.trial = trial
# Architecture hyperparameters
self.lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
self.batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
# Model definition
self.model = self.build_model(trial)
self.criterion = nn.CrossEntropyLoss()
def build_model(self, trial):
# Build model based on trial suggestions
pass
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.criterion(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_loss', loss, sync_dist=True)
self.log('val_acc', acc, sync_dist=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def distributed_objective(trial):
model = LightningModel(trial)
# Pruning callback
pruning_callback = PyTorchLightningPruningCallback(trial, monitor='val_acc')
# Multi-GPU trainer
trainer = pl.Trainer(
gpus=4,
strategy='ddp',
max_epochs=50,
callbacks=[pruning_callback],
enable_checkpointing=False
)
trainer.fit(model, train_dataloader, val_dataloader)
return trainer.callback_metrics['val_acc'].item()
# Distributed study
study = optuna.create_study(
direction='maximize',
storage='sqlite:///distributed_study.db', # Shared storage
study_name='distributed_nas'
)
study.optimize(distributed_objective, n_trials=500)Advanced Techniques
Hyperband Integration
class HyperbandPruner(optuna.pruners.BasePruner):
def __init__(self, min_resource=1, max_resource=81, reduction_factor=3):
self.min_resource = min_resource
self.max_resource = max_resource
self.reduction_factor = reduction_factor
def prune(self, study, trial):
# Hyperband logic implementation
pass
def hyperband_objective(trial):
# Suggest resource budget
resource_budget = trial.suggest_int('resource_budget', 1, 81)
# Train for suggested epochs
model = create_model(trial)
accuracy = train_model(model, epochs=resource_budget)
return accuracy
study = optuna.create_study(
direction='maximize',
pruner=HyperbandPruner()
)Population-Based Training
def population_based_optimization():
population_size = 20
generations = 10
# Initialize population
population = []
for i in range(population_size):
trial = optuna.trial.create_trial(
params={
'lr': np.random.uniform(1e-5, 1e-1),
'batch_size': np.random.choice([16, 32, 64, 128]),
'weight_decay': np.random.uniform(1e-6, 1e-2)
}
)
population.append(trial)
for generation in range(generations):
# Evaluate population
fitness_scores = []
for trial in population:
model = create_model(trial)
score = train_and_evaluate(model, trial.params)
fitness_scores.append(score)
# Select top performers
top_indices = np.argsort(fitness_scores)[-population_size//2:]
# Create new population
new_population = []
for idx in top_indices:
new_population.append(population[idx])
# Mutate and add to population
for i in range(population_size - len(new_population)):
parent = np.random.choice(new_population)
child = mutate_hyperparameters(parent)
new_population.append(child)
population = new_population
return populationVisualization and Analysis
Study Analysis
# Basic study analysis
print(f"Number of finished trials: {len(study.trials)}")
print(f"Best trial: {study.best_trial.number}")
print(f"Best value: {study.best_value}")
print(f"Best parameters: {study.best_params}")
# Parameter importance
importance = optuna.importance.get_param_importances(study)
print("Parameter importance:")
for param, imp in importance.items():
print(f" {param}: {imp:.4f}")
# Visualization
import optuna.visualization as vis
# Optimization history
fig = vis.plot_optimization_history(study)
fig.show()
# Parameter importance plot
fig = vis.plot_param_importances(study)
fig.show()
# Parameter relationships
fig = vis.plot_parallel_coordinate(study)
fig.show()
# Hyperparameter slice plot
fig = vis.plot_slice(study)
fig.show()Custom Metrics Tracking
class CustomCallback:
def __init__(self):
self.metrics = {}
def __call__(self, study, trial):
# Track custom metrics
self.metrics[trial.number] = {
'params': trial.params,
'value': trial.value,
'state': trial.state,
'duration': trial.duration
}
# Custom analysis
if len(study.trials) % 10 == 0:
self.analyze_progress(study)
def analyze_progress(self, study):
# Convergence analysis
values = [t.value for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
if len(values) > 10:
improvement = values[-1] - values[-11]
print(f"Improvement over last 10 trials: {improvement:.4f}")
# Use custom callback
callback = CustomCallback()
study.optimize(objective, n_trials=100, callbacks=[callback])Best Practices and Tips
Study Configuration
# Optimal study configuration
study = optuna.create_study(
direction='maximize',
sampler=optuna.samplers.TPESampler(
n_startup_trials=20, # Random trials before TPE
n_ei_candidates=24, # Candidates for EI
multivariate=True, # Consider parameter interactions
seed=42 # Reproducibility
),
pruner=optuna.pruners.MedianPruner(
n_startup_trials=10, # Minimum trials before pruning
n_warmup_steps=5, # Steps before considering pruning
interval_steps=1 # Frequency of pruning checks
)
)Memory Management
def memory_efficient_objective(trial):
# Clear GPU memory
torch.cuda.empty_cache()
# Use gradient checkpointing
model = create_model(trial)
model.gradient_checkpointing_enable()
# Mixed precision training
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
# Training loop
pass
# Cleanup
del model
torch.cuda.empty_cache()
return accuracyReproducibility
def set_seed(seed=42):
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def reproducible_objective(trial):
# Set seed for reproducibility
set_seed(trial.suggest_int('seed', 0, 10000))
# Rest of objective function
passReal-World Applications
Computer Vision NAS
def vision_nas_objective(trial):
# Data augmentation search
augmentation_policy = {
'rotation': trial.suggest_float('rotation', 0, 30),
'brightness': trial.suggest_float('brightness', 0.8, 1.2),
'contrast': trial.suggest_float('contrast', 0.8, 1.2),
'saturation': trial.suggest_float('saturation', 0.8, 1.2),
'hue': trial.suggest_float('hue', -0.1, 0.1)
}
# Architecture search
backbone = trial.suggest_categorical('backbone', ['resnet', 'efficientnet', 'mobilenet'])
if backbone == 'resnet':
depth = trial.suggest_categorical('depth', [18, 34, 50, 101])
model = create_resnet(depth)
elif backbone == 'efficientnet':
version = trial.suggest_categorical('version', ['b0', 'b1', 'b2', 'b3'])
model = create_efficientnet(version)
else:
width_mult = trial.suggest_float('width_mult', 0.25, 2.0)
model = create_mobilenet(width_mult)
# Training strategy
training_strategy = {
'optimizer': trial.suggest_categorical('optimizer', ['adam', 'sgd', 'adamw']),
'lr_schedule': trial.suggest_categorical('lr_schedule', ['cosine', 'step', 'exponential']),
'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
}
return train_vision_model(model, augmentation_policy, training_strategy)NLP Architecture Search
def nlp_nas_objective(trial):
# Transformer architecture search
config = {
'num_layers': trial.suggest_int('num_layers', 4, 12),
'num_heads': trial.suggest_categorical('num_heads', [4, 8, 12, 16]),
'hidden_size': trial.suggest_categorical('hidden_size', [256, 512, 768, 1024]),
'ffn_size': trial.suggest_categorical('ffn_size', [1024, 2048, 3072, 4096]),
'dropout': trial.suggest_float('dropout', 0.0, 0.3),
'attention_dropout': trial.suggest_float('attention_dropout', 0.0, 0.3)
}
# Positional encoding
pos_encoding = trial.suggest_categorical('pos_encoding', ['learned', 'sinusoidal', 'rotary'])
# Activation function
activation = trial.suggest_categorical('activation', ['gelu', 'relu', 'swish'])
model = create_transformer(config, pos_encoding, activation)
# Training hyperparameters
lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
warmup_steps = trial.suggest_int('warmup_steps', 1000, 10000)
return train_nlp_model(model, lr, warmup_steps)Conclusion
Optuna provides a powerful, flexible framework for hyperparameter optimization and neural architecture search in deep learning. Its sophisticated algorithms, pruning capabilities, and extensive integration ecosystem make it an essential tool for modern ML practitioners.
Key takeaways:
- Start Simple: Begin with basic hyperparameter optimization before moving to complex NAS
- Use Pruning: Implement pruning to save computational resources
- Leverage Distributed Computing: Scale optimization across multiple GPUs/nodes
- Monitor Progress: Use visualization tools to understand optimization dynamics
- Consider Multi-Objective: Balance multiple criteria like accuracy and efficiency
- Reproducibility: Set seeds and use consistent evaluation protocols
The future of automated ML lies in intelligent optimization frameworks like Optuna, which democratize access to state-of-the-art hyperparameter tuning and architecture search techniques. By mastering these tools, practitioners can focus on higher-level design decisions while letting algorithms handle the tedious parameter optimization process.
Whether you’re working on computer vision, NLP, or other domains, Optuna’s flexibility and power make it an invaluable addition to your deep learning toolkit. Start with the basic examples provided here, then gradually incorporate more advanced techniques as your optimization needs grow in complexity.