⚡ PyTorch Lightning: A Comprehensive Guide
PyTorch Lightning is a lightweight wrapper for PyTorch that helps organize code and reduce boilerplate while adding powerful features for research and production. This guide will walk you through the basics to advanced techniques.
Table of Contents
Introduction to PyTorch Lightning
PyTorch Lightning separates research code from engineering code, making models more:
- Reproducible: The same code works across different hardware
- Readable: Standard project structure makes collaboration easier
- Scalable: Train on CPUs, GPUs, TPUs or clusters with no code changes
Lightning helps you focus on the science by handling the engineering details.
Installation
pip install pytorch-lightning
For the latest features, you can install from the source:
pip install git+https://github.com/Lightning-AI/lightning.git
Basic Structure: The LightningModule
The core component in Lightning is the LightningModule
, which organizes your PyTorch code into a standardized structure:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class MNISTModel(pl.LightningModule):
def __init__(self, lr=0.001):
super().__init__()
self.save_hyperparameters() # Saves learning_rate to hparams
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
self.lr = lr
def forward(self, x):
= x.size()
batch_size, channels, width, height = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x return x
def training_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.cross_entropy(logits, y)
loss self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.cross_entropy(logits, y)
loss = torch.argmax(logits, dim=1)
preds = (preds == y).float().mean()
acc self.log('val_loss', loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.cross_entropy(logits, y)
loss = torch.argmax(logits, dim=1)
preds = (preds == y).float().mean()
acc self.log('test_loss', loss)
self.log('test_acc', acc)
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=self.lr)
optimizer return optimizer
This basic structure includes:
- Model architecture (
__init__
andforward
) - Training logic (
training_step
) - Validation logic (
validation_step
) - Test logic (
test_step
) - Optimization setup (
configure_optimizers
)
DataModules
Lightning’s LightningDataModule
encapsulates all data-related logic:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir='./data', batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),0.1307,), (0.3081,))
transforms.Normalize((
])
def prepare_data(self):
# Download data if needed
self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
MNIST(
def setup(self, stage=None):
# Assign train/val/test datasets
if stage == 'fit' or stage is None:
= MNIST(self.data_dir, train=True, transform=self.transform)
mnist_full self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
Benefits of LightningDataModule
:
- Encapsulates all data preparation logic
- Makes data pipeline portable and reproducible
- Simplifies sharing data pipelines between projects
Training with Trainer
The Lightning Trainer
handles the training loop and validation:
from pytorch_lightning import Trainer
# Create model and data module
= MNISTModel()
model = MNISTDataModule()
data_module
# Initialize trainer
= Trainer(
trainer =10,
max_epochs="auto", # Automatically use available GPU
accelerator="auto",
devices=True, # Use TensorBoard logger by default
logger
)
# Train model
=data_module)
trainer.fit(model, datamodule
# Test model
=data_module) trainer.test(model, datamodule
The Trainer
automatically handles:
- Epoch and batch iteration
- Optimizer steps
- Logging metrics
- Hardware acceleration (CPU, GPU, TPU)
- Early stopping
- Checkpointing
- Multi-GPU training
Common Trainer Parameters
= Trainer(
trainer =10, # Maximum number of epochs
max_epochs=1, # Minimum number of epochs
min_epochs="cpu", # Use CPU acceleration
accelerator=1, # Use 1 CPUs
devices="16-mixed", # Use mixed precision for faster training
precision=0.5, # Clip gradients
gradient_clip_val=4, # Accumulate gradients over 4 batches
accumulate_grad_batches=50, # Log metrics every 50 steps
log_every_n_steps=0.25, # Run validation 4 times per epoch
val_check_interval=False, # Debug mode (only run a few batches)
fast_dev_run=True, # Make training deterministic
deterministic )
Callbacks
Callbacks add functionality to the training loop without modifying the core code:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Save the best model based on validation accuracy
= ModelCheckpoint(
checkpoint_callback ='val_acc',
monitor='./checkpoints',
dirpath='mnist-{epoch:02d}-{val_acc:.2f}',
filename=3,
save_top_k='max',
mode
)
# Stop training when validation loss stops improving
= EarlyStopping(
early_stop_callback ='val_loss',
monitor=3,
patience=True,
verbose='min'
mode
)
# Monitor learning rate
= LearningRateMonitor(logging_interval='step')
lr_monitor
# Initialize trainer with callbacks
= Trainer(
trainer =10,
max_epochs=[checkpoint_callback, early_stop_callback, lr_monitor]
callbacks )
Custom Callback
You can create custom callbacks by extending the base Callback
class:
from pytorch_lightning.callbacks import Callback
class PrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training has started!")
def on_train_end(self, trainer, pl_module):
print("Training has finished!")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % 100 == 0:
print(f"Batch {batch_idx} completed")
Logging
Lightning supports various loggers:
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
# TensorBoard logger
= TensorBoardLogger(save_dir="logs/", name="mnist")
tensorboard_logger
# Initialize trainer with loggers
= Trainer(
trainer =10,
max_epochs=[tensorboard_logger]
logger )
Adding metrics in your model:
def training_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.cross_entropy(logits, y)
loss
# Log scalar values
self.log('train_loss', loss)
# Log learning rate
= self.optimizers().param_groups[0]['lr']
lr self.log('learning_rate', lr)
# Log histograms (on GPU)
if batch_idx % 100 == 0:
self.logger.experiment.add_histogram('logits', logits, self.global_step)
return loss
Distributed Training
Lightning makes distributed training simple:
# Single GPU
= Trainer(accelerator="gpu", devices=1)
trainer
# Multiple GPUs (DDP strategy automatically selected)
= Trainer(accelerator="gpu", devices=4)
trainer
# Specific GPU indices
= Trainer(accelerator="gpu", devices=[0, 1, 3])
trainer
# TPU cores
= Trainer(accelerator="tpu", devices=8)
trainer
# Advanced distributed training
= Trainer(
trainer ="gpu",
accelerator=4,
devices="ddp", # Distributed Data Parallel
strategy=2 # Use 2 machines
num_nodes )
Other distribution strategies:
ddp_spawn
: Similar to DDP but uses spawn for multiprocessingdeepspeed
: For very large models using DeepSpeedfsdp
: Fully Sharded Data Parallel for huge models
Hyperparameter Tuning
Lightning integrates well with optimization libraries:
from pytorch_lightning import Trainer
from pytorch_lightning.tuner import Tuner
# Create model
= MNISTModel()
model = MNISTDataModule()
data_module
# Create trainer
= Trainer(max_epochs=10)
trainer
# Use Lightning's Tuner for auto-scaling batch size
= Tuner(trainer)
tuner =data_module)
tuner.scale_batch_size(model, datamodule
# Find optimal learning rate
= tuner.lr_find(model, datamodule=data_module)
lr_finder = lr_finder.suggestion()
model.learning_rate
# Train with optimized parameters
=data_module) trainer.fit(model, datamodule
Model Checkpointing
Saving and loading model checkpoints:
# Save checkpoints during training
= ModelCheckpoint(
checkpoint_callback ='./checkpoints',
dirpath='{epoch}-{val_loss:.2f}',
filename='val_loss',
monitor=3,
save_top_k='min'
mode
)
= Trainer(
trainer =10,
max_epochs=[checkpoint_callback]
callbacks
)
=data_module)
trainer.fit(model, datamodule
# Path to best checkpoint
= checkpoint_callback.best_model_path
best_model_path
# Load checkpoint into model
= MNISTModel.load_from_checkpoint(best_model_path)
model
# Continue training from checkpoint
=data_module) trainer.fit(model, datamodule
Checkpointing for production:
# Save model in production-ready format
"model.ckpt")
trainer.save_checkpoint(
# Extract state dict for production
= torch.load("model.ckpt")
checkpoint = checkpoint["state_dict"]
model_state_dict
# Save just the PyTorch model for production
= MNISTModel()
model
model.load_state_dict(model_state_dict)"production_model.pt") torch.save(model.state_dict(),
Production Deployment
Converting Lightning models to production:
# Option 1: Use the Lightning model directly
= MNISTModel.load_from_checkpoint("model.ckpt")
model eval()
model.# Freeze the model parameters
model.freeze()
# Make predictions
with torch.no_grad():
= torch.randn(1, 1, 28, 28)
x = model(x)
y_hat
# Option 2: Extract the PyTorch model
class ProductionModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
= x.view(x.size(0), -1)
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
x return x
# Load weights from Lightning model
= MNISTModel.load_from_checkpoint("model.ckpt")
lightning_model = ProductionModel()
production_model
# Copy weights
= lightning_model.layer_1.weight.data
production_model.layer_1.weight.data = lightning_model.layer_1.bias.data
production_model.layer_1.bias.data = lightning_model.layer_2.weight.data
production_model.layer_2.weight.data = lightning_model.layer_2.bias.data
production_model.layer_2.bias.data
# Save production model
"production_model.pt")
torch.save(production_model.state_dict(),
# Export to ONNX
= torch.randn(1, 1, 28, 28)
dummy_input
torch.onnx.export(
production_model,
dummy_input,"model.onnx",
=True,
export_params=11,
opset_version=['input'],
input_names=['output']
output_names )
Best Practices
Code Organization
A well-organized Lightning project structure:
project/
├── configs/ # Configuration files
├── data/ # Data files
├── lightning_logs/ # Generated logs
├── models/ # Model definitions
│ ├── __init__.py
│ └── mnist_model.py # LightningModule
├── data_modules/ # Data modules
│ ├── __init__.py
│ └── mnist_data.py # LightningDataModule
├── callbacks/ # Custom callbacks
├── utils/ # Utility functions
├── main.py # Training script
└── README.md
Lightning CLI
Lightning provides a CLI for running experiments from config files:
# main.py
from pytorch_lightning.cli import LightningCLI
def cli_main():
# Create CLI with LightningModule and LightningDataModule
= LightningCLI(
cli
MNISTModel,
MNISTDataModule,=None,
save_config_callback
)
if __name__ == "__main__":
cli_main()
Run with:
python main.py fit --config configs/mnist.yaml
Example config file (mnist.yaml):
model:
learning_rate: 0.001
hidden_size: 128
data:
data_dir: ./data
batch_size: 64
trainer:
max_epochs: 10
accelerator: gpu
devices: 1
Profiling
Lightning includes tools to profile your code:
from pytorch_lightning.profilers import PyTorchProfiler, SimpleProfiler
# PyTorch Profiler
= PyTorchProfiler(
profiler =torch.profiler.tensorboard_trace_handler("logs/profiler"),
on_trace_ready=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2)
schedule
)
# Simple Profiler
# profiler = SimpleProfiler()
= MNISTModel()
model = MNISTDataModule()
data_module = Trainer(
trainer =5,
max_epochs=profiler
profiler
)
=data_module) trainer.fit(model, datamodule
Advanced Features
Gradient Accumulation
= Trainer(
trainer =4 # Accumulate over 4 batches
accumulate_grad_batches )
Learning Rate Scheduling
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=1e-3)
optimizer = torch.optim.lr_scheduler.ReduceLROnPlateau(
scheduler
optimizer, ='min',
mode=0.1,
factor=5
patience
)return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1
} }
Mixed Precision Training
= Trainer(
trainer ="16-mixed" # Use 16-bit mixed precision
precision )
Transfer Learning
class TransferLearningModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Load pretrained model
self.feature_extractor = torchvision.models.resnet18(pretrained=True)
# Freeze feature extractor
for param in self.feature_extractor.parameters():
= False
param.requires_grad
# Replace final layer
= self.feature_extractor.fc.in_features
num_features self.feature_extractor.fc = nn.Linear(num_features, 10)
def unfreeze_features(self):
# Unfreeze model after some training
for param in self.feature_extractor.parameters():
= True param.requires_grad
Conclusion
PyTorch Lightning provides an organized framework that separates research code from engineering boilerplate, making deep learning projects easier to develop, share, and scale. It’s especially valuable for research projects that need to be reproducible and scalable across different hardware configurations.
For further information: