PyTorch Lightning Fabric Code Guide
I’ve created a comprehensive code guide for PyTorch Lightning Fabric that covers everything from basic setup to advanced distributed training features
Table of Contents
Introduction
Lightning Fabric is a lightweight PyTorch wrapper that provides essential training utilities without the overhead of the full Lightning framework. It’s perfect when you want more control over your training loop while still benefiting from distributed training, mixed precision, and other optimizations.
Installation
pip install lightning
# or
pip install pytorch-lightning
Basic Setup
Minimal Example
import torch
import torch.nn as nn
from lightning.fabric import Fabric
# Initialize Fabric
= Fabric()
fabric
# Your model
= nn.Linear(10, 1)
model = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer
# Setup model and optimizer with Fabric
= fabric.setup(model, optimizer)
model, optimizer
# Training step
for batch in dataloader:
optimizer.zero_grad()= model(batch).mean()
loss
fabric.backward(loss) optimizer.step()
Core Components
1. Fabric Initialization
from lightning.fabric import Fabric
# Basic initialization
= Fabric()
fabric
# With specific configuration
= Fabric(
fabric ="gpu", # "cpu", "gpu", "tpu", "auto"
accelerator="ddp", # "ddp", "fsdp", "deepspeed", etc.
strategy=2, # Number of devices
devices="16-mixed", # "32", "16-mixed", "bf16-mixed"
precision=[], # Custom plugins
plugins
)
# Launch the fabric
fabric.launch()
2. Model and Optimizer Setup
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, x):
= self.relu(self.fc1(x))
x = self.dropout(x)
x return self.fc2(x)
# Create model and optimizer
= SimpleModel(784, 128, 10)
model = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scheduler
# Setup with Fabric
= fabric.setup(model, optimizer)
model, optimizer = fabric.setup(scheduler) scheduler
3. DataLoader Setup
from torch.utils.data import DataLoader, TensorDataset
# Create your dataset
= TensorDataset(torch.randn(1000, 784), torch.randint(0, 10, (1000,)))
dataset = DataLoader(dataset, batch_size=32, shuffle=True)
dataloader
# Setup with Fabric
= fabric.setup_dataloaders(dataloader) dataloader
Training Loop
Basic Training Loop
def train_epoch(fabric, model, optimizer, dataloader, criterion):
model.train()= 0
total_loss
for batch_idx, (data, target) in enumerate(dataloader):
# Zero gradients
optimizer.zero_grad()
# Forward pass
= model(data)
output = criterion(output, target)
loss
# Backward pass with Fabric
fabric.backward(loss)
# Optimizer step
optimizer.step()
+= loss.item()
total_loss
# Log every 100 batches
if batch_idx % 100 == 0:
print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')
fabric.
return total_loss / len(dataloader)
# Training loop
= nn.CrossEntropyLoss()
criterion
for epoch in range(10):
= train_epoch(fabric, model, optimizer, dataloader, criterion)
avg_loss
scheduler.step()
print(f'Epoch {epoch}: Average Loss = {avg_loss:.4f}') fabric.
Training with Validation
def validate(fabric, model, val_dataloader, criterion):
eval()
model.= 0
total_loss = 0
correct = 0
total
with torch.no_grad():
for data, target in val_dataloader:
= model(data)
output = criterion(output, target)
loss += loss.item()
total_loss
= output.argmax(dim=1)
pred += (pred == target).sum().item()
correct += target.size(0)
total
= correct / total
accuracy = total_loss / len(val_dataloader)
avg_loss
return avg_loss, accuracy
# Complete training with validation
= fabric.setup_dataloaders(train_dataloader)
train_loader = fabric.setup_dataloaders(val_dataloader)
val_loader
for epoch in range(num_epochs):
# Training
= train_epoch(fabric, model, optimizer, train_loader, criterion)
train_loss
# Validation
= validate(fabric, model, val_loader, criterion)
val_loss, val_acc
print(f'Epoch {epoch}:')
fabric.print(f' Train Loss: {train_loss:.4f}')
fabric.print(f' Val Loss: {val_loss:.4f}')
fabric.print(f' Val Acc: {val_acc:.4f}')
fabric.
scheduler.step()
Multi-GPU Training
Distributed Data Parallel (DDP)
# Initialize Fabric for multi-GPU
= Fabric(
fabric ="gpu",
accelerator="ddp",
strategy=4, # Use 4 GPUs
devices
)
fabric.launch()
# All-reduce for metrics across processes
def all_reduce_mean(fabric, tensor):
"""Average tensor across all processes"""
="mean")
fabric.all_reduce(tensor, reduce_opreturn tensor
# Training with distributed metrics
def train_distributed(fabric, model, optimizer, dataloader, criterion):
model.train()= torch.tensor(0.0, device=fabric.device)
total_loss = 0
num_batches
for data, target in dataloader:
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
fabric.backward(loss)
optimizer.step()
+= loss.detach()
total_loss += 1
num_batches
# Average loss across all processes
= total_loss / num_batches
avg_loss = all_reduce_mean(fabric, avg_loss)
avg_loss
return avg_loss.item()
Fully Sharded Data Parallel (FSDP)
# For very large models
= Fabric(
fabric ="gpu",
accelerator="fsdp",
strategy=8,
devices="bf16-mixed"
precision
)
fabric.launch()
# FSDP automatically shards model parameters
= fabric.setup(model, optimizer) model, optimizer
Mixed Precision
Automatic Mixed Precision
# Enable mixed precision
= Fabric(precision="16-mixed") # or "bf16-mixed"
fabric
fabric.launch()
# Training remains the same - Fabric handles precision automatically
def train_with_amp(fabric, model, optimizer, dataloader, criterion):
model.train()
for data, target in dataloader:
optimizer.zero_grad()
# Forward pass (automatically uses mixed precision)
= model(data)
output = criterion(output, target)
loss
# Backward pass (handles gradient scaling)
fabric.backward(loss)
optimizer.step()
Manual Precision Control
from lightning.fabric.utilities import rank_zero_only
@rank_zero_only
def log_model_precision(model):
"""Log model parameter precisions (only on rank 0)"""
for name, param in model.named_parameters():
print(f"{name}: {param.dtype}")
# Check model precision after setup
= fabric.setup(model, optimizer)
model, optimizer log_model_precision(model)
Logging and Checkpointing
Checkpointing
import os
def save_checkpoint(fabric, model, optimizer, epoch, loss, path):
"""Save model checkpoint"""
= {
checkpoint "model": model,
"optimizer": optimizer,
"epoch": epoch,
"loss": loss
}
fabric.save(path, checkpoint)
def load_checkpoint(fabric, path):
"""Load model checkpoint"""
= fabric.load(path)
checkpoint return checkpoint
# Save checkpoint
= f"checkpoint_epoch_{epoch}.ckpt"
checkpoint_path
save_checkpoint(fabric, model, optimizer, epoch, train_loss, checkpoint_path)
# Load checkpoint
if os.path.exists("checkpoint_epoch_5.ckpt"):
= load_checkpoint(fabric, "checkpoint_epoch_5.ckpt")
checkpoint = checkpoint["model"]
model = checkpoint["optimizer"]
optimizer = checkpoint["epoch"] + 1 start_epoch
Logging with External Loggers
from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
# Initialize logger
= TensorBoardLogger("logs", name="my_experiment")
logger
# Setup Fabric with logger
= Fabric(loggers=[logger])
fabric
fabric.launch()
# Log metrics
def log_metrics(fabric, metrics, step):
for logger in fabric.loggers:
logger.log_metrics(metrics, step)
# Usage in training loop
for epoch in range(num_epochs):
= train_epoch(fabric, model, optimizer, train_loader, criterion)
train_loss = validate(fabric, model, val_loader, criterion)
val_loss, val_acc
# Log metrics
= {
metrics "train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": val_acc,
"learning_rate": optimizer.param_groups[0]['lr']
} log_metrics(fabric, metrics, epoch)
Advanced Features
Custom Precision Plugin
from lightning.fabric.plugins import MixedPrecisionPlugin
# Custom precision configuration
= MixedPrecisionPlugin(
precision_plugin ="16-mixed",
precision="cuda",
device={"init_scale": 2**16}
scaler_kwargs
)
= Fabric(plugins=[precision_plugin]) fabric
Gradient Clipping
def train_with_grad_clipping(fabric, model, optimizer, dataloader, criterion, max_norm=1.0):
model.train()
for data, target in dataloader:
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
fabric.backward(loss)
# Gradient clipping
=max_norm)
fabric.clip_gradients(model, optimizer, max_norm
optimizer.step()
Early Stopping
class EarlyStopping:
def __init__(self, patience=10, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
def __call__(self, val_loss):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
return self.counter >= self.patience
# Usage
= EarlyStopping(patience=5)
early_stopping
for epoch in range(num_epochs):
= train_epoch(fabric, model, optimizer, train_loader, criterion)
train_loss = validate(fabric, model, val_loader, criterion)
val_loss, val_acc
if early_stopping(val_loss):
print(f"Early stopping at epoch {epoch}")
fabric.break
Best Practices
1. Proper Fabric Launch
# Always use fabric.launch() for proper initialization
def main():
= Fabric(accelerator="gpu", devices=2)
fabric
fabric.launch()
# Your training code here
= create_model()
model # ... rest of training
if __name__ == "__main__":
main()
2. Rank-specific Operations
from lightning.fabric.utilities import rank_zero_only
@rank_zero_only
def save_model_artifacts(model, path):
"""Only save on rank 0 to avoid conflicts"""
torch.save(model.state_dict(), path)
@rank_zero_only
def print_training_info(epoch, loss):
"""Only print on rank 0 to avoid duplicate outputs"""
print(f"Epoch {epoch}, Loss: {loss}")
3. Proper Device Management
# Let Fabric handle device placement
= Fabric()
fabric
fabric.launch()
# Don't manually move to device - Fabric handles this
# BAD: model.to(device), data.to(device)
# GOOD: Let fabric.setup() handle device placement
= fabric.setup(model, optimizer)
model, optimizer = fabric.setup_dataloaders(dataloader) dataloader
4. Memory Efficient Training
def memory_efficient_training(fabric, model, optimizer, dataloader, criterion):
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
# Use gradient checkpointing for large models
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
= model(data)
output = criterion(output, target)
loss
fabric.backward(loss)
optimizer.step()
# Clear cache periodically
if batch_idx % 100 == 0:
torch.cuda.empty_cache()
5. Complete Training Script Template
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lightning.fabric import Fabric
from lightning.fabric.utilities import rank_zero_only
def create_model():
return nn.Sequential(
784, 256),
nn.Linear(
nn.ReLU(),256, 128),
nn.Linear(
nn.ReLU(),128, 10)
nn.Linear(
)
def train_epoch(fabric, model, optimizer, dataloader, criterion):
model.train()= 0
total_loss
for data, target in dataloader:
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
fabric.backward(loss)
optimizer.step()+= loss.item()
total_loss
return total_loss / len(dataloader)
def main():
# Initialize Fabric
= Fabric(
fabric ="auto",
accelerator="auto",
strategy="auto",
devices="16-mixed"
precision
)
fabric.launch()
# Create model, optimizer, data
= create_model()
model = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer
# Setup with Fabric
= fabric.setup(model, optimizer)
model, optimizer
# Training loop
= nn.CrossEntropyLoss()
criterion for epoch in range(10):
= train_epoch(fabric, model, optimizer, dataloader, criterion)
avg_loss
if fabric.is_global_zero:
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
if __name__ == "__main__":
main()
This guide covers the essential aspects of using Lightning Fabric for efficient PyTorch training. Fabric provides the perfect balance between control and convenience, making it ideal for researchers and practitioners who want distributed training capabilities without giving up flexibility in their training loops.