PyTorch Model Deployment on Edge Devices - Complete Code Guide
Prerequisites
# Install required packages
pip install torch torchvision
pip install torch-model-archiver
pip install onnx onnxruntime
pip install tensorflow # for TensorFlow Lite conversion
1. Model Optimization
Quantization
import torch
import torch.quantization as quantization
from torch.quantization import get_default_qconfig
import torchvision.models as models
# Load your trained model
= models.resnet18(pretrained=True)
model eval()
model.
# Post-training quantization (easiest method)
def post_training_quantization(model, sample_data):
"""
Apply post-training quantization to reduce model size
"""
# Set model to evaluation mode
eval()
model.
# Fuse conv, bn and relu
= torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu']])
model_fused
# Specify quantization configuration
= torch.quantization.get_default_qconfig('fbgemm')
model_fused.qconfig
# Prepare the model for quantization
= torch.quantization.prepare(model_fused)
model_prepared
# Calibrate with sample data
with torch.no_grad():
for data in sample_data:
model_prepared(data)
# Convert to quantized model
= torch.quantization.convert(model_prepared)
quantized_model
return quantized_model
# Example usage
= [torch.randn(1, 3, 224, 224) for _ in range(100)]
sample_data = post_training_quantization(model, sample_data) quantized_model
Pruning
import torch.nn.utils.prune as prune
def prune_model(model, pruning_amount=0.3):
"""
Apply magnitude-based pruning to reduce model complexity
"""
= []
parameters_to_prune
# Collect all conv and linear layers
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
'weight'))
parameters_to_prune.append((module,
# Apply global magnitude pruning
prune.global_unstructured(
parameters_to_prune,=prune.L1Unstructured,
pruning_method=pruning_amount,
amount
)
# Remove pruning reparameterization to make pruning permanent
for module, param in parameters_to_prune:
prune.remove(module, param)
return model
# Apply pruning
= prune_model(model.copy(), pruning_amount=0.3) pruned_model
2. Model Conversion
Convert to TorchScript
def convert_to_torchscript(model, sample_input, save_path):
"""
Convert PyTorch model to TorchScript for deployment
"""
eval()
model.
# Method 1: Tracing (recommended for models without control flow)
try:
= torch.jit.trace(model, sample_input)
traced_model
traced_model.save(save_path)print(f"Model traced and saved to {save_path}")
return traced_model
except Exception as e:
print(f"Tracing failed: {e}")
# Method 2: Scripting (for models with control flow)
try:
= torch.jit.script(model)
scripted_model
scripted_model.save(save_path)print(f"Model scripted and saved to {save_path}")
return scripted_model
except Exception as e:
print(f"Scripting also failed: {e}")
return None
# Example usage
= torch.randn(1, 3, 224, 224)
sample_input = convert_to_torchscript(model, sample_input, "model.pt") torchscript_model
Convert to ONNX
import onnx
import onnxruntime as ort
def convert_to_onnx(model, sample_input, onnx_path):
"""
Convert PyTorch model to ONNX format
"""
eval()
model.
torch.onnx.export(# model being run
model, # model input
sample_input, # where to save the model
onnx_path, =True, # store the trained parameter weights
export_params=11, # ONNX version to export to
opset_version=True, # optimize constant folding
do_constant_folding=['input'], # model's input names
input_names=['output'], # model's output names
output_names={
dynamic_axes'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify the ONNX model
= onnx.load(onnx_path)
onnx_model
onnx.checker.check_model(onnx_model)print(f"ONNX model saved and verified at {onnx_path}")
# Convert to ONNX
"model.onnx")
convert_to_onnx(model, sample_input,
# Test ONNX Runtime inference
def test_onnx_inference(onnx_path, sample_input):
"""Test ONNX model inference"""
= ort.InferenceSession(onnx_path)
ort_session
# Convert input to numpy
= sample_input.numpy()
input_np
# Run inference
= ort_session.run(None, {'input': input_np})
outputs return outputs[0]
# Test the converted model
= test_onnx_inference("model.onnx", sample_input) onnx_output
Convert to TensorFlow Lite
import tensorflow as tf
def pytorch_to_tflite(onnx_path, tflite_path):
"""
Convert ONNX model to TensorFlow Lite
"""
# Convert ONNX to TensorFlow
from onnx_tf.backend import prepare
import onnx
= onnx.load(onnx_path)
onnx_model = prepare(onnx_model)
tf_rep "temp_tf_model")
tf_rep.export_graph(
# Convert to TensorFlow Lite
= tf.lite.TFLiteConverter.from_saved_model("temp_tf_model")
converter
# Apply optimizations
= [tf.lite.Optimize.DEFAULT]
converter.optimizations
# Convert model
= converter.convert()
tflite_model
# Save the model
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
print(f"TensorFlow Lite model saved to {tflite_path}")
# Convert to TensorFlow Lite
"model.onnx", "model.tflite") pytorch_to_tflite(
3. Mobile Deployment
Android Deployment
// Android Java code for PyTorch Mobile
public class ModelInference {
private Module model;
public ModelInference(String modelPath) {
= LiteModuleLoader.load(modelPath);
model }
public float[] predict(Bitmap bitmap) {
// Preprocess image
= TensorImageUtils.bitmapToFloat32Tensor(
Tensor inputTensor ,
bitmap.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
TensorImageUtils);
// Run inference
= model.forward(IValue.from(inputTensor)).toTensor();
Tensor outputTensor
// Get results
return outputTensor.getDataAsFloatArray();
}
}
iOS Deployment (Swift)
// iOS Swift code for PyTorch Mobile
import LibTorch
class ModelInference {
private var model: TorchModule
init(modelPath: String) {
= TorchModule(fileAtPath: modelPath)!
model }
func predict(image: UIImage) -> [Float] {
// Preprocess image
guard let pixelBuffer = image.pixelBuffer() else { return [] }
guard let inputTensor = TorchTensor.fromPixelBuffer(pixelBuffer) else { return [] }
// Run inference
guard let outputTensor = model.predict(inputs: [inputTensor]) else { return [] }
// Get results
return outputTensor[0].floatArray
}
}
Python Mobile Preprocessing
def create_mobile_model(model, sample_input):
"""
Create optimized model for mobile deployment
"""
eval()
model.
# Convert to TorchScript
= torch.jit.trace(model, sample_input)
traced_model
# Optimize for mobile
= optimize_for_mobile(traced_model)
optimized_model
# Save mobile-optimized model
"mobile_model.ptl")
optimized_model._save_for_lite_interpreter(
return optimized_model
from torch.utils.mobile_optimizer import optimize_for_mobile
# Create mobile model
= create_mobile_model(model, sample_input) mobile_model
4. Raspberry Pi Deployment
# Raspberry Pi deployment script
import torch
import torchvision.transforms as transforms
from PIL import Image
import time
import psutil
import threading
class RaspberryPiInference:
def __init__(self, model_path, device='cpu'):
self.device = torch.device(device)
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
# Define preprocessing transforms
self.transform = transforms.Compose([
224, 224)),
transforms.Resize((
transforms.ToTensor(),=[0.485, 0.456, 0.406],
transforms.Normalize(mean=[0.229, 0.224, 0.225])
std
])
# Performance monitoring
self.inference_times = []
def preprocess_image(self, image_path):
"""Preprocess image for inference"""
= Image.open(image_path).convert('RGB')
image = self.transform(image).unsqueeze(0)
input_tensor return input_tensor.to(self.device)
def inference(self, image_path):
"""Run inference on image"""
= time.time()
start_time
# Preprocess
= self.preprocess_image(image_path)
input_tensor
# Inference
with torch.no_grad():
= self.model(input_tensor)
outputs = torch.nn.functional.softmax(outputs[0], dim=0)
predictions
= time.time() - start_time
inference_time self.inference_times.append(inference_time)
return predictions.cpu().numpy(), inference_time
def get_system_stats(self):
"""Get system performance statistics"""
return {
'cpu_percent': psutil.cpu_percent(),
'memory_percent': psutil.virtual_memory().percent,
'temperature': self.get_cpu_temperature()
}
def get_cpu_temperature(self):
"""Get CPU temperature (Raspberry Pi specific)"""
try:
with open('/sys/class/thermal/thermal_zone0/temp', 'r') as f:
= float(f.read()) / 1000.0
temp return temp
except:
return None
# Usage example
if __name__ == "__main__":
# Initialize inference engine
= RaspberryPiInference("model.pt")
inference_engine
# Run inference
= inference_engine.inference("test_image.jpg")
predictions, inference_time
print(f"Inference time: {inference_time:.3f} seconds")
print(f"Top prediction: {predictions.max():.3f}")
print(f"System stats: {inference_engine.get_system_stats()}")
5. NVIDIA Jetson Deployment
# NVIDIA Jetson optimized deployment
import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
class JetsonTensorRTInference:
def __init__(self, onnx_model_path, trt_engine_path=None):
self.onnx_path = onnx_model_path
self.engine_path = trt_engine_path or onnx_model_path.replace('.onnx', '.trt')
# Build or load TensorRT engine
if not os.path.exists(self.engine_path):
self.build_engine()
self.engine = self.load_engine()
self.context = self.engine.create_execution_context()
# Allocate GPU memory
self.allocate_buffers()
def build_engine(self):
"""Build TensorRT engine from ONNX model"""
= trt.Logger(trt.Logger.WARNING)
logger = trt.Builder(logger)
builder = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = trt.OnnxParser(network, logger)
parser
# Parse ONNX model
with open(self.onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# Build engine
= builder.create_builder_config()
config = 1 << 28 # 256MB
config.max_workspace_size # Enable FP16 precision
config.set_flag(trt.BuilderFlag.FP16)
= builder.build_engine(network, config)
engine
# Save engine
with open(self.engine_path, 'wb') as f:
f.write(engine.serialize())
return engine
def load_engine(self):
"""Load TensorRT engine"""
= trt.Runtime(trt.Logger(trt.Logger.WARNING))
runtime with open(self.engine_path, 'rb') as f:
return runtime.deserialize_cuda_engine(f.read())
def allocate_buffers(self):
"""Allocate GPU memory buffers"""
self.bindings = []
self.inputs = []
self.outputs = []
for binding in self.engine:
= self.engine.get_binding_shape(binding)
shape = trt.volume(shape) * self.engine.max_batch_size
size = trt.nptype(self.engine.get_binding_dtype(binding))
dtype
# Allocate host and device buffers
= cuda.pagelocked_empty(size, dtype)
host_mem = cuda.mem_alloc(host_mem.nbytes)
device_mem
self.bindings.append(int(device_mem))
if self.engine.binding_is_input(binding):
self.inputs.append({'host': host_mem, 'device': device_mem})
else:
self.outputs.append({'host': host_mem, 'device': device_mem})
def inference(self, input_data):
"""Run TensorRT inference"""
# Copy input data to GPU
self.inputs[0]['host'], input_data.ravel())
np.copyto(self.inputs[0]['device'], self.inputs[0]['host'])
cuda.memcpy_htod(
# Run inference
self.context.execute_v2(bindings=self.bindings)
# Copy output data from GPU
self.outputs[0]['host'], self.outputs[0]['device'])
cuda.memcpy_dtoh(
return self.outputs[0]['host']
# Usage for Jetson
= JetsonTensorRTInference("model.onnx") jetson_inference
6. Performance Optimization
Benchmarking Script
import time
import numpy as np
import torch
import psutil
from contextlib import contextmanager
@contextmanager
def timer():
"""Context manager for timing code execution"""
= time.perf_counter()
start yield
= time.perf_counter()
end print(f"Execution time: {end - start:.4f} seconds")
class ModelBenchmark:
def __init__(self, model, input_shape, device='cpu'):
self.model = model.to(device)
self.device = device
self.input_shape = input_shape
def benchmark_inference(self, num_runs=100, warmup_runs=10):
"""Benchmark model inference performance"""
# Generate random input
= torch.randn(self.input_shape).to(self.device)
dummy_input
# Warmup runs
self.model.eval()
with torch.no_grad():
for _ in range(warmup_runs):
= self.model(dummy_input)
_
# Benchmark runs
= []
inference_times = []
memory_usage
for i in range(num_runs):
# Monitor memory before inference
if self.device == 'cuda':
torch.cuda.empty_cache()= torch.cuda.memory_allocated()
memory_before else:
= psutil.Process().memory_info().rss
memory_before
# Time inference
= time.perf_counter()
start_time with torch.no_grad():
= self.model(dummy_input)
output
if self.device == 'cuda':
torch.cuda.synchronize()
= time.perf_counter()
end_time
# Monitor memory after inference
if self.device == 'cuda':
= torch.cuda.memory_allocated()
memory_after else:
= psutil.Process().memory_info().rss
memory_after
- start_time)
inference_times.append(end_time - memory_before)
memory_usage.append(memory_after
# Calculate statistics
= {
stats 'mean_time': np.mean(inference_times),
'std_time': np.std(inference_times),
'min_time': np.min(inference_times),
'max_time': np.max(inference_times),
'fps': 1.0 / np.mean(inference_times),
'mean_memory': np.mean(memory_usage),
'max_memory': np.max(memory_usage)
}
return stats
def profile_model(self):
"""Profile model to identify bottlenecks"""
= torch.randn(self.input_shape).to(self.device)
dummy_input
with torch.profiler.profile(
=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
activities=True,
record_shapes=True,
profile_memory=True
with_stackas profiler:
) with torch.no_grad():
self.model(dummy_input)
# Print profiling results
print(profiler.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return profiler
# Usage example
= ModelBenchmark(model, (1, 3, 224, 224), device='cpu')
benchmark = benchmark.benchmark_inference()
stats print(f"Average inference time: {stats['mean_time']:.4f}s")
print(f"FPS: {stats['fps']:.2f}")
Memory Optimization
def optimize_memory_usage(model):
"""Apply memory optimization techniques"""
# Enable memory efficient attention (for transformers)
if hasattr(model, 'enable_memory_efficient_attention'):
model.enable_memory_efficient_attention()
# Use gradient checkpointing during training
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
# Fuse operations where possible
= torch.jit.optimize_for_inference(torch.jit.script(model))
model
return model
def batch_inference(model, data_loader, batch_size=1):
"""Perform batch inference with memory management"""
eval()
model.= []
results
with torch.no_grad():
for batch in data_loader:
# Process in smaller chunks if needed
if batch.size(0) > batch_size:
for i in range(0, batch.size(0), batch_size):
= batch[i:i+batch_size]
chunk = model(chunk)
output
results.append(output.cpu())
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()else:
= model(batch)
output
results.append(output.cpu())
return torch.cat(results, dim=0)
7. Best Practices
Model Deployment Checklist
class DeploymentValidator:
def __init__(self, original_model, optimized_model, test_input):
self.original_model = original_model
self.optimized_model = optimized_model
self.test_input = test_input
def validate_accuracy(self, tolerance=1e-3):
"""Validate that optimized model maintains accuracy"""
self.original_model.eval()
self.optimized_model.eval()
with torch.no_grad():
= self.original_model(self.test_input)
original_output = self.optimized_model(self.test_input)
optimized_output
# Check if outputs are close
if torch.allclose(original_output, optimized_output, atol=tolerance):
print("✓ Accuracy validation passed")
return True
else:
print("✗ Accuracy validation failed")
= torch.abs(original_output - optimized_output).max().item()
diff print(f"Maximum difference: {diff}")
return False
def validate_performance(self):
"""Compare performance metrics"""
# Benchmark both models
= ModelBenchmark(self.original_model, self.test_input.shape)
original_benchmark = ModelBenchmark(self.optimized_model, self.test_input.shape)
optimized_benchmark
= original_benchmark.benchmark_inference(num_runs=50)
original_stats = optimized_benchmark.benchmark_inference(num_runs=50)
optimized_stats
= original_stats['mean_time'] / optimized_stats['mean_time']
speedup = (original_stats['mean_memory'] - optimized_stats['mean_memory']) / original_stats['mean_memory'] * 100
memory_reduction
print(f"Performance improvement: {speedup:.2f}x speedup")
print(f"Memory reduction: {memory_reduction:.1f}%")
return {
'speedup': speedup,
'memory_reduction': memory_reduction,
'original_fps': original_stats['fps'],
'optimized_fps': optimized_stats['fps']
}
def check_model_size(self):
"""Compare model file sizes"""
# Save both models temporarily
self.original_model.state_dict(), 'temp_original.pth')
torch.save(self.optimized_model), 'temp_optimized.pt')
torch.jit.save(torch.jit.script(
import os
= os.path.getsize('temp_original.pth')
original_size = os.path.getsize('temp_optimized.pt')
optimized_size
= (original_size - optimized_size) / original_size * 100
size_reduction
print(f"Original model size: {original_size / 1024 / 1024:.2f} MB")
print(f"Optimized model size: {optimized_size / 1024 / 1024:.2f} MB")
print(f"Size reduction: {size_reduction:.1f}%")
# Clean up temporary files
'temp_original.pth')
os.remove('temp_optimized.pt')
os.remove(
return size_reduction
# Example usage
= DeploymentValidator(model, quantized_model, sample_input)
validator
validator.validate_accuracy()= validator.validate_performance()
performance_metrics = validator.check_model_size() size_reduction
Error Handling and Logging
import logging
from functools import wraps
def setup_logging():
"""Setup logging for deployment"""
logging.basicConfig(=logging.INFO,
levelformat='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
=[
handlers'model_deployment.log'),
logging.FileHandler(
logging.StreamHandler()
]
)return logging.getLogger(__name__)
def handle_inference_errors(func):
"""Decorator for handling inference errors"""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except torch.cuda.OutOfMemoryError:
"CUDA out of memory. Try reducing batch size.")
logging.error(
torch.cuda.empty_cache()raise
except Exception as e:
f"Inference error: {str(e)}")
logging.error(raise
return wrapper
class RobustInference:
def __init__(self, model_path, device='cpu'):
self.logger = setup_logging()
self.device = torch.device(device)
try:
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
self.logger.info(f"Model loaded successfully on {device}")
except Exception as e:
self.logger.error(f"Failed to load model: {e}")
raise
@handle_inference_errors
def inference(self, input_data):
"""Robust inference with error handling"""
= time.time()
start_time
with torch.no_grad():
= self.model(input_data)
output
= time.time() - start_time
inference_time self.logger.info(f"Inference completed in {inference_time:.3f}s")
return output
Conclusion
This guide provides a comprehensive approach to deploying PyTorch models on edge devices. Key takeaways:
- Model Optimization: Always quantize and prune models before deployment
- Format Selection: Choose the right format (TorchScript, ONNX, TensorRT) based on your target device
- Performance Monitoring: Continuously monitor inference time, memory usage, and accuracy
- Device-Specific Optimization: Leverage device-specific optimizations (TensorRT for NVIDIA, Core ML for iOS)
- Robust Deployment: Implement proper error handling and logging for production systems
Remember to validate your optimized models thoroughly before deployment and monitor their performance in production environments.