import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup_distributed(rank, world_size, backend='nccl'):
"""Initialize distributed training environment"""
'MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ[
# Initialize process group
dist.init_process_group(=backend,
backend=rank,
rank=world_size
world_size
)
# Set device for current process
torch.cuda.set_device(rank)
def cleanup_distributed():
"""Clean up distributed training"""
dist.destroy_process_group()
Distributed Training with PyTorch - Complete Code Guide
Introduction
Distributed training allows you to scale PyTorch models across multiple GPUs and machines, dramatically reducing training time for large models and datasets. This guide covers practical implementation patterns from basic data parallelism to advanced distributed strategies.
Core Concepts
Key Terminology
- World Size: Total number of processes participating in training
- Rank: Unique identifier for each process (0 to world_size-1)
- Local Rank: Process identifier within a single node/machine
- Process Group: Collection of processes that can communicate with each other
- Backend: Communication backend (NCCL for GPU, Gloo for CPU)
Communication Patterns
- All-Reduce: Combine values from all processes and distribute the result
- Broadcast: Send data from one process to all others
- Gather: Collect data from all processes to one process
- Scatter: Distribute data from one process to all others
Setup and Initialization
Basic Environment Setup
Multi-Node Setup
def setup_multinode(rank, world_size, master_addr, master_port):
"""Setup for multi-node distributed training"""
'MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ[
dist.init_process_group(='nccl',
backend='env://',
init_method=rank,
rank=world_size
world_size )
Data Parallel Training
Simple DataParallel (Single Node)
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
= self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x return x
def train_dataparallel():
"""Basic DataParallel training"""
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
# Create model and wrap with DataParallel
= SimpleModel(784, 256, 10)
model if torch.cuda.device_count() > 1:
= nn.DataParallel(model)
model
model.to(device)
# Setup optimizer and loss
= torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = nn.CrossEntropyLoss()
criterion
# Training loop
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
= data.to(device), target.to(device)
data, target
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
loss.backward() optimizer.step()
Distributed Data Parallel (DDP)
Basic DDP Implementation
def train_ddp(rank, world_size):
"""Distributed Data Parallel training function"""
# Setup distributed environment
setup_distributed(rank, world_size)
# Create model and move to GPU
= SimpleModel(784, 256, 10).to(rank)
model
# Wrap model with DDP
= DDP(model, device_ids=[rank])
ddp_model
# Setup distributed sampler
= DistributedSampler(
train_sampler
train_dataset,=world_size,
num_replicas=rank,
rank=True
shuffle
)
= torch.utils.data.DataLoader(
train_loader
train_dataset,=batch_size,
batch_size=train_sampler,
sampler=4,
num_workers=True
pin_memory
)
# Setup optimizer and loss
= torch.optim.Adam(ddp_model.parameters(), lr=0.001)
optimizer = nn.CrossEntropyLoss()
criterion
# Training loop
for epoch in range(num_epochs):
# Important for shuffling
train_sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
= data.to(rank), target.to(rank)
data, target
optimizer.zero_grad()= ddp_model(data)
output = criterion(output, target)
loss
loss.backward()
optimizer.step()
if rank == 0 and batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
cleanup_distributed()
def main():
"""Main function to spawn distributed processes"""
= torch.cuda.device_count()
world_size =(world_size,), nprocs=world_size, join=True)
mp.spawn(train_ddp, args
if __name__ == "__main__":
main()
Complete Training Script with Validation
import time
from torch.utils.tensorboard import SummaryWriter
class DistributedTrainer:
def __init__(self, model, rank, world_size, train_loader, val_loader=None):
self.model = model
self.rank = rank
self.world_size = world_size
self.train_loader = train_loader
self.val_loader = val_loader
# Setup DDP
self.ddp_model = DDP(model, device_ids=[rank])
# Setup optimizer and scheduler
self.optimizer = torch.optim.AdamW(
self.ddp_model.parameters(),
=0.001,
lr=0.01
weight_decay
)self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=100
)self.criterion = nn.CrossEntropyLoss()
# Logging (only on rank 0)
if rank == 0:
self.writer = SummaryWriter('runs/distributed_training')
def train_epoch(self, epoch):
"""Train for one epoch"""
self.ddp_model.train()
= 0
total_loss = 0
num_batches
for batch_idx, (data, target) in enumerate(self.train_loader):
= data.to(self.rank), target.to(self.rank)
data, target
self.optimizer.zero_grad()
= self.ddp_model(data)
output = self.criterion(output, target)
loss
loss.backward()
# Gradient clipping
self.ddp_model.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_norm_(
self.optimizer.step()
+= loss.item()
total_loss += 1
num_batches
if self.rank == 0 and batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
= total_loss / num_batches
avg_loss return avg_loss
def validate(self):
"""Validate the model"""
if self.val_loader is None:
return None
self.ddp_model.eval()
= 0
total_loss = 0
correct = 0
total
with torch.no_grad():
for data, target in self.val_loader:
= data.to(self.rank), target.to(self.rank)
data, target = self.ddp_model(data)
output = self.criterion(output, target)
loss
+= loss.item()
total_loss = output.argmax(dim=1)
pred += pred.eq(target).sum().item()
correct += target.size(0)
total
# Gather metrics from all processes
= torch.tensor(total_loss).to(self.rank)
total_loss_tensor = torch.tensor(correct).to(self.rank)
correct_tensor = torch.tensor(total).to(self.rank)
total_tensor
=dist.ReduceOp.SUM)
dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op
= total_loss_tensor.item() / self.world_size
avg_loss = correct_tensor.item() / total_tensor.item()
accuracy
return avg_loss, accuracy
def save_checkpoint(self, epoch, loss):
"""Save model checkpoint (only on rank 0)"""
if self.rank == 0:
= {
checkpoint 'epoch': epoch,
'model_state_dict': self.ddp_model.module.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'loss': loss,
}f'checkpoint_epoch_{epoch}.pth')
torch.save(checkpoint,
def train(self, num_epochs):
"""Complete training loop"""
for epoch in range(num_epochs):
= time.time()
start_time
# Set epoch for distributed sampler
if hasattr(self.train_loader.sampler, 'set_epoch'):
self.train_loader.sampler.set_epoch(epoch)
# Train
= self.train_epoch(epoch)
train_loss
# Validate
= self.validate()
val_metrics
# Step scheduler
self.scheduler.step()
# Logging and checkpointing (rank 0 only)
if self.rank == 0:
= time.time() - start_time
epoch_time print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, '
f'Time: {epoch_time:.2f}s')
if val_metrics:
= val_metrics
val_loss, val_acc print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
# TensorBoard logging
self.writer.add_scalar('Loss/Train', train_loss, epoch)
self.writer.add_scalar('Loss/Val', val_loss, epoch)
self.writer.add_scalar('Accuracy/Val', val_acc, epoch)
# Save checkpoint
if epoch % 10 == 0:
self.save_checkpoint(epoch, train_loss)
Advanced Patterns
Mixed Precision Training
from torch.cuda.amp import GradScaler, autocast
class MixedPrecisionTrainer(DistributedTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scaler = GradScaler()
def train_epoch(self, epoch):
"""Train with mixed precision"""
self.ddp_model.train()
= 0
total_loss = 0
num_batches
for batch_idx, (data, target) in enumerate(self.train_loader):
= data.to(self.rank), target.to(self.rank)
data, target
self.optimizer.zero_grad()
# Forward pass with autocast
with autocast():
= self.ddp_model(data)
output = self.criterion(output, target)
loss
# Backward pass with scaled gradients
self.scaler.scale(loss).backward()
# Gradient clipping with scaled gradients
self.scaler.unscale_(self.optimizer)
self.ddp_model.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_norm_(
# Optimizer step with scaler
self.scaler.step(self.optimizer)
self.scaler.update()
+= loss.item()
total_loss += 1
num_batches
return total_loss / num_batches
Model Sharding with FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
def create_fsdp_model(model, rank):
"""Create FSDP wrapped model"""
= size_based_auto_wrap_policy(min_num_params=100000)
wrap_policy
= FSDP(
fsdp_model
model,=wrap_policy,
auto_wrap_policy=torch.distributed.fsdp.MixedPrecision(
mixed_precision=torch.float16,
param_dtype=torch.float16,
reduce_dtype=torch.float16
buffer_dtype
),=rank,
device_id=True,
sync_module_states=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD
sharding_strategy
)
return fsdp_model
Pipeline Parallelism
import torch.distributed.pipeline.sync as Pipe
class PipelineModel(nn.Module):
def __init__(self, layers_per_partition=2):
super().__init__()
# Define layers
= []
layers 784, 512))
layers.append(nn.Linear(
layers.append(nn.ReLU())512, 256))
layers.append(nn.Linear(
layers.append(nn.ReLU())256, 128))
layers.append(nn.Linear(
layers.append(nn.ReLU())128, 10))
layers.append(nn.Linear(
# Create pipeline
self.pipe = Pipe.Pipe(
*layers),
nn.Sequential(=[layers_per_partition] * (len(layers) // layers_per_partition),
balance=[0, 1], # GPU devices
devices=8 # Number of micro-batches
chunks
)
def forward(self, x):
return self.pipe(x)
Monitoring and Debugging
Performance Profiling
import torch.profiler
def profile_training(trainer, num_steps=100):
"""Profile distributed training performance"""
with torch.profiler.profile(
=[
activities
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
schedule=torch.profiler.tensorboard_trace_handler('./log/profiler'),
on_trace_ready=True,
record_shapes=True,
profile_memory=True
with_stackas prof:
) for step, (data, target) in enumerate(trainer.train_loader):
if step >= num_steps:
break
= data.to(trainer.rank), target.to(trainer.rank)
data, target
trainer.optimizer.zero_grad()= trainer.ddp_model(data)
output = trainer.criterion(output, target)
loss
loss.backward()
trainer.optimizer.step()
prof.step()
Communication Debugging
def debug_communication():
"""Debug distributed communication"""
= dist.get_rank()
rank = dist.get_world_size()
world_size
# Test all-reduce
= torch.randn(10).cuda()
tensor print(f"Rank {rank}: Before all-reduce: {tensor.sum().item():.4f}")
=dist.ReduceOp.SUM)
dist.all_reduce(tensor, opprint(f"Rank {rank}: After all-reduce: {tensor.sum().item():.4f}")
# Test broadcast
if rank == 0:
= torch.randn(5).cuda()
broadcast_tensor else:
= torch.zeros(5).cuda()
broadcast_tensor
=0)
dist.broadcast(broadcast_tensor, srcprint(f"Rank {rank}: Broadcast result: {broadcast_tensor.sum().item():.4f}")
Best Practices
Data Loading Optimization
def create_efficient_dataloader(dataset, batch_size, world_size, rank):
"""Create optimized distributed data loader"""
= DistributedSampler(
sampler
dataset,=world_size,
num_replicas=rank,
rank=True,
shuffle=True # Ensures consistent batch sizes
drop_last
)
= torch.utils.data.DataLoader(
loader
dataset,=batch_size,
batch_size=sampler,
sampler=4, # Adjust based on system
num_workers=True,
pin_memory=True, # Reuse worker processes
persistent_workers=2
prefetch_factor
)
return loader
Gradient Synchronization
def train_with_gradient_accumulation(model, optimizer, criterion, data_loader,
=4):
accumulation_steps"""Training with gradient accumulation"""
model.train()
for batch_idx, (data, target) in enumerate(data_loader):
= data.cuda(), target.cuda()
data, target
# Forward pass
= model(data)
output = criterion(output, target) / accumulation_steps
loss
# Backward pass
loss.backward()
# Update parameters every accumulation_steps
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step() optimizer.zero_grad()
Dynamic Loss Scaling
class DynamicLossScaler:
def __init__(self, init_scale=2.**16, scale_factor=2., scale_window=2000):
self.scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.counter = 0
def update(self, overflow):
if overflow:
self.scale /= self.scale_factor
self.counter = 0
else:
self.counter += 1
if self.counter >= self.scale_window:
self.scale *= self.scale_factor
self.counter = 0
Launch Script Example
launch_distributed.sh
#!/bin/bash
# launch_distributed.sh
# Single node, multiple GPUs
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=1 \
--node_rank=0 \
--master_addr="localhost" \
--master_port=12345 \
train_distributed.py
# Multi-node setup
# Node 0:
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=0 \
--master_addr="192.168.1.100" \
--master_port=12345 \
train_distributed.py
# Node 1:
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=1 \
--master_addr="192.168.1.100" \
--master_port=12345 \
train_distributed.py
Error Handling and Recovery
def robust_train_loop(trainer, num_epochs, checkpoint_dir):
"""Training loop with error handling and recovery"""
= 0
start_epoch
# Load checkpoint if exists
= find_latest_checkpoint(checkpoint_dir)
latest_checkpoint if latest_checkpoint:
= load_checkpoint(trainer, latest_checkpoint)
start_epoch
for epoch in range(start_epoch, num_epochs):
try:
trainer.train_epoch(epoch)
# Save checkpoint
if epoch % 5 == 0:
save_checkpoint(trainer, epoch, checkpoint_dir)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"OOM error at epoch {epoch}, reducing batch size")
# Implement batch size reduction logic
torch.cuda.empty_cache()continue
else:
raise e
except Exception as e:
print(f"Error at epoch {epoch}: {e}")
# Save emergency checkpoint
=True)
save_checkpoint(trainer, epoch, checkpoint_dir, emergencyraise e
Conclusion
This guide provides a comprehensive foundation for implementing distributed training with PyTorch. Start with basic DDP for single-node multi-GPU setups, then progress to more advanced techniques like FSDP and pipeline parallelism as your models and datasets grow larger.
Key Takeaways
- Start Simple: Begin with DataParallel for single-node setups
- Scale Gradually: Move to DDP for multi-node distributed training
- Monitor Performance: Use profiling tools to identify bottlenecks
- Handle Errors: Implement robust error handling and checkpointing
- Optimize Data Loading: Use efficient data loaders and samplers
Additional Resources
For more advanced topics and latest updates, refer to: - PyTorch Distributed Documentation - FSDP Tutorial - Pipeline Parallelism Guide