LitServe with MobileNetV2 - Complete Code Guide

This guide demonstrates how to deploy a MobileNetV2 image classification model using LitServe for efficient, scalable inference.
Installation
# Install required packages
pip install litserve torch torchvision pillow requestsBasic Implementation
Simple MobileNetV2 API
# mobilenet_api.py
import io
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls
class MobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Initialize the model and preprocessing pipeline"""
# Load pre-trained MobileNetV2
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
self.model.to(device)
# Define image preprocessing
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ImageNet class labels (first 10 for brevity)
self.class_labels = [
"tench", "goldfish", "great white shark", "tiger shark",
"hammerhead", "electric ray", "stingray", "cock",
"hen", "ostrich"
# ... add all 1000 ImageNet classes
]
def decode_request(self, request):
"""Parse incoming request and prepare image"""
if isinstance(request, dict) and "image" in request:
# Handle base64 encoded image
import base64
image_data = base64.b64decode(request["image"])
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
# Handle direct image upload
image = Image.open(io.BytesIO(request)).convert('RGB')
return image
def predict(self, image):
"""Run inference on the preprocessed image"""
# Preprocess image
input_tensor = self.transform(image).unsqueeze(0)
input_tensor = input_tensor.to(next(self.model.parameters()).device)
# Run inference
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
return probabilities
def encode_response(self, probabilities):
"""Format the response"""
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
results = []
for i in range(5):
idx = top5_indices[i].item()
prob = top5_prob[i].item()
label = self.class_labels[idx] if idx < len(self.class_labels) else f"class_{idx}"
results.append({
"class": label,
"confidence": round(prob, 4),
"class_id": idx
})
return {
"predictions": results,
"model": "mobilenet_v2"
}
if __name__ == "__main__":
# Create and run the API
api = MobileNetV2API()
server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server.run(port=8000)Client Code for Testing
# client.py
import requests
import base64
from PIL import Image
import io
def encode_image(image_path):
"""Encode image to base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def test_image_classification(image_path, server_url="http://localhost:8000/predict"):
"""Test the MobileNetV2 API"""
# Method 1: Send as base64 in JSON
encoded_image = encode_image(image_path)
response = requests.post(
server_url,
json={"image": encoded_image}
)
if response.status_code == 200:
result = response.json()
print("Top 5 Predictions:")
for pred in result["predictions"]:
print(f" {pred['class']}: {pred['confidence']:.4f}")
else:
print(f"Error: {response.status_code} - {response.text}")
def test_direct_upload(image_path, server_url="http://localhost:8000/predict"):
"""Test with direct image upload"""
with open(image_path, 'rb') as f:
files = {'file': f}
response = requests.post(server_url, files=files)
if response.status_code == 200:
result = response.json()
print("Predictions:", result)
if __name__ == "__main__":
# Test with a sample image
test_image_classification("sample_image.jpg")Advanced Features
Batch Processing with Custom Batching
# advanced_mobilenet_api.py
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls
import json
from typing import List, Any
class BatchedMobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Initialize model with batch processing capabilities"""
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
self.model.to(device)
self.device = device
# Optimized transform for batch processing
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load class labels from file or define them
self.load_imagenet_labels()
def load_imagenet_labels(self):
"""Load ImageNet class labels"""
# You can download from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
try:
with open('imagenet_classes.txt', 'r') as f:
self.class_labels = [line.strip() for line in f.readlines()]
except FileNotFoundError:
# Fallback to first few classes
self.class_labels = [f"class_{i}" for i in range(1000)]
def decode_request(self, request):
"""Handle both single images and batch requests"""
if isinstance(request, dict):
if "images" in request: # Batch request
images = []
for img_data in request["images"]:
if isinstance(img_data, str): # base64
import base64
image_data = base64.b64decode(img_data)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else: # direct bytes
image = Image.open(io.BytesIO(img_data)).convert('RGB')
images.append(image)
return images
elif "image" in request: # Single request
import base64
image_data = base64.b64decode(request["image"])
image = Image.open(io.BytesIO(image_data)).convert('RGB')
return [image] # Wrap in list for consistent handling
# Direct upload
image = Image.open(io.BytesIO(request)).convert('RGB')
return [image]
def batch(self, inputs: List[Any]) -> List[Any]:
"""Custom batching logic"""
# Flatten all images from all requests
all_images = []
batch_sizes = []
for inp in inputs:
if isinstance(inp, list):
all_images.extend(inp)
batch_sizes.append(len(inp))
else:
all_images.append(inp)
batch_sizes.append(1)
return all_images, batch_sizes
def predict(self, batch_data):
"""Process batch of images efficiently"""
images, batch_sizes = batch_data
# Preprocess all images
batch_tensor = torch.stack([
self.transform(img) for img in images
]).to(self.device)
# Run batch inference
with torch.no_grad():
outputs = self.model(batch_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
return probabilities, batch_sizes
def unbatch(self, output):
"""Split batch results back to individual responses"""
probabilities, batch_sizes = output
results = []
start_idx = 0
for batch_size in batch_sizes:
batch_probs = probabilities[start_idx:start_idx + batch_size]
results.append(batch_probs)
start_idx += batch_size
return results
def encode_response(self, probabilities):
"""Format response for batch or single predictions"""
if len(probabilities.shape) == 1: # Single prediction
probabilities = probabilities.unsqueeze(0)
all_results = []
for prob_vector in probabilities:
top5_prob, top5_indices = torch.topk(prob_vector, 5)
predictions = []
for i in range(5):
idx = top5_indices[i].item()
prob = top5_prob[i].item()
predictions.append({
"class": self.class_labels[idx],
"confidence": round(prob, 4),
"class_id": idx
})
all_results.append(predictions)
return {
"predictions": all_results[0] if len(all_results) == 1 else all_results,
"model": "mobilenet_v2",
"batch_size": len(all_results)
}
if __name__ == "__main__":
api = BatchedMobileNetV2API()
server = ls.LitServer(
api,
accelerator="auto",
max_batch_size=8,
batch_timeout=0.1, # 100ms timeout for batching
)
server.run(port=8000, num_workers=2)Adding Model Quantization for Better Performance
# quantized_mobilenet_api.py
import torch
import torch.quantization
from torchvision import models
import litserve as ls
class QuantizedMobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Setup quantized MobileNetV2 for faster inference"""
# Load pre-trained model
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
# Apply dynamic quantization for CPU inference
if device == "cpu":
self.model = torch.quantization.quantize_dynamic(
self.model,
{torch.nn.Linear, torch.nn.Conv2d},
dtype=torch.qint8
)
print("Applied dynamic quantization for CPU")
else:
self.model.to(device)
print(f"Using device: {device}")
# Rest of setup code...
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ... rest of the methods remain the samePerformance Optimization
Configuration for Production
# production_config.py
import litserve as ls
from advanced_mobilenet_api import BatchedMobileNetV2API
def create_production_server():
"""Create optimized server for production"""
api = BatchedMobileNetV2API()
server = ls.LitServer(
api,
accelerator="auto", # Auto-detect GPU/CPU
max_batch_size=16, # Larger batches for throughput
batch_timeout=0.05, # 50ms batching timeout
workers_per_device=2, # Multiple workers per GPU
timeout=30, # Request timeout
)
return server
if __name__ == "__main__":
server = create_production_server()
server.run(
port=8000,
host="0.0.0.0", # Accept external connections
num_workers=4 # Total number of workers
)Monitoring and Logging
# monitored_api.py
import time
import logging
from collections import defaultdict
import litserve as ls
class MonitoredMobileNetV2API(BatchedMobileNetV2API):
def setup(self, device):
"""Setup with monitoring"""
super().setup(device)
# Setup logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# Metrics tracking
self.metrics = defaultdict(list)
self.request_count = 0
def predict(self, batch_data):
"""Predict with timing and metrics"""
start_time = time.time()
result = super().predict(batch_data)
# Log timing
inference_time = time.time() - start_time
batch_size = len(batch_data[0])
self.metrics['inference_times'].append(inference_time)
self.metrics['batch_sizes'].append(batch_size)
self.request_count += 1
self.logger.info(
f"Processed batch of {batch_size} images in {inference_time:.3f}s "
f"(Total requests: {self.request_count})"
)
return result
def encode_response(self, probabilities):
"""Add metrics to response"""
response = super().encode_response(probabilities)
# Add performance metrics every 100 requests
if self.request_count % 100 == 0:
avg_time = sum(self.metrics['inference_times'][-100:]) / min(100, len(self.metrics['inference_times']))
response['metrics'] = {
'avg_inference_time': round(avg_time, 4),
'total_requests': self.request_count
}
return responseDeployment
Docker Deployment
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
wget \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Download ImageNet labels
RUN wget -O imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
EXPOSE 8000
CMD ["python", "production_config.py"]# docker-compose.yml
version: '3.8'
services:
mobilenet-api:
build: .
ports:
- "8000:8000"
environment:
- CUDA_VISIBLE_DEVICES=0 # Set GPU if available
volumes:
- ./models:/app/models # Optional: for custom models
restart: unless-stopped
# Optional: Add nginx for load balancing
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- mobilenet-apiKubernetes Deployment
# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: mobilenet-api
spec:
replicas: 3
selector:
matchLabels:
app: mobilenet-api
template:
metadata:
labels:
app: mobilenet-api
spec:
containers:
- name: mobilenet-api
image: your-registry/mobilenet-api:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "1Gi"
cpu: "500m"
limits:
memory: "2Gi"
cpu: "1000m"
env:
- name: WORKERS
value: "2"
---
apiVersion: v1
kind: Service
metadata:
name: mobilenet-service
spec:
selector:
app: mobilenet-api
ports:
- port: 80
targetPort: 8000
type: LoadBalancerTesting
Comprehensive Test Suite
# test_api.py
import pytest
import requests
import base64
import numpy as np
from PIL import Image
import io
class TestMobileNetV2API:
def setup_class(self):
"""Setup test configuration"""
self.base_url = "http://localhost:8000"
self.test_image = self.create_test_image()
def create_test_image(self):
"""Create a test image"""
# Create a simple test image
img = Image.new('RGB', (224, 224), color='red')
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG')
img_buffer.seek(0)
return img_buffer.getvalue()
def test_single_prediction(self):
"""Test single image prediction"""
encoded_image = base64.b64encode(self.test_image).decode('utf-8')
response = requests.post(
f"{self.base_url}/predict",
json={"image": encoded_image}
)
assert response.status_code == 200
result = response.json()
assert "predictions" in result
assert len(result["predictions"]) == 5
assert all("class" in pred and "confidence" in pred for pred in result["predictions"])
def test_batch_prediction(self):
"""Test batch prediction"""
encoded_image = base64.b64encode(self.test_image).decode('utf-8')
response = requests.post(
f"{self.base_url}/predict",
json={"images": [encoded_image, encoded_image]}
)
assert response.status_code == 200
result = response.json()
assert "batch_size" in result
assert result["batch_size"] == 2
def test_performance(self):
"""Test API performance"""
import time
encoded_image = base64.b64encode(self.test_image).decode('utf-8')
# Warmup
requests.post(f"{self.base_url}/predict", json={"image": encoded_image})
# Time multiple requests
times = []
for _ in range(10):
start = time.time()
response = requests.post(f"{self.base_url}/predict", json={"image": encoded_image})
times.append(time.time() - start)
assert response.status_code == 200
avg_time = sum(times) / len(times)
print(f"Average response time: {avg_time:.3f}s")
assert avg_time < 1.0 # Should respond within 1 second
if __name__ == "__main__":
pytest.main([__file__, "-v"])Load Testing
# load_test.py
import asyncio
import aiohttp
import base64
import time
from PIL import Image
import io
async def send_request(session, url, data):
"""Send a single request"""
try:
async with session.post(url, json=data) as response:
result = await response.json()
return response.status, time.time()
except Exception as e:
return 500, time.time()
async def load_test(num_requests=100, concurrent=10):
"""Run load test"""
# Create test image
img = Image.new('RGB', (224, 224), color='blue')
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG')
encoded_image = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
url = "http://localhost:8000/predict"
data = {"image": encoded_image}
start_time = time.time()
async with aiohttp.ClientSession() as session:
# Create semaphore to limit concurrent requests
semaphore = asyncio.Semaphore(concurrent)
async def bounded_request():
async with semaphore:
return await send_request(session, url, data)
# Send all requests
tasks = [bounded_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
total_time = time.time() - start_time
# Analyze results
successful = sum(1 for status, _ in results if status == 200)
failed = num_requests - successful
print(f"Load Test Results:")
print(f" Total Requests: {num_requests}")
print(f" Successful: {successful}")
print(f" Failed: {failed}")
print(f" Total Time: {total_time:.2f}s")
print(f" Requests/sec: {num_requests/total_time:.2f}")
print(f" Concurrent: {concurrent}")
if __name__ == "__main__":
asyncio.run(load_test(num_requests=200, concurrent=20))Requirements File
# requirements.txt
litserve>=0.2.0
torch>=1.9.0
torchvision>=0.10.0
Pillow>=8.0.0
requests>=2.25.0
numpy>=1.21.0
aiohttp>=3.8.0 # For async testing
pytest>=6.0.0 # For testingBest Practices
- Model Optimization: Use quantization and TorchScript for production
- Batch Processing: Configure appropriate batch sizes based on your hardware
- Error Handling: Implement comprehensive error handling for robustness
- Monitoring: Add logging and metrics collection for production monitoring
- Security: Implement authentication and input validation for production APIs
- Caching: Consider caching frequently requested predictions
- Scaling: Use container orchestration for high-availability deployments
This guide provides a complete foundation for deploying MobileNetV2 with LitServe, from basic implementation to production-ready deployment with monitoring and testing.