MLflow for PyTorch - Complete Guide
MLflow is an open-source platform for managing the machine learning lifecycle, including experimentation, reproducibility, deployment, and model registry. This guide covers how to integrate MLflow with PyTorch for comprehensive ML workflow management. ## Installation and Setup
Install MLflow
pip install mlflow
pip install torch torchvision
Start MLflow UI
mlflow ui
This starts the MLflow UI at http://localhost:5000
Basic Configuration
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Set tracking URI (optional - defaults to local)
"http://localhost:5000")
mlflow.set_tracking_uri(
# Set experiment name
"pytorch_experiments") mlflow.set_experiment(
Basic MLflow Concepts
- Experiment: A collection of runs for a particular task
- Run: A single execution of your ML code
- Artifact: Files generated during a run (models, plots, data)
- Metric: Numerical values tracked over time
- Parameter: Input configurations for your run
Experiment Tracking
Basic Run Structure
import mlflow
with mlflow.start_run():
# Your training code here
"learning_rate", 0.001)
mlflow.log_param("accuracy", 0.95)
mlflow.log_metric("model.pth") mlflow.log_artifact(
Complete Training Example
import torch
import torch.nn as nn
import torch.optim as optim
import mlflow
import mlflow.pytorch
from sklearn.metrics import accuracy_score
import numpy as np
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
= self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out return out
def train_model():
# Hyperparameters
= 784
input_size = 128
hidden_size = 10
num_classes = 0.001
learning_rate = 64
batch_size = 10
num_epochs
# Start MLflow run
with mlflow.start_run():
# Log hyperparameters
"input_size", input_size)
mlflow.log_param("hidden_size", hidden_size)
mlflow.log_param("num_classes", num_classes)
mlflow.log_param("learning_rate", learning_rate)
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("num_epochs", num_epochs)
mlflow.log_param(
# Initialize model
= SimpleNet(input_size, hidden_size, num_classes)
model = nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=learning_rate)
optimizer
# Training loop
for epoch in range(num_epochs):
= 0.0
running_loss = 0
correct = 0
total
# Simulate training data
for i in range(100): # 100 batches
# Generate dummy data
= torch.randn(batch_size, input_size)
inputs = torch.randint(0, num_classes, (batch_size,))
labels
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
+= loss.item()
running_loss = torch.max(outputs.data, 1)
_, predicted += labels.size(0)
total += (predicted == labels).sum().item()
correct
# Calculate metrics
= running_loss / 100
epoch_loss = 100 * correct / total
epoch_acc
# Log metrics
"loss", epoch_loss, step=epoch)
mlflow.log_metric("accuracy", epoch_acc, step=epoch)
mlflow.log_metric(
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
# Log model
"model")
mlflow.pytorch.log_model(model,
# Log additional artifacts
"model_state_dict.pth")
torch.save(model.state_dict(), "model_state_dict.pth")
mlflow.log_artifact(
# Run training
train_model()
Model Logging
Different Ways to Log PyTorch Models
1. Log Complete Model
# Log the entire model
"complete_model") mlflow.pytorch.log_model(model,
2. Log Model State Dict
# Save and log state dict
"model_state_dict.pth")
torch.save(model.state_dict(), "model_state_dict.pth") mlflow.log_artifact(
3. Log with Custom Code
# Log model with custom code for loading
mlflow.pytorch.log_model(
model, "model",
=["model_definition.py"] # Include custom model definition
code_paths )
4. Log with Conda Environment
import mlflow.pytorch
# Create conda environment specification
= {
conda_env 'channels': ['defaults', 'pytorch'],
'dependencies': [
'python=3.8',
'pytorch',
'torchvision',
'pip': ['mlflow']}
{
],'name': 'pytorch_env'
}
mlflow.pytorch.log_model(
model,"model",
=conda_env
conda_env )
Model Registry
Register Model
# Register model during logging
mlflow.pytorch.log_model(
model, "model",
="MyPyTorchModel"
registered_model_name
)
# Or register existing run
= "runs:/your_run_id/model"
model_uri "MyPyTorchModel") mlflow.register_model(model_uri,
Model Versioning and Stages
from mlflow.tracking import MlflowClient
= MlflowClient()
client
# Transition model to different stages
client.transition_model_version_stage(="MyPyTorchModel",
name=1,
version="Production"
stage
)
# Get model by stage
= client.get_latest_versions(
model_version "MyPyTorchModel",
=["Production"]
stages0] )[
Load Registered Model
# Load model from registry
= mlflow.pytorch.load_model(
model =f"models:/MyPyTorchModel/Production"
model_uri
)
# Or load specific version
= mlflow.pytorch.load_model(
model =f"models:/MyPyTorchModel/1"
model_uri )
Model Deployment
Local Serving
# Serve model locally
# Run in terminal:
# mlflow models serve -m models:/MyPyTorchModel/Production -p 1234
Prediction with Served Model
import requests
import json
# Prepare data
= {
data "inputs": [[1.0, 2.0, 3.0, 4.0]] # Your input features
}
# Make prediction request
= requests.post(
response "http://localhost:1234/invocations",
={"Content-Type": "application/json"},
headers=json.dumps(data)
data
)
= response.json()
predictions print(predictions)
Docker Deployment
# Build Docker image
mlflow models build-docker -m models:/MyPyTorchModel/Production -n my-pytorch-model
# Run Docker container
docker run -p 8080:8080 my-pytorch-model
Advanced Features
Custom MLflow Model
import mlflow.pyfunc
class PyTorchModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(self, model):
self.model = model
def predict(self, context, model_input):
# Custom prediction logic
with torch.no_grad():
= torch.FloatTensor(model_input.values)
tensor_input = self.model(tensor_input)
predictions return predictions.numpy()
# Log custom model
= PyTorchModelWrapper(model)
wrapped_model
mlflow.pyfunc.log_model("custom_model",
=wrapped_model
python_model )
Automatic Logging
# Enable automatic logging
mlflow.pytorch.autolog()
# Your training code - metrics and models are logged automatically
with mlflow.start_run():
# Training happens here
pass
Logging Hyperparameter Sweeps
import itertools
# Define hyperparameter grid
= {
param_grid 'learning_rate': [0.001, 0.01, 0.1],
'hidden_size': [64, 128, 256],
'batch_size': [32, 64, 128]
}
# Run experiments
for params in [dict(zip(param_grid.keys(), v))
for v in itertools.product(*param_grid.values())]:
with mlflow.start_run():
# Log parameters
for key, value in params.items():
mlflow.log_param(key, value)
# Train model with these parameters
= train_with_params(params)
model
# Log results
"final_accuracy", accuracy)
mlflow.log_metric("model") mlflow.pytorch.log_model(model,
Logging Artifacts and Plots
import matplotlib.pyplot as plt
import seaborn as sns
# Create and log plots
def log_training_plots(train_losses, val_losses):
=(10, 6))
plt.figure(figsize='Training Loss')
plt.plot(train_losses, label='Validation Loss')
plt.plot(val_losses, label'Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()'Training and Validation Loss')
plt.title('loss_plot.png')
plt.savefig('loss_plot.png')
mlflow.log_artifact(
plt.close()
# Log confusion matrix
def log_confusion_matrix(y_true, y_pred, class_names):
from sklearn.metrics import confusion_matrix
import seaborn as sns
= confusion_matrix(y_true, y_pred)
cm =(8, 6))
plt.figure(figsize=True, fmt='d', cmap='Blues',
sns.heatmap(cm, annot=class_names, yticklabels=class_names)
xticklabels'Confusion Matrix')
plt.title('confusion_matrix.png')
plt.savefig('confusion_matrix.png')
mlflow.log_artifact( plt.close()
Best Practices
1. Organize Experiments
# Use descriptive experiment names
"image_classification_resnet")
mlflow.set_experiment(
# Use run names for specific configurations
with mlflow.start_run(run_name="resnet50_adam_lr001"):
pass
2. Comprehensive Logging
def comprehensive_logging(model, optimizer, criterion, config):
# Log hyperparameters
mlflow.log_params(config)
# Log model architecture info
= sum(p.numel() for p in model.parameters())
total_params "total_parameters", total_params)
mlflow.log_param("model_architecture", str(model))
mlflow.log_param(
# Log optimizer info
"optimizer", type(optimizer).__name__)
mlflow.log_param("criterion", type(criterion).__name__)
mlflow.log_param(
# Log system info
"cuda_available", torch.cuda.is_available())
mlflow.log_param(if torch.cuda.is_available():
"gpu_name", torch.cuda.get_device_name(0)) mlflow.log_param(
3. Error Handling
def safe_mlflow_run(training_function, **kwargs):
try:
with mlflow.start_run():
= training_function(**kwargs)
result "status", "success")
mlflow.log_param(return result
except Exception as e:
"status", "failed")
mlflow.log_param("error", str(e))
mlflow.log_param(raise e
4. Model Comparison
def compare_models():
# Get experiment
= mlflow.get_experiment_by_name("pytorch_experiments")
experiment = mlflow.search_runs(experiment_ids=[experiment.experiment_id])
runs
# Sort by accuracy
= runs.sort_values("metrics.accuracy", ascending=False)
best_runs
print("Top 5 models by accuracy:")
print(best_runs[["run_id", "metrics.accuracy", "params.learning_rate"]].head())
5. Model Loading Best Practices
def load_model_safely(model_uri):
try:
= mlflow.pytorch.load_model(model_uri)
model eval() # Set to evaluation mode
model.return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Usage
= load_model_safely("models:/MyPyTorchModel/Production")
model if model:
# Use model for inference
pass
Summary
MLflow provides a comprehensive solution for managing PyTorch ML workflows:
- Experiment Tracking: Log parameters, metrics, and artifacts
- Model Management: Version and organize your models
- Model Registry: Centralized model store with lifecycle management
- Deployment: Easy model serving and deployment options
- Reproducibility: Track everything needed to reproduce experiments
Start with basic experiment tracking, then gradually adopt more advanced features like the model registry and deployment capabilities as your ML workflow matures.