PyTorch Lightning: A Comprehensive Guide

code
tutorial
advanced
Author

Krishnatheja Vanka

Published

March 29, 2025

⚡ PyTorch Lightning: A Comprehensive Guide

PyTorch Lightning is a lightweight wrapper for PyTorch that helps organize code and reduce boilerplate while adding powerful features for research and production. This guide will walk you through the basics to advanced techniques.

Table of Contents

  1. Introduction to PyTorch Lightning
  2. Installation
  3. Basic Structure: The LightningModule
  4. DataModules
  5. Training with Trainer
  6. Callbacks
  7. Logging
  8. Distributed Training
  9. Hyperparameter Tuning
  10. Model Checkpointing
  11. Production Deployment
  12. Best Practices

Introduction to PyTorch Lightning

PyTorch Lightning separates research code from engineering code, making models more:

  • Reproducible: The same code works across different hardware
  • Readable: Standard project structure makes collaboration easier
  • Scalable: Train on CPUs, GPUs, TPUs or clusters with no code changes

Lightning helps you focus on the science by handling the engineering details.

Installation

pip install pytorch-lightning

For the latest features, you can install from the source:

pip install git+https://github.com/Lightning-AI/lightning.git

Basic Structure: The LightningModule

The core component in Lightning is the LightningModule, which organizes your PyTorch code into a standardized structure:

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTModel(pl.LightningModule):
    def __init__(self, lr=0.001):
        super().__init__()
        self.save_hyperparameters()  # Saves learning_rate to hparams
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)
        self.lr = lr
        
    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

This basic structure includes:

  • Model architecture (__init__ and forward)
  • Training logic (training_step)
  • Validation logic (validation_step)
  • Test logic (test_step)
  • Optimization setup (configure_optimizers)

DataModules

Lightning’s LightningDataModule encapsulates all data-related logic:

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir='./data', batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    def prepare_data(self):
        # Download data if needed
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
    def setup(self, stage=None):
        # Assign train/val/test datasets
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
            
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)
        
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
        
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

Benefits of LightningDataModule:

  • Encapsulates all data preparation logic
  • Makes data pipeline portable and reproducible
  • Simplifies sharing data pipelines between projects

Training with Trainer

The Lightning Trainer handles the training loop and validation:

from pytorch_lightning import Trainer

# Create model and data module
model = MNISTModel()
data_module = MNISTDataModule()

# Initialize trainer
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",  # Automatically use available GPU
    devices="auto",
    logger=True,         # Use TensorBoard logger by default
)

# Train model
trainer.fit(model, datamodule=data_module)

# Test model
trainer.test(model, datamodule=data_module)

The Trainer automatically handles:

  • Epoch and batch iteration
  • Optimizer steps
  • Logging metrics
  • Hardware acceleration (CPU, GPU, TPU)
  • Early stopping
  • Checkpointing
  • Multi-GPU training

Common Trainer Parameters

trainer = Trainer(
    max_epochs=10,                    # Maximum number of epochs
    min_epochs=1,                     # Minimum number of epochs
    accelerator="cpu",                # Use CPU acceleration
    devices=1,                        # Use 1 CPUs
    precision="16-mixed",             # Use mixed precision for faster training
    gradient_clip_val=0.5,            # Clip gradients
    accumulate_grad_batches=4,        # Accumulate gradients over 4 batches
    log_every_n_steps=50,             # Log metrics every 50 steps
    val_check_interval=0.25,          # Run validation 4 times per epoch
    fast_dev_run=False,               # Debug mode (only run a few batches)
    deterministic=True,               # Make training deterministic
)

Callbacks

Callbacks add functionality to the training loop without modifying the core code:

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Save the best model based on validation accuracy
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath='./checkpoints',
    filename='mnist-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    mode='max',
)

# Stop training when validation loss stops improving
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    verbose=True,
    mode='min'
)

# Monitor learning rate
lr_monitor = LearningRateMonitor(logging_interval='step')

# Initialize trainer with callbacks
trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor]
)

Custom Callback

You can create custom callbacks by extending the base Callback class:

from pytorch_lightning.callbacks import Callback

class PrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training has started!")
        
    def on_train_end(self, trainer, pl_module):
        print("Training has finished!")
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx} completed")

Logging

Lightning supports various loggers:

from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

# TensorBoard logger
tensorboard_logger = TensorBoardLogger(save_dir="logs/", name="mnist")

# Initialize trainer with loggers
trainer = Trainer(
    max_epochs=10,
    logger=[tensorboard_logger]
)

Adding metrics in your model:

def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.cross_entropy(logits, y)
    
    # Log scalar values
    self.log('train_loss', loss)
    
    # Log learning rate
    lr = self.optimizers().param_groups[0]['lr']
    self.log('learning_rate', lr)
    
    # Log histograms (on GPU)
    if batch_idx % 100 == 0:
        self.logger.experiment.add_histogram('logits', logits, self.global_step)
    
    return loss

Distributed Training

Lightning makes distributed training simple:

# Single GPU
trainer = Trainer(accelerator="gpu", devices=1)

# Multiple GPUs (DDP strategy automatically selected)
trainer = Trainer(accelerator="gpu", devices=4)

# Specific GPU indices
trainer = Trainer(accelerator="gpu", devices=[0, 1, 3])

# TPU cores
trainer = Trainer(accelerator="tpu", devices=8)

# Advanced distributed training
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy="ddp",  # Distributed Data Parallel
    num_nodes=2      # Use 2 machines
)

Other distribution strategies:

  • ddp_spawn: Similar to DDP but uses spawn for multiprocessing
  • deepspeed: For very large models using DeepSpeed
  • fsdp: Fully Sharded Data Parallel for huge models

Hyperparameter Tuning

Lightning integrates well with optimization libraries:

from pytorch_lightning import Trainer
from pytorch_lightning.tuner import Tuner

# Create model
model = MNISTModel()
data_module = MNISTDataModule()

# Create trainer
trainer = Trainer(max_epochs=10)

# Use Lightning's Tuner for auto-scaling batch size
tuner = Tuner(trainer)
tuner.scale_batch_size(model, datamodule=data_module)

# Find optimal learning rate
lr_finder = tuner.lr_find(model, datamodule=data_module)
model.learning_rate = lr_finder.suggestion()

# Train with optimized parameters
trainer.fit(model, datamodule=data_module)

Model Checkpointing

Saving and loading model checkpoints:

# Save checkpoints during training
checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints',
    filename='{epoch}-{val_loss:.2f}',
    monitor='val_loss',
    save_top_k=3,
    mode='min'
)

trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback]
)

trainer.fit(model, datamodule=data_module)

# Path to best checkpoint
best_model_path = checkpoint_callback.best_model_path

# Load checkpoint into model
model = MNISTModel.load_from_checkpoint(best_model_path)

# Continue training from checkpoint
trainer.fit(model, datamodule=data_module)

Checkpointing for production:

# Save model in production-ready format
trainer.save_checkpoint("model.ckpt")

# Extract state dict for production
checkpoint = torch.load("model.ckpt")
model_state_dict = checkpoint["state_dict"]

# Save just the PyTorch model for production
model = MNISTModel()
model.load_state_dict(model_state_dict)
torch.save(model.state_dict(), "production_model.pt")

Production Deployment

Converting Lightning models to production:

# Option 1: Use the Lightning model directly
model = MNISTModel.load_from_checkpoint("model.ckpt")
model.eval()
model.freeze()  # Freeze the model parameters

# Make predictions
with torch.no_grad():
    x = torch.randn(1, 1, 28, 28)
    y_hat = model(x)

# Option 2: Extract the PyTorch model
class ProductionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x

# Load weights from Lightning model
lightning_model = MNISTModel.load_from_checkpoint("model.ckpt")
production_model = ProductionModel()

# Copy weights
production_model.layer_1.weight.data = lightning_model.layer_1.weight.data
production_model.layer_1.bias.data = lightning_model.layer_1.bias.data
production_model.layer_2.weight.data = lightning_model.layer_2.weight.data
production_model.layer_2.bias.data = lightning_model.layer_2.bias.data

# Save production model
torch.save(production_model.state_dict(), "production_model.pt")

# Export to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    production_model,
    dummy_input,
    "model.onnx",
    export_params=True,
    opset_version=11,
    input_names=['input'],
    output_names=['output']
)

Best Practices

Code Organization

A well-organized Lightning project structure:

project/
├── configs/              # Configuration files
├── data/                 # Data files
├── lightning_logs/       # Generated logs
├── models/               # Model definitions
│   ├── __init__.py
│   └── mnist_model.py    # LightningModule
├── data_modules/         # Data modules
│   ├── __init__.py
│   └── mnist_data.py     # LightningDataModule
├── callbacks/            # Custom callbacks
├── utils/                # Utility functions
├── main.py               # Training script
└── README.md

Lightning CLI

Lightning provides a CLI for running experiments from config files:

# main.py
from pytorch_lightning.cli import LightningCLI

def cli_main():
    # Create CLI with LightningModule and LightningDataModule
    cli = LightningCLI(
        MNISTModel,
        MNISTDataModule,
        save_config_callback=None,
    )

if __name__ == "__main__":
    cli_main()

Run with:

python main.py fit --config configs/mnist.yaml

Example config file (mnist.yaml):

model:
  learning_rate: 0.001
  hidden_size: 128
data:
  data_dir: ./data
  batch_size: 64
trainer:
  max_epochs: 10
  accelerator: gpu
  devices: 1

Profiling

Lightning includes tools to profile your code:

from pytorch_lightning.profilers import PyTorchProfiler, SimpleProfiler

# PyTorch Profiler
profiler = PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler("logs/profiler"),
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2)
)

# Simple Profiler
# profiler = SimpleProfiler()
model = MNISTModel()
data_module = MNISTDataModule()
trainer = Trainer(
    max_epochs=5,
    profiler=profiler
)

trainer.fit(model, datamodule=data_module)

Advanced Features

Gradient Accumulation

trainer = Trainer(
    accumulate_grad_batches=4  # Accumulate over 4 batches
)

Learning Rate Scheduling

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.1, 
        patience=5
    )
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "val_loss",
            "interval": "epoch",
            "frequency": 1
        }
    }

Mixed Precision Training

trainer = Trainer(
    precision="16-mixed"  # Use 16-bit mixed precision
)

Transfer Learning

class TransferLearningModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Load pretrained model
        self.feature_extractor = torchvision.models.resnet18(pretrained=True)
        
        # Freeze feature extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
            
        # Replace final layer
        num_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Linear(num_features, 10)
        
    def unfreeze_features(self):
        # Unfreeze model after some training
        for param in self.feature_extractor.parameters():
            param.requires_grad = True

Conclusion

PyTorch Lightning provides an organized framework that separates research code from engineering boilerplate, making deep learning projects easier to develop, share, and scale. It’s especially valuable for research projects that need to be reproducible and scalable across different hardware configurations.

For further information: