PyTorch Training and Inference Optimization Guide
The guide includes practical code examples you can directly use in your projects, along with best practices and common pitfalls to avoid. Each section builds upon the previous ones, so you can implement these optimizations incrementally based on your specific needs and performance requirements.
Table of Contents
General Optimization Principles
1. Use the Right Data Types
import torch
# Use half precision when possible (reduces memory and increases speed)
= model.half() # Convert to float16
model # Or use mixed precision training
from torch.cuda.amp import autocast, GradScaler
# Use appropriate tensor types
= torch.tensor(data, dtype=torch.float32) # Explicit dtype x
2. Optimize Data Loading
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
# Optimize DataLoader
= DataLoader(
train_loader
dataset,=32,
batch_size=True,
shuffle=4, # Use multiple workers
num_workers=True, # Faster GPU transfer
pin_memory=True, # Keep workers alive
persistent_workers=2 # Prefetch batches
prefetch_factor
)
# Use non_blocking transfers
for batch in train_loader:
= batch[0].to(device, non_blocking=True)
data = batch[1].to(device, non_blocking=True) target
3. Tensor Operations Best Practices
# Avoid unnecessary CPU-GPU transfers
= torch.randn(1000, 1000, device='cuda') # Create directly on GPU
x
# Use in-place operations when possible
# Instead of x = x + y
x.add_(y) 2) # Instead of x = x * 2
x.mul_(
# Batch operations instead of loops
# Bad
for i in range(batch_size):
= model(x[i])
result[i]
# Good
= model(x) # Process entire batch result
Training Optimizations
1. Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
= MyModel().cuda()
model = torch.optim.Adam(model.parameters())
optimizer = GradScaler()
scaler
for epoch in range(num_epochs):
for batch in train_loader:
optimizer.zero_grad()
# Forward pass with autocast
with autocast():
= model(inputs)
outputs = criterion(outputs, targets)
loss
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer) scaler.update()
2. Gradient Accumulation
= 4
accumulation_steps
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
with autocast():
= model(inputs)
outputs = criterion(outputs, targets) / accumulation_steps
loss
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update() optimizer.zero_grad()
3. Efficient Learning Rate Scheduling
from torch.optim.lr_scheduler import OneCycleLR
= torch.optim.AdamW(model.parameters(), lr=0.001)
optimizer = OneCycleLR(
scheduler
optimizer,=0.01,
max_lr=num_epochs,
epochs=len(train_loader)
steps_per_epoch
)
# Use scheduler after each batch for OneCycleLR
for batch in train_loader:
# ... training step ...
scheduler.step()
4. Model Compilation (PyTorch 2.0+)
# Compile model for faster training
= torch.compile(model)
model
# Different modes for different use cases
= torch.compile(model, mode="reduce-overhead") # For large models
model = torch.compile(model, mode="max-autotune") # For maximum performance model
5. Checkpoint and Resume Training
def save_checkpoint(model, optimizer, epoch, loss, filename):
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, filename)
def load_checkpoint(model, optimizer, filename):
= torch.load(filename)
checkpoint 'model_state_dict'])
model.load_state_dict(checkpoint['optimizer_state_dict'])
optimizer.load_state_dict(checkpoint[return checkpoint['epoch'], checkpoint['loss']
Inference Optimizations
1. Model Optimization for Inference
# Set model to evaluation mode
eval()
model.
# Disable gradient computation
with torch.no_grad():
= model(inputs)
outputs
# Use torch.inference_mode() for even better performance
with torch.inference_mode():
= model(inputs) outputs
2. TorchScript Optimization
# Trace the model
= torch.randn(1, 3, 224, 224)
example_input = torch.jit.trace(model, example_input)
traced_model
# Or script the model
= torch.jit.script(model)
scripted_model
# Optimize the scripted model
= torch.jit.optimize_for_inference(scripted_model)
optimized_model
# Save and load
"optimized_model.pt")
torch.jit.save(optimized_model, = torch.jit.load("optimized_model.pt") loaded_model
3. Quantization
import torch.quantization as quant
# Post-training quantization
eval()
model.= torch.quantization.quantize_dynamic(
quantized_model =torch.qint8
model, {torch.nn.Linear}, dtype
)
# Quantization-aware training
model.train()= torch.quantization.get_default_qat_qconfig('fbgemm')
model.qconfig =True)
torch.quantization.prepare_qat(model, inplace
# Train the model...
# Convert to quantized model
= torch.quantization.convert(model, inplace=False) quantized_model
4. Batch Processing for Inference
def batch_inference(model, data_loader, device):
eval()
model.= []
results
with torch.inference_mode():
for batch in data_loader:
= batch.to(device, non_blocking=True)
inputs = model(inputs)
outputs
results.append(outputs.cpu())
return torch.cat(results, dim=0)
Memory Management
1. Memory Efficient Training
# Clear unnecessary variables
del intermediate_results
# Free GPU memory
torch.cuda.empty_cache()
# Use gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint
class MyModel(nn.Module):
def forward(self, x):
# Use checkpointing for memory-intensive layers
= checkpoint(self.expensive_layer, x)
x return x
2. Monitor Memory Usage
def print_memory_usage():
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
# Monitor during training
for epoch in range(num_epochs):
for batch in train_loader:
# ... training code ...
if batch_idx % 100 == 0:
print_memory_usage()
3. Memory-Efficient Data Loading
class MemoryEfficientDataset(torch.utils.data.Dataset):
def __init__(self, data_paths):
self.data_paths = data_paths
def __getitem__(self, idx):
# Load data on-demand instead of keeping in memory
= self.load_data(self.data_paths[idx])
data return data
def __len__(self):
return len(self.data_paths)
Hardware-Specific Optimizations
1. GPU Optimizations
# Set optimal GPU settings
= True # For fixed input sizes
torch.backends.cudnn.benchmark = False # For reproducibility (slower)
torch.backends.cudnn.deterministic
# Use multiple GPUs
if torch.cuda.device_count() > 1:
= nn.DataParallel(model)
model
# Or use DistributedDataParallel for better performance
from torch.nn.parallel import DistributedDataParallel as DDP
= DDP(model, device_ids=[local_rank]) model
2. CPU Optimizations
# Set number of threads
4)
torch.set_num_threads(
# Use Intel MKL-DNN optimizations
= True torch.backends.mkldnn.enabled
3. Apple Silicon (MPS) Support
# Use Metal Performance Shaders on Apple Silicon
if torch.backends.mps.is_available():
= torch.device("mps")
device = model.to(device) model
Profiling and Debugging
1. PyTorch Profiler
from torch.profiler import profile, record_function, ProfilerActivity
with profile(
=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
activities=True,
record_shapes=True,
profile_memory=True
with_stackas prof:
) for batch in train_loader:
with record_function("forward"):
= model(inputs)
outputs with record_function("backward"):
loss.backward()with record_function("optimizer"):
optimizer.step()
# Save trace for tensorboard
"trace.json") prof.export_chrome_trace(
2. Memory Profiling
# Profile memory usage
with profile(profile_memory=True) as prof:
model(inputs)
print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))
3. Speed Benchmarking
import time
def benchmark_model(model, input_tensor, num_runs=100):
eval()
model.
# Warmup
with torch.no_grad():
for _ in range(10):
= model(input_tensor)
_
# Benchmark
torch.cuda.synchronize()= time.time()
start_time
with torch.no_grad():
for _ in range(num_runs):
= model(input_tensor)
_
torch.cuda.synchronize()= time.time()
end_time
= (end_time - start_time) / num_runs
avg_time print(f"Average inference time: {avg_time*1000:.2f} ms")
Best Practices Summary
- Always profile first - Identify bottlenecks before optimizing
- Use mixed precision - Significant speedup with minimal accuracy loss
- Optimize data loading - Use multiple workers and pin memory
- Batch operations - Avoid loops over individual samples
- Model compilation - Use
torch.compile()
for PyTorch 2.0+ - Memory management - Monitor and optimize memory usage
- Hardware utilization - Use all available compute resources
- Quantization for inference - Reduce model size and increase speed
- TorchScript for production - Better performance and deployment options
- Regular checkpointing - Save training progress and enable resumption
Common Pitfalls to Avoid
- Moving tensors between CPU and GPU unnecessarily
- Using small batch sizes that underutilize hardware
- Not using
torch.no_grad()
during inference - Creating tensors in loops instead of batching
- Not clearing variables and calling
torch.cuda.empty_cache()
- Using synchronous operations when asynchronous would work
- Not leveraging built-in optimized functions