PyTorch to PyTorch Lightning Migration Guide

Introduction
PyTorch Lightning is a lightweight wrapper around PyTorch that eliminates boilerplate code while maintaining full control over your models. It provides a structured approach to organizing PyTorch code and includes built-in support for distributed training, logging, and experiment management.
Why Migrate?
- Reduced Boilerplate: Lightning handles training loops, device management, and distributed training
- Better Organization: Standardized code structure improves readability and maintenance
- Built-in Features: Automatic logging, checkpointing, early stopping, and more
- Scalability: Easy multi-GPU and multi-node training
- Reproducibility: Better experiment tracking and configuration management
Key Concepts
LightningModule
The core abstraction that wraps your PyTorch model. It defines:
- Model architecture (
__init__) - Forward pass (
forward) - Training step (
training_step) - Validation step (
validation_step) - Optimizer configuration (
configure_optimizers)
Trainer
Handles the training loop, device management, and various training configurations.
DataModule
Encapsulates data loading logic, including datasets and dataloaders.
Basic Migration Steps
Step 1: Convert Model to LightningModule
Before (PyTorch):
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return xAfter (Lightning):
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LightningModel(pl.LightningModule):
def __init__(self, input_size, hidden_size, num_classes, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_loss', loss)
self.log('val_acc', acc)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)Step 2: Replace Training Loop with Trainer
Before (PyTorch):
# Training loop
model = SimpleModel(input_size=784, hidden_size=128, num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
# Validation
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {correct/len(val_dataset)}')After (Lightning):
# Training with Lightning
model = LightningModel(input_size=784, hidden_size=128, num_classes=10)
trainer = pl.Trainer(max_epochs=10, accelerator='auto', devices='auto')
trainer.fit(model, train_loader, val_loader)Code Examples
Complete MNIST Example
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "data/", batch_size: int = 64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
# Download data
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets
if stage == 'fit':
mnist_full = datasets.MNIST(
self.data_dir, train=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000]
)
if stage == 'test':
self.mnist_test = datasets.MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
class MNISTClassifier(pl.LightningModule):
def __init__(self, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == y).float() / len(y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == y).float() / len(y)
self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
# Usage
if __name__ == "__main__":
# Initialize data module and model
dm = MNISTDataModule()
model = MNISTClassifier()
# Initialize trainer
trainer = pl.Trainer(
max_epochs=5,
accelerator='auto',
devices='auto',
logger=pl.loggers.TensorBoardLogger('lightning_logs/'),
callbacks=[
pl.callbacks.EarlyStopping(monitor='val_loss', patience=3),
pl.callbacks.ModelCheckpoint(monitor='val_loss', save_top_k=1)
]
)
# Train the model
trainer.fit(model, dm)
# Test the model
trainer.test(model, dm)Advanced Model with Custom Metrics
import torchmetrics
class AdvancedClassifier(pl.LightningModule):
def __init__(self, num_classes=10, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Model layers
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
# Metrics
self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# Log metrics
self.train_accuracy(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True)
self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# Update metrics
self.val_accuracy(y_hat, y)
self.val_f1(y_hat, y)
# Log metrics
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.val_accuracy, prog_bar=True)
self.log('val_f1', self.val_f1, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"frequency": 1
}
}Advanced Features
1. Multiple Optimizers and Schedulers
def configure_optimizers(self):
# Different learning rates for different parts
opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
# Different schedulers
sch_g = torch.optim.lr_scheduler.StepLR(opt_g, step_size=50, gamma=0.5)
sch_d = torch.optim.lr_scheduler.StepLR(opt_d, step_size=50, gamma=0.5)
return [opt_g, opt_d], [sch_g, sch_d]2. Custom Callbacks
class CustomCallback(pl.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Custom logic at end of each epoch
if trainer.current_epoch % 10 == 0:
print(f"Completed {trainer.current_epoch} epochs")
def on_validation_end(self, trainer, pl_module):
# Custom validation logic
val_loss = trainer.callback_metrics.get('val_loss')
if val_loss and val_loss < 0.1:
print("Excellent validation performance!")
# Usage
trainer = pl.Trainer(
callbacks=[CustomCallback(), pl.callbacks.EarlyStopping(monitor='val_loss')]
)3. Manual Optimization
class ManualOptimizationModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False
# ... model definition
def training_step(self, batch, batch_idx):
opt = self.optimizers()
# Manual optimization
opt.zero_grad()
loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step()
self.log('train_loss', loss)
return lossBest Practices
1. Hyperparameter Management
class ConfigurableModel(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
# Save all hyperparameters
self.save_hyperparameters()
# Access with self.hparams
self.model = self._build_model()
def _build_model(self):
return nn.Sequential(
nn.Linear(self.hparams.input_size, self.hparams.hidden_size),
nn.ReLU(),
nn.Linear(self.hparams.hidden_size, self.hparams.num_classes)
)2. Proper Logging
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log to both step and epoch
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# Log learning rate
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
return loss3. Model Organization
# Separate model definition from Lightning logic
class ResNetBackbone(nn.Module):
def __init__(self, num_classes):
super().__init__()
# Model architecture here
class ResNetLightning(pl.LightningModule):
def __init__(self, num_classes, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Use separate model class
self.model = ResNetBackbone(num_classes)
def forward(self, x):
return self.model(x)
# Training logic here...4. Testing and Validation
def validation_step(self, batch, batch_idx):
# Always include validation metrics
loss, acc = self._shared_eval_step(batch, batch_idx)
self.log_dict({'val_loss': loss, 'val_acc': acc}, prog_bar=True)
def test_step(self, batch, batch_idx):
# Comprehensive test metrics
loss, acc = self._shared_eval_step(batch, batch_idx)
self.log_dict({'test_loss': loss, 'test_acc': acc})
def _shared_eval_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
return loss, accCommon Pitfalls
1. Device Management
Wrong:
def training_step(self, batch, batch_idx):
x, y = batch
x = x.cuda() # Don't manually move to device
y = y.cuda()
# ...Correct:
def training_step(self, batch, batch_idx):
x, y = batch # Lightning handles device placement
# ...2. Gradient Accumulation
Wrong:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
loss.backward() # Don't call backward manually
return lossCorrect:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
return loss # Lightning handles backward pass3. Metric Computation
Wrong:
def validation_step(self, batch, batch_idx):
# Computing metrics inside step leads to incorrect averages
acc = compute_accuracy(batch)
self.log('val_acc', acc.mean())Correct:
def __init__(self):
super().__init__()
self.val_acc = torchmetrics.Accuracy()
def validation_step(self, batch, batch_idx):
# Let torchmetrics handle the averaging
y_hat = self(x)
self.val_acc(y_hat, y)
self.log('val_acc', self.val_acc)4. DataLoader in Model
Wrong:
class Model(pl.LightningModule):
def train_dataloader(self):
# Don't put data loading in model
return DataLoader(...)Correct:
# Use separate DataModule
class DataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(...)
# Or pass dataloaders to trainer.fit()
trainer.fit(model, train_dataloader, val_dataloader)Migration Checklist
By following this guide, you should be able to successfully migrate your PyTorch code to PyTorch Lightning while maintaining all functionality and gaining the benefits of Lightning’s structured approach.