MLflow for PyTorch - Complete Guide
MLflow is an open-source platform for managing the machine learning lifecycle, including experimentation, reproducibility, deployment, and model registry. This guide covers how to integrate MLflow with PyTorch for comprehensive ML workflow management.
Table of Contents
Installation and Setup
Install MLflow
pip install mlflow
pip install torch torchvision
Start MLflow UI
mlflow ui
This starts the MLflow UI at http://localhost:5000
Basic Configuration
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Set tracking URI (optional - defaults to local)
"http://localhost:5000")
mlflow.set_tracking_uri(
# Set experiment name
"pytorch_experiments") mlflow.set_experiment(
Basic MLflow Concepts
- Experiment: A collection of runs for a particular task
- Run: A single execution of your ML code
- Artifact: Files generated during a run (models, plots, data)
- Metric: Numerical values tracked over time
- Parameter: Input configurations for your run
Experiment Tracking
Basic Run Structure
import mlflow
with mlflow.start_run():
# Your training code here
"learning_rate", 0.001)
mlflow.log_param("accuracy", 0.95)
mlflow.log_metric("model.pth") mlflow.log_artifact(
Complete Training Example
import torch
import torch.nn as nn
import torch.optim as optim
import mlflow
import mlflow.pytorch
from sklearn.metrics import accuracy_score
import numpy as np
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
= self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out return out
def train_model():
# Hyperparameters
= 784
input_size = 128
hidden_size = 10
num_classes = 0.001
learning_rate = 64
batch_size = 10
num_epochs
# Start MLflow run
with mlflow.start_run():
# Log hyperparameters
"input_size", input_size)
mlflow.log_param("hidden_size", hidden_size)
mlflow.log_param("num_classes", num_classes)
mlflow.log_param("learning_rate", learning_rate)
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("num_epochs", num_epochs)
mlflow.log_param(
# Initialize model
= SimpleNet(input_size, hidden_size, num_classes)
model = nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=learning_rate)
optimizer
# Training loop
for epoch in range(num_epochs):
= 0.0
running_loss = 0
correct = 0
total
# Simulate training data
for i in range(100): # 100 batches
# Generate dummy data
= torch.randn(batch_size, input_size)
inputs = torch.randint(0, num_classes, (batch_size,))
labels
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
+= loss.item()
running_loss = torch.max(outputs.data, 1)
_, predicted += labels.size(0)
total += (predicted == labels).sum().item()
correct
# Calculate metrics
= running_loss / 100
epoch_loss = 100 * correct / total
epoch_acc
# Log metrics
"loss", epoch_loss, step=epoch)
mlflow.log_metric("accuracy", epoch_acc, step=epoch)
mlflow.log_metric(
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
# Log model
"model")
mlflow.pytorch.log_model(model,
# Log additional artifacts
"model_state_dict.pth")
torch.save(model.state_dict(), "model_state_dict.pth")
mlflow.log_artifact(
# Run training
train_model()
Model Logging
Different Ways to Log PyTorch Models
1. Log Complete Model
# Log the entire model
"complete_model") mlflow.pytorch.log_model(model,
2. Log Model State Dict
# Save and log state dict
"model_state_dict.pth")
torch.save(model.state_dict(), "model_state_dict.pth") mlflow.log_artifact(
3. Log with Custom Code
# Log model with custom code for loading
mlflow.pytorch.log_model(
model, "model",
=["model_definition.py"] # Include custom model definition
code_paths )
4. Log with Conda Environment
import mlflow.pytorch
# Create conda environment specification
= {
conda_env 'channels': ['defaults', 'pytorch'],
'dependencies': [
'python=3.8',
'pytorch',
'torchvision',
'pip': ['mlflow']}
{
],'name': 'pytorch_env'
}
mlflow.pytorch.log_model(
model,"model",
=conda_env
conda_env )
Model Registry
Register Model
# Register model during logging
mlflow.pytorch.log_model(
model, "model",
="MyPyTorchModel"
registered_model_name
)
# Or register existing run
= "runs:/your_run_id/model"
model_uri "MyPyTorchModel") mlflow.register_model(model_uri,
Model Versioning and Stages
from mlflow.tracking import MlflowClient
= MlflowClient()
client
# Transition model to different stages
client.transition_model_version_stage(="MyPyTorchModel",
name=1,
version="Production"
stage
)
# Get model by stage
= client.get_latest_versions(
model_version "MyPyTorchModel",
=["Production"]
stages0] )[
Load Registered Model
# Load model from registry
= mlflow.pytorch.load_model(
model =f"models:/MyPyTorchModel/Production"
model_uri
)
# Or load specific version
= mlflow.pytorch.load_model(
model =f"models:/MyPyTorchModel/1"
model_uri )
Model Deployment
Local Serving
# Serve model locally
# Run in terminal:
# mlflow models serve -m models:/MyPyTorchModel/Production -p 1234
Prediction with Served Model
import requests
import json
# Prepare data
= {
data "inputs": [[1.0, 2.0, 3.0, 4.0]] # Your input features
}
# Make prediction request
= requests.post(
response "http://localhost:1234/invocations",
={"Content-Type": "application/json"},
headers=json.dumps(data)
data
)
= response.json()
predictions print(predictions)
Docker Deployment
# Build Docker image
mlflow models build-docker -m models:/MyPyTorchModel/Production -n my-pytorch-model
# Run Docker container
docker run -p 8080:8080 my-pytorch-model
Advanced Features
Custom MLflow Model
import mlflow.pyfunc
class PyTorchModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self, model):
self.model = model
def predict(self, context, model_input):
# Custom prediction logic
with torch.no_grad():
= torch.FloatTensor(model_input.values)
tensor_input = self.model(tensor_input)
predictions return predictions.numpy()
# Log custom model
= PyTorchModelWrapper(model)
wrapped_model
mlflow.pyfunc.log_model("custom_model",
=wrapped_model
python_model )
Automatic Logging
# Enable automatic logging
mlflow.pytorch.autolog()
# Your training code - metrics and models are logged automatically
with mlflow.start_run():
# Training happens here
pass
Logging Hyperparameter Sweeps
import itertools
# Define hyperparameter grid
= {
param_grid 'learning_rate': [0.001, 0.01, 0.1],
'hidden_size': [64, 128, 256],
'batch_size': [32, 64, 128]
}
# Run experiments
for params in [dict(zip(param_grid.keys(), v))
for v in itertools.product(*param_grid.values())]:
with mlflow.start_run():
# Log parameters
for key, value in params.items():
mlflow.log_param(key, value)
# Train model with these parameters
= train_with_params(params)
model
# Log results
"final_accuracy", accuracy)
mlflow.log_metric("model") mlflow.pytorch.log_model(model,
Logging Artifacts and Plots
import matplotlib.pyplot as plt
import seaborn as sns
# Create and log plots
def log_training_plots(train_losses, val_losses):
=(10, 6))
plt.figure(figsize='Training Loss')
plt.plot(train_losses, label='Validation Loss')
plt.plot(val_losses, label'Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()'Training and Validation Loss')
plt.title('loss_plot.png')
plt.savefig('loss_plot.png')
mlflow.log_artifact(
plt.close()
# Log confusion matrix
def log_confusion_matrix(y_true, y_pred, class_names):
from sklearn.metrics import confusion_matrix
import seaborn as sns
= confusion_matrix(y_true, y_pred)
cm =(8, 6))
plt.figure(figsize=True, fmt='d', cmap='Blues',
sns.heatmap(cm, annot=class_names, yticklabels=class_names)
xticklabels'Confusion Matrix')
plt.title('confusion_matrix.png')
plt.savefig('confusion_matrix.png')
mlflow.log_artifact( plt.close()
Best Practices
1. Organize Experiments
# Use descriptive experiment names
"image_classification_resnet")
mlflow.set_experiment(
# Use run names for specific configurations
with mlflow.start_run(run_name="resnet50_adam_lr001"):
pass
2. Comprehensive Logging
def comprehensive_logging(model, optimizer, criterion, config):
# Log hyperparameters
mlflow.log_params(config)
# Log model architecture info
= sum(p.numel() for p in model.parameters())
total_params "total_parameters", total_params)
mlflow.log_param("model_architecture", str(model))
mlflow.log_param(
# Log optimizer info
"optimizer", type(optimizer).__name__)
mlflow.log_param("criterion", type(criterion).__name__)
mlflow.log_param(
# Log system info
"cuda_available", torch.cuda.is_available())
mlflow.log_param(if torch.cuda.is_available():
"gpu_name", torch.cuda.get_device_name(0)) mlflow.log_param(
3. Error Handling
def safe_mlflow_run(training_function, **kwargs):
try:
with mlflow.start_run():
= training_function(**kwargs)
result "status", "success")
mlflow.log_param(return result
except Exception as e:
"status", "failed")
mlflow.log_param("error", str(e))
mlflow.log_param(raise e
4. Model Comparison
def compare_models():
# Get experiment
= mlflow.get_experiment_by_name("pytorch_experiments")
experiment = mlflow.search_runs(experiment_ids=[experiment.experiment_id])
runs
# Sort by accuracy
= runs.sort_values("metrics.accuracy", ascending=False)
best_runs
print("Top 5 models by accuracy:")
print(best_runs[["run_id", "metrics.accuracy", "params.learning_rate"]].head())
5. Model Loading Best Practices
def load_model_safely(model_uri):
try:
= mlflow.pytorch.load_model(model_uri)
model eval() # Set to evaluation mode
model.return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Usage
= load_model_safely("models:/MyPyTorchModel/Production")
model if model:
# Use model for inference
pass
Summary
MLflow provides a comprehensive solution for managing PyTorch ML workflows:
- Experiment Tracking: Log parameters, metrics, and artifacts
- Model Management: Version and organize your models
- Model Registry: Centralized model store with lifecycle management
- Deployment: Easy model serving and deployment options
- Reproducibility: Track everything needed to reproduce experiments
Start with basic experiment tracking, then gradually adopt more advanced features like the model registry and deployment capabilities as your ML workflow matures.