PyTorch Training and Inference Optimization Guide

code
tutorial
intermediate
Author

Krishnatheja Vanka

Published

June 1, 2025

PyTorch Training and Inference Optimization Guide

The guide includes practical code examples you can directly use in your projects, along with best practices and common pitfalls to avoid. Each section builds upon the previous ones, so you can implement these optimizations incrementally based on your specific needs and performance requirements.

Table of Contents

  1. General Optimization Principles
  2. Training Optimizations
  3. Inference Optimizations
  4. Memory Management
  5. Hardware-Specific Optimizations
  6. Profiling and Debugging

General Optimization Principles

1. Use the Right Data Types

import torch

# Use half precision when possible (reduces memory and increases speed)
model = model.half()  # Convert to float16
# Or use mixed precision training
from torch.cuda.amp import autocast, GradScaler

# Use appropriate tensor types
x = torch.tensor(data, dtype=torch.float32)  # Explicit dtype

2. Optimize Data Loading

from torch.utils.data import DataLoader
import torch.multiprocessing as mp

# Optimize DataLoader
train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,  # Use multiple workers
    pin_memory=True,  # Faster GPU transfer
    persistent_workers=True,  # Keep workers alive
    prefetch_factor=2  # Prefetch batches
)

# Use non_blocking transfers
for batch in train_loader:
    data = batch[0].to(device, non_blocking=True)
    target = batch[1].to(device, non_blocking=True)

3. Tensor Operations Best Practices

# Avoid unnecessary CPU-GPU transfers
x = torch.randn(1000, 1000, device='cuda')  # Create directly on GPU

# Use in-place operations when possible
x.add_(y)  # Instead of x = x + y
x.mul_(2)  # Instead of x = x * 2

# Batch operations instead of loops
# Bad
for i in range(batch_size):
    result[i] = model(x[i])

# Good
result = model(x)  # Process entire batch

Training Optimizations

1. Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

for epoch in range(num_epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

2. Gradient Accumulation

accumulation_steps = 4
optimizer.zero_grad()

for i, batch in enumerate(train_loader):
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets) / accumulation_steps
    
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

3. Efficient Learning Rate Scheduling

from torch.optim.lr_scheduler import OneCycleLR

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = OneCycleLR(
    optimizer,
    max_lr=0.01,
    epochs=num_epochs,
    steps_per_epoch=len(train_loader)
)

# Use scheduler after each batch for OneCycleLR
for batch in train_loader:
    # ... training step ...
    scheduler.step()

4. Model Compilation (PyTorch 2.0+)

# Compile model for faster training
model = torch.compile(model)

# Different modes for different use cases
model = torch.compile(model, mode="reduce-overhead")  # For large models
model = torch.compile(model, mode="max-autotune")     # For maximum performance

5. Checkpoint and Resume Training

def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filename)

def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

Inference Optimizations

1. Model Optimization for Inference

# Set model to evaluation mode
model.eval()

# Disable gradient computation
with torch.no_grad():
    outputs = model(inputs)

# Use torch.inference_mode() for even better performance
with torch.inference_mode():
    outputs = model(inputs)

2. TorchScript Optimization

# Trace the model
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# Or script the model
scripted_model = torch.jit.script(model)

# Optimize the scripted model
optimized_model = torch.jit.optimize_for_inference(scripted_model)

# Save and load
torch.jit.save(optimized_model, "optimized_model.pt")
loaded_model = torch.jit.load("optimized_model.pt")

3. Quantization

import torch.quantization as quant

# Post-training quantization
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# Quantization-aware training
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

# Train the model...

# Convert to quantized model
quantized_model = torch.quantization.convert(model, inplace=False)

4. Batch Processing for Inference

def batch_inference(model, data_loader, device):
    model.eval()
    results = []
    
    with torch.inference_mode():
        for batch in data_loader:
            inputs = batch.to(device, non_blocking=True)
            outputs = model(inputs)
            results.append(outputs.cpu())
    
    return torch.cat(results, dim=0)

Memory Management

1. Memory Efficient Training

# Clear unnecessary variables
del intermediate_results
torch.cuda.empty_cache()  # Free GPU memory

# Use gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def forward(self, x):
        # Use checkpointing for memory-intensive layers
        x = checkpoint(self.expensive_layer, x)
        return x

2. Monitor Memory Usage

def print_memory_usage():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Monitor during training
for epoch in range(num_epochs):
    for batch in train_loader:
        # ... training code ...
        if batch_idx % 100 == 0:
            print_memory_usage()

3. Memory-Efficient Data Loading

class MemoryEfficientDataset(torch.utils.data.Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
    
    def __getitem__(self, idx):
        # Load data on-demand instead of keeping in memory
        data = self.load_data(self.data_paths[idx])
        return data
    
    def __len__(self):
        return len(self.data_paths)

Hardware-Specific Optimizations

1. GPU Optimizations

# Set optimal GPU settings
torch.backends.cudnn.benchmark = True  # For fixed input sizes
torch.backends.cudnn.deterministic = False  # For reproducibility (slower)

# Use multiple GPUs
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

# Or use DistributedDataParallel for better performance
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[local_rank])

2. CPU Optimizations

# Set number of threads
torch.set_num_threads(4)

# Use Intel MKL-DNN optimizations
torch.backends.mkldnn.enabled = True

3. Apple Silicon (MPS) Support

# Use Metal Performance Shaders on Apple Silicon
if torch.backends.mps.is_available():
    device = torch.device("mps")
    model = model.to(device)

Profiling and Debugging

1. PyTorch Profiler

from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    for batch in train_loader:
        with record_function("forward"):
            outputs = model(inputs)
        with record_function("backward"):
            loss.backward()
        with record_function("optimizer"):
            optimizer.step()

# Save trace for tensorboard
prof.export_chrome_trace("trace.json")

2. Memory Profiling

# Profile memory usage
with profile(profile_memory=True) as prof:
    model(inputs)

print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

3. Speed Benchmarking

import time

def benchmark_model(model, input_tensor, num_runs=100):
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(input_tensor)
    
    # Benchmark
    torch.cuda.synchronize()
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(input_tensor)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    print(f"Average inference time: {avg_time*1000:.2f} ms")

Best Practices Summary

  1. Always profile first - Identify bottlenecks before optimizing
  2. Use mixed precision - Significant speedup with minimal accuracy loss
  3. Optimize data loading - Use multiple workers and pin memory
  4. Batch operations - Avoid loops over individual samples
  5. Model compilation - Use torch.compile() for PyTorch 2.0+
  6. Memory management - Monitor and optimize memory usage
  7. Hardware utilization - Use all available compute resources
  8. Quantization for inference - Reduce model size and increase speed
  9. TorchScript for production - Better performance and deployment options
  10. Regular checkpointing - Save training progress and enable resumption

Common Pitfalls to Avoid

  • Moving tensors between CPU and GPU unnecessarily
  • Using small batch sizes that underutilize hardware
  • Not using torch.no_grad() during inference
  • Creating tensors in loops instead of batching
  • Not clearing variables and calling torch.cuda.empty_cache()
  • Using synchronous operations when asynchronous would work
  • Not leveraging built-in optimized functions