PyTorch to PyTorch Lightning Migration Guide

code
tutorial
advanced
Author

Krishnatheja Vanka

Published

May 25, 2025

PyTorch to PyTorch Lightning Migration Guide

Table of Contents

  1. Introduction
  2. Key Concepts
  3. Basic Migration Steps
  4. Code Examples
  5. Advanced Features
  6. Best Practices
  7. Common Pitfalls

Introduction

PyTorch Lightning is a lightweight wrapper around PyTorch that eliminates boilerplate code while maintaining full control over your models. It provides a structured approach to organizing PyTorch code and includes built-in support for distributed training, logging, and experiment management.

Why Migrate?

  • Reduced Boilerplate: Lightning handles training loops, device management, and distributed training
  • Better Organization: Standardized code structure improves readability and maintenance
  • Built-in Features: Automatic logging, checkpointing, early stopping, and more
  • Scalability: Easy multi-GPU and multi-node training
  • Reproducibility: Better experiment tracking and configuration management

Key Concepts

LightningModule

The core abstraction that wraps your PyTorch model. It defines:

  • Model architecture (__init__)
  • Forward pass (forward)
  • Training step (training_step)
  • Validation step (validation_step)
  • Optimizer configuration (configure_optimizers)

Trainer

Handles the training loop, device management, and various training configurations.

DataModule

Encapsulates data loading logic, including datasets and dataloaders.

Basic Migration Steps

Step 1: Convert Model to LightningModule

Before (PyTorch):

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

After (Lightning):

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LightningModel(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

Step 2: Replace Training Loop with Trainer

Before (PyTorch):

# Training loop
model = SimpleModel(input_size=784, hidden_size=128, num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {correct/len(val_dataset)}')

After (Lightning):

# Training with Lightning
model = LightningModel(input_size=784, hidden_size=128, num_classes=10)
trainer = pl.Trainer(max_epochs=10, accelerator='auto', devices='auto')
trainer.fit(model, train_loader, val_loader)

Code Examples

Complete MNIST Example

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "data/", batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        # Download data
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets
        if stage == 'fit':
            mnist_full = datasets.MNIST(
                self.data_dir, train=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000]
            )

        if stage == 'test':
            self.mnist_test = datasets.MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

class MNISTClassifier(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = torch.sum(preds == y).float() / len(y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = torch.sum(preds == y).float() / len(y)
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Usage
if __name__ == "__main__":
    # Initialize data module and model
    dm = MNISTDataModule()
    model = MNISTClassifier()
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=5,
        accelerator='auto',
        devices='auto',
        logger=pl.loggers.TensorBoardLogger('lightning_logs/'),
        callbacks=[
            pl.callbacks.EarlyStopping(monitor='val_loss', patience=3),
            pl.callbacks.ModelCheckpoint(monitor='val_loss', save_top_k=1)
        ]
    )
    
    # Train the model
    trainer.fit(model, dm)
    
    # Test the model
    trainer.test(model, dm)

Advanced Model with Custom Metrics

import torchmetrics

class AdvancedClassifier(pl.LightningModule):
    def __init__(self, num_classes=10, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        # Model layers
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
        
        # Metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x.view(x.size(0), -1))
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Log metrics
        self.train_accuracy(y_hat, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Update metrics
        self.val_accuracy(y_hat, y)
        self.val_f1(y_hat, y)
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_accuracy, prog_bar=True)
        self.log('val_f1', self.val_f1, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            weight_decay=1e-4
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "frequency": 1
            }
        }

Advanced Features

1. Multiple Optimizers and Schedulers

def configure_optimizers(self):
    # Different learning rates for different parts
    opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
    opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
    
    # Different schedulers
    sch_g = torch.optim.lr_scheduler.StepLR(opt_g, step_size=50, gamma=0.5)
    sch_d = torch.optim.lr_scheduler.StepLR(opt_d, step_size=50, gamma=0.5)
    
    return [opt_g, opt_d], [sch_g, sch_d]

2. Custom Callbacks

class CustomCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # Custom logic at end of each epoch
        if trainer.current_epoch % 10 == 0:
            print(f"Completed {trainer.current_epoch} epochs")
    
    def on_validation_end(self, trainer, pl_module):
        # Custom validation logic
        val_loss = trainer.callback_metrics.get('val_loss')
        if val_loss and val_loss < 0.1:
            print("Excellent validation performance!")

# Usage
trainer = pl.Trainer(
    callbacks=[CustomCallback(), pl.callbacks.EarlyStopping(monitor='val_loss')]
)

3. Manual Optimization

class ManualOptimizationModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False
        # ... model definition
        
    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        
        # Manual optimization
        opt.zero_grad()
        loss = self.compute_loss(batch)
        self.manual_backward(loss)
        opt.step()
        
        self.log('train_loss', loss)
        return loss

Best Practices

1. Hyperparameter Management

class ConfigurableModel(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        # Save all hyperparameters
        self.save_hyperparameters()
        
        # Access with self.hparams
        self.model = self._build_model()
        
    def _build_model(self):
        return nn.Sequential(
            nn.Linear(self.hparams.input_size, self.hparams.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hparams.hidden_size, self.hparams.num_classes)
        )

2. Proper Logging

def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    
    # Log to both step and epoch
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
    
    # Log learning rate
    self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
    
    return loss

3. Model Organization

# Separate model definition from Lightning logic
class ResNetBackbone(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Model architecture here
        
class ResNetLightning(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        # Use separate model class
        self.model = ResNetBackbone(num_classes)
        
    def forward(self, x):
        return self.model(x)
    
    # Training logic here...

4. Testing and Validation

def validation_step(self, batch, batch_idx):
    # Always include validation metrics
    loss, acc = self._shared_eval_step(batch, batch_idx)
    self.log_dict({'val_loss': loss, 'val_acc': acc}, prog_bar=True)
    
def test_step(self, batch, batch_idx):
    # Comprehensive test metrics
    loss, acc = self._shared_eval_step(batch, batch_idx)
    self.log_dict({'test_loss': loss, 'test_acc': acc})
    
def _shared_eval_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    acc = (y_hat.argmax(dim=1) == y).float().mean()
    return loss, acc

Common Pitfalls

1. Device Management

Wrong:

def training_step(self, batch, batch_idx):
    x, y = batch
    x = x.cuda()  # Don't manually move to device
    y = y.cuda()
    # ...

Correct:

def training_step(self, batch, batch_idx):
    x, y = batch  # Lightning handles device placement
    # ...

2. Gradient Accumulation

Wrong:

def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    loss.backward()  # Don't call backward manually
    return loss

Correct:

def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    return loss  # Lightning handles backward pass

3. Metric Computation

Wrong:

def validation_step(self, batch, batch_idx):
    # Computing metrics inside step leads to incorrect averages
    acc = compute_accuracy(batch)
    self.log('val_acc', acc.mean())

Correct:

def __init__(self):
    super().__init__()
    self.val_acc = torchmetrics.Accuracy()

def validation_step(self, batch, batch_idx):
    # Let torchmetrics handle the averaging
    y_hat = self(x)
    self.val_acc(y_hat, y)
    self.log('val_acc', self.val_acc)

4. DataLoader in Model

Wrong:

class Model(pl.LightningModule):
    def train_dataloader(self):
        # Don't put data loading in model
        return DataLoader(...)

Correct:

# Use separate DataModule
class DataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(...)

# Or pass dataloaders to trainer.fit()
trainer.fit(model, train_dataloader, val_dataloader)

Migration Checklist

By following this guide, you should be able to successfully migrate your PyTorch code to PyTorch Lightning while maintaining all functionality and gaining the benefits of Lightning’s structured approach.