⚡ PyTorch Lightning: A Comprehensive Guide

PyTorch Lightning is a lightweight wrapper for PyTorch that helps organize code and reduce boilerplate while adding powerful features for research and production. This guide will walk you through the basics to advanced techniques.
Introduction to PyTorch Lightning
PyTorch Lightning separates research code from engineering code, making models more:
- Reproducible: The same code works across different hardware
- Readable: Standard project structure makes collaboration easier
- Scalable: Train on CPUs, GPUs, TPUs or clusters with no code changes
Lightning helps you focus on the science by handling the engineering details.
Installation
pip install pytorch-lightningFor the latest features, you can install from the source:
pip install git+https://github.com/Lightning-AI/lightning.gitBasic Structure: The LightningModule
The core component in Lightning is the LightningModule, which organizes your PyTorch code into a standardized structure:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class MNISTModel(pl.LightningModule):
def __init__(self, lr=0.001):
super().__init__()
self.save_hyperparameters() # Saves learning_rate to hparams
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
self.lr = lr
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('test_loss', loss)
self.log('test_acc', acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizerThis basic structure includes:
- Model architecture (
__init__andforward) - Training logic (
training_step) - Validation logic (
validation_step) - Test logic (
test_step) - Optimization setup (
configure_optimizers)
DataModules
Lightning’s LightningDataModule encapsulates all data-related logic:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir='./data', batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
# Download data if needed
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val/test datasets
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)Benefits of LightningDataModule:
- Encapsulates all data preparation logic
- Makes data pipeline portable and reproducible
- Simplifies sharing data pipelines between projects
Training with Trainer
The Lightning Trainer handles the training loop and validation:
from pytorch_lightning import Trainer
# Create model and data module
model = MNISTModel()
data_module = MNISTDataModule()
# Initialize trainer
trainer = Trainer(
max_epochs=10,
accelerator="auto", # Automatically use available GPU
devices="auto",
logger=True, # Use TensorBoard logger by default
)
# Train model
trainer.fit(model, datamodule=data_module)
# Test model
trainer.test(model, datamodule=data_module)The Trainer automatically handles:
- Epoch and batch iteration
- Optimizer steps
- Logging metrics
- Hardware acceleration (CPU, GPU, TPU)
- Early stopping
- Checkpointing
- Multi-GPU training
Common Trainer Parameters
trainer = Trainer(
max_epochs=10, # Maximum number of epochs
min_epochs=1, # Minimum number of epochs
accelerator="cpu", # Use CPU acceleration
devices=1, # Use 1 CPUs
precision="16-mixed", # Use mixed precision for faster training
gradient_clip_val=0.5, # Clip gradients
accumulate_grad_batches=4, # Accumulate gradients over 4 batches
log_every_n_steps=50, # Log metrics every 50 steps
val_check_interval=0.25, # Run validation 4 times per epoch
fast_dev_run=False, # Debug mode (only run a few batches)
deterministic=True, # Make training deterministic
)Callbacks
Callbacks add functionality to the training loop without modifying the core code:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Save the best model based on validation accuracy
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='./checkpoints',
filename='mnist-{epoch:02d}-{val_acc:.2f}',
save_top_k=3,
mode='max',
)
# Stop training when validation loss stops improving
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
mode='min'
)
# Monitor learning rate
lr_monitor = LearningRateMonitor(logging_interval='step')
# Initialize trainer with callbacks
trainer = Trainer(
max_epochs=10,
callbacks=[checkpoint_callback, early_stop_callback, lr_monitor]
)Custom Callback
You can create custom callbacks by extending the base Callback class:
from pytorch_lightning.callbacks import Callback
class PrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training has started!")
def on_train_end(self, trainer, pl_module):
print("Training has finished!")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % 100 == 0:
print(f"Batch {batch_idx} completed")Logging
Lightning supports various loggers:
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
# TensorBoard logger
tensorboard_logger = TensorBoardLogger(save_dir="logs/", name="mnist")
# Initialize trainer with loggers
trainer = Trainer(
max_epochs=10,
logger=[tensorboard_logger]
)Adding metrics in your model:
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
# Log scalar values
self.log('train_loss', loss)
# Log learning rate
lr = self.optimizers().param_groups[0]['lr']
self.log('learning_rate', lr)
# Log histograms (on GPU)
if batch_idx % 100 == 0:
self.logger.experiment.add_histogram('logits', logits, self.global_step)
return lossDistributed Training
Lightning makes distributed training simple:
# Single GPU
trainer = Trainer(accelerator="gpu", devices=1)
# Multiple GPUs (DDP strategy automatically selected)
trainer = Trainer(accelerator="gpu", devices=4)
# Specific GPU indices
trainer = Trainer(accelerator="gpu", devices=[0, 1, 3])
# TPU cores
trainer = Trainer(accelerator="tpu", devices=8)
# Advanced distributed training
trainer = Trainer(
accelerator="gpu",
devices=4,
strategy="ddp", # Distributed Data Parallel
num_nodes=2 # Use 2 machines
)Other distribution strategies:
ddp_spawn: Similar to DDP but uses spawn for multiprocessingdeepspeed: For very large models using DeepSpeedfsdp: Fully Sharded Data Parallel for huge models
Hyperparameter Tuning
Lightning integrates well with optimization libraries:
from pytorch_lightning import Trainer
from pytorch_lightning.tuner import Tuner
# Create model
model = MNISTModel()
data_module = MNISTDataModule()
# Create trainer
trainer = Trainer(max_epochs=10)
# Use Lightning's Tuner for auto-scaling batch size
tuner = Tuner(trainer)
tuner.scale_batch_size(model, datamodule=data_module)
# Find optimal learning rate
lr_finder = tuner.lr_find(model, datamodule=data_module)
model.learning_rate = lr_finder.suggestion()
# Train with optimized parameters
trainer.fit(model, datamodule=data_module)Model Checkpointing
Saving and loading model checkpoints:
# Save checkpoints during training
checkpoint_callback = ModelCheckpoint(
dirpath='./checkpoints',
filename='{epoch}-{val_loss:.2f}',
monitor='val_loss',
save_top_k=3,
mode='min'
)
trainer = Trainer(
max_epochs=10,
callbacks=[checkpoint_callback]
)
trainer.fit(model, datamodule=data_module)
# Path to best checkpoint
best_model_path = checkpoint_callback.best_model_path
# Load checkpoint into model
model = MNISTModel.load_from_checkpoint(best_model_path)
# Continue training from checkpoint
trainer.fit(model, datamodule=data_module)Checkpointing for production:
# Save model in production-ready format
trainer.save_checkpoint("model.ckpt")
# Extract state dict for production
checkpoint = torch.load("model.ckpt")
model_state_dict = checkpoint["state_dict"]
# Save just the PyTorch model for production
model = MNISTModel()
model.load_state_dict(model_state_dict)
torch.save(model.state_dict(), "production_model.pt")Production Deployment
Converting Lightning models to production:
# Option 1: Use the Lightning model directly
model = MNISTModel.load_from_checkpoint("model.ckpt")
model.eval()
model.freeze() # Freeze the model parameters
# Make predictions
with torch.no_grad():
x = torch.randn(1, 1, 28, 28)
y_hat = model(x)
# Option 2: Extract the PyTorch model
class ProductionModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
return x
# Load weights from Lightning model
lightning_model = MNISTModel.load_from_checkpoint("model.ckpt")
production_model = ProductionModel()
# Copy weights
production_model.layer_1.weight.data = lightning_model.layer_1.weight.data
production_model.layer_1.bias.data = lightning_model.layer_1.bias.data
production_model.layer_2.weight.data = lightning_model.layer_2.weight.data
production_model.layer_2.bias.data = lightning_model.layer_2.bias.data
# Save production model
torch.save(production_model.state_dict(), "production_model.pt")
# Export to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
production_model,
dummy_input,
"model.onnx",
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['output']
)Best Practices
Code Organization
A well-organized Lightning project structure:
project/
├── configs/ # Configuration files
├── data/ # Data files
├── lightning_logs/ # Generated logs
├── models/ # Model definitions
│ ├── __init__.py
│ └── mnist_model.py # LightningModule
├── data_modules/ # Data modules
│ ├── __init__.py
│ └── mnist_data.py # LightningDataModule
├── callbacks/ # Custom callbacks
├── utils/ # Utility functions
├── main.py # Training script
└── README.md
Lightning CLI
Lightning provides a CLI for running experiments from config files:
# main.py
from pytorch_lightning.cli import LightningCLI
def cli_main():
# Create CLI with LightningModule and LightningDataModule
cli = LightningCLI(
MNISTModel,
MNISTDataModule,
save_config_callback=None,
)
if __name__ == "__main__":
cli_main()Run with:
python main.py fit --config configs/mnist.yamlExample config file (mnist.yaml):
model:
learning_rate: 0.001
hidden_size: 128
data:
data_dir: ./data
batch_size: 64
trainer:
max_epochs: 10
accelerator: gpu
devices: 1
Profiling
Lightning includes tools to profile your code:
from pytorch_lightning.profilers import PyTorchProfiler, SimpleProfiler
# PyTorch Profiler
profiler = PyTorchProfiler(
on_trace_ready=torch.profiler.tensorboard_trace_handler("logs/profiler"),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2)
)
# Simple Profiler
# profiler = SimpleProfiler()
model = MNISTModel()
data_module = MNISTDataModule()
trainer = Trainer(
max_epochs=5,
profiler=profiler
)
trainer.fit(model, datamodule=data_module)Advanced Features
Gradient Accumulation
trainer = Trainer(
accumulate_grad_batches=4 # Accumulate over 4 batches
)Learning Rate Scheduling
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=5
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1
}
}Mixed Precision Training
trainer = Trainer(
precision="16-mixed" # Use 16-bit mixed precision
)Transfer Learning
class TransferLearningModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Load pretrained model
self.feature_extractor = torchvision.models.resnet18(pretrained=True)
# Freeze feature extractor
for param in self.feature_extractor.parameters():
param.requires_grad = False
# Replace final layer
num_features = self.feature_extractor.fc.in_features
self.feature_extractor.fc = nn.Linear(num_features, 10)
def unfreeze_features(self):
# Unfreeze model after some training
for param in self.feature_extractor.parameters():
param.requires_grad = TrueConclusion
PyTorch Lightning provides an organized framework that separates research code from engineering boilerplate, making deep learning projects easier to develop, share, and scale. It’s especially valuable for research projects that need to be reproducible and scalable across different hardware configurations.
For further information: