PyTorch to PyTorch Lightning Migration Guide
Table of Contents
Introduction
PyTorch Lightning is a lightweight wrapper around PyTorch that eliminates boilerplate code while maintaining full control over your models. It provides a structured approach to organizing PyTorch code and includes built-in support for distributed training, logging, and experiment management.
Why Migrate?
- Reduced Boilerplate: Lightning handles training loops, device management, and distributed training
- Better Organization: Standardized code structure improves readability and maintenance
- Built-in Features: Automatic logging, checkpointing, early stopping, and more
- Scalability: Easy multi-GPU and multi-node training
- Reproducibility: Better experiment tracking and configuration management
Key Concepts
LightningModule
The core abstraction that wraps your PyTorch model. It defines:
- Model architecture (
__init__
) - Forward pass (
forward
) - Training step (
training_step
) - Validation step (
validation_step
) - Optimizer configuration (
configure_optimizers
)
Trainer
Handles the training loop, device management, and various training configurations.
DataModule
Encapsulates data loading logic, including datasets and dataloaders.
Basic Migration Steps
Step 1: Convert Model to LightningModule
Before (PyTorch):
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
= self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x return x
After (Lightning):
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LightningModel(pl.LightningModule):
def __init__(self, input_size, hidden_size, num_classes, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
= self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x return x
def training_step(self, batch, batch_idx):
= batch
x, y = self(x)
y_hat = F.cross_entropy(y_hat, y)
loss self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
= batch
x, y = self(x)
y_hat = F.cross_entropy(y_hat, y)
loss = (y_hat.argmax(dim=1) == y).float().mean()
acc self.log('val_loss', loss)
self.log('val_acc', acc)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
Step 2: Replace Training Loop with Trainer
Before (PyTorch):
# Training loop
= SimpleModel(input_size=784, hidden_size=128, num_classes=10)
model = optim.Adam(model.parameters(), lr=1e-3)
optimizer = nn.CrossEntropyLoss()
criterion
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
model.to(device)
for epoch in range(num_epochs):
model.train()for batch_idx, (data, target) in enumerate(train_loader):
= data.to(device), target.to(device)
data, target
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
# Validation
eval()
model.= 0
val_loss = 0
correct with torch.no_grad():
for data, target in val_loader:
= data.to(device), target.to(device)
data, target = model(data)
output += criterion(output, target).item()
val_loss = output.argmax(dim=1)
pred += pred.eq(target).sum().item()
correct
print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {correct/len(val_dataset)}')
After (Lightning):
# Training with Lightning
= LightningModel(input_size=784, hidden_size=128, num_classes=10)
model = pl.Trainer(max_epochs=10, accelerator='auto', devices='auto')
trainer trainer.fit(model, train_loader, val_loader)
Code Examples
Complete MNIST Example
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "data/", batch_size: int = 64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),0.1307,), (0.3081,))
transforms.Normalize((
])
def prepare_data(self):
# Download data
self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
datasets.MNIST(
def setup(self, stage: str):
# Assign train/val datasets
if stage == 'fit':
= datasets.MNIST(
mnist_full self.data_dir, train=True, transform=self.transform
)self.mnist_train, self.mnist_val = random_split(
55000, 5000]
mnist_full, [
)
if stage == 'test':
self.mnist_test = datasets.MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
class MNISTClassifier(pl.LightningModule):
def __init__(self, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
= self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.nll_loss(logits, y)
loss self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.nll_loss(logits, y)
loss = torch.argmax(logits, dim=1)
preds = torch.sum(preds == y).float() / len(y)
acc self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def test_step(self, batch, batch_idx):
= batch
x, y = self(x)
logits = F.nll_loss(logits, y)
loss = torch.argmax(logits, dim=1)
preds = torch.sum(preds == y).float() / len(y)
acc self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
scheduler return {"optimizer": optimizer, "lr_scheduler": scheduler}
# Usage
if __name__ == "__main__":
# Initialize data module and model
= MNISTDataModule()
dm = MNISTClassifier()
model
# Initialize trainer
= pl.Trainer(
trainer =5,
max_epochs='auto',
accelerator='auto',
devices=pl.loggers.TensorBoardLogger('lightning_logs/'),
logger=[
callbacks='val_loss', patience=3),
pl.callbacks.EarlyStopping(monitor='val_loss', save_top_k=1)
pl.callbacks.ModelCheckpoint(monitor
]
)
# Train the model
trainer.fit(model, dm)
# Test the model
trainer.test(model, dm)
Advanced Model with Custom Metrics
import torchmetrics
class AdvancedClassifier(pl.LightningModule):
def __init__(self, num_classes=10, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Model layers
self.model = nn.Sequential(
784, 256),
nn.Linear(
nn.ReLU(),0.2),
nn.Dropout(256, 128),
nn.Linear(
nn.ReLU(),0.2),
nn.Dropout(128, num_classes)
nn.Linear(
)
# Metrics
self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
def training_step(self, batch, batch_idx):
= batch
x, y = self(x)
y_hat = F.cross_entropy(y_hat, y)
loss
# Log metrics
self.train_accuracy(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True)
self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
= batch
x, y = self(x)
y_hat = F.cross_entropy(y_hat, y)
loss
# Update metrics
self.val_accuracy(y_hat, y)
self.val_f1(y_hat, y)
# Log metrics
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.val_accuracy, prog_bar=True)
self.log('val_f1', self.val_f1, prog_bar=True)
def configure_optimizers(self):
= torch.optim.AdamW(
optimizer self.parameters(),
=self.hparams.learning_rate,
lr=1e-4
weight_decay
)
= torch.optim.lr_scheduler.CosineAnnealingLR(
scheduler =100
optimizer, T_max
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"frequency": 1
} }
Advanced Features
1. Multiple Optimizers and Schedulers
def configure_optimizers(self):
# Different learning rates for different parts
= torch.optim.Adam(self.generator.parameters(), lr=0.0002)
opt_g = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
opt_d
# Different schedulers
= torch.optim.lr_scheduler.StepLR(opt_g, step_size=50, gamma=0.5)
sch_g = torch.optim.lr_scheduler.StepLR(opt_d, step_size=50, gamma=0.5)
sch_d
return [opt_g, opt_d], [sch_g, sch_d]
2. Custom Callbacks
class CustomCallback(pl.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Custom logic at end of each epoch
if trainer.current_epoch % 10 == 0:
print(f"Completed {trainer.current_epoch} epochs")
def on_validation_end(self, trainer, pl_module):
# Custom validation logic
= trainer.callback_metrics.get('val_loss')
val_loss if val_loss and val_loss < 0.1:
print("Excellent validation performance!")
# Usage
= pl.Trainer(
trainer =[CustomCallback(), pl.callbacks.EarlyStopping(monitor='val_loss')]
callbacks )
3. Manual Optimization
class ManualOptimizationModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False
# ... model definition
def training_step(self, batch, batch_idx):
= self.optimizers()
opt
# Manual optimization
opt.zero_grad()= self.compute_loss(batch)
loss self.manual_backward(loss)
opt.step()
self.log('train_loss', loss)
return loss
Best Practices
1. Hyperparameter Management
class ConfigurableModel(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
# Save all hyperparameters
self.save_hyperparameters()
# Access with self.hparams
self.model = self._build_model()
def _build_model(self):
return nn.Sequential(
self.hparams.input_size, self.hparams.hidden_size),
nn.Linear(
nn.ReLU(),self.hparams.hidden_size, self.hparams.num_classes)
nn.Linear( )
2. Proper Logging
def training_step(self, batch, batch_idx):
= self.compute_loss(batch)
loss
# Log to both step and epoch
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# Log learning rate
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
return loss
3. Model Organization
# Separate model definition from Lightning logic
class ResNetBackbone(nn.Module):
def __init__(self, num_classes):
super().__init__()
# Model architecture here
class ResNetLightning(pl.LightningModule):
def __init__(self, num_classes, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Use separate model class
self.model = ResNetBackbone(num_classes)
def forward(self, x):
return self.model(x)
# Training logic here...
4. Testing and Validation
def validation_step(self, batch, batch_idx):
# Always include validation metrics
= self._shared_eval_step(batch, batch_idx)
loss, acc self.log_dict({'val_loss': loss, 'val_acc': acc}, prog_bar=True)
def test_step(self, batch, batch_idx):
# Comprehensive test metrics
= self._shared_eval_step(batch, batch_idx)
loss, acc self.log_dict({'test_loss': loss, 'test_acc': acc})
def _shared_eval_step(self, batch, batch_idx):
= batch
x, y = self(x)
y_hat = F.cross_entropy(y_hat, y)
loss = (y_hat.argmax(dim=1) == y).float().mean()
acc return loss, acc
Common Pitfalls
1. Device Management
Wrong:
def training_step(self, batch, batch_idx):
= batch
x, y = x.cuda() # Don't manually move to device
x = y.cuda()
y # ...
Correct:
def training_step(self, batch, batch_idx):
= batch # Lightning handles device placement
x, y # ...
2. Gradient Accumulation
Wrong:
def training_step(self, batch, batch_idx):
= self.compute_loss(batch)
loss # Don't call backward manually
loss.backward() return loss
Correct:
def training_step(self, batch, batch_idx):
= self.compute_loss(batch)
loss return loss # Lightning handles backward pass
3. Metric Computation
Wrong:
def validation_step(self, batch, batch_idx):
# Computing metrics inside step leads to incorrect averages
= compute_accuracy(batch)
acc self.log('val_acc', acc.mean())
Correct:
def __init__(self):
super().__init__()
self.val_acc = torchmetrics.Accuracy()
def validation_step(self, batch, batch_idx):
# Let torchmetrics handle the averaging
= self(x)
y_hat self.val_acc(y_hat, y)
self.log('val_acc', self.val_acc)
4. DataLoader in Model
Wrong:
class Model(pl.LightningModule):
def train_dataloader(self):
# Don't put data loading in model
return DataLoader(...)
Correct:
# Use separate DataModule
class DataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(...)
# Or pass dataloaders to trainer.fit()
trainer.fit(model, train_dataloader, val_dataloader)
Migration Checklist
By following this guide, you should be able to successfully migrate your PyTorch code to PyTorch Lightning while maintaining all functionality and gaining the benefits of Lightning’s structured approach.