LitServe with MobileNetV2 - Complete Code Guide
This guide demonstrates how to deploy a MobileNetV2 image classification model using LitServe for efficient, scalable inference.
Installation
# Install required packages
pip install litserve torch torchvision pillow requests
Basic Implementation
Simple MobileNetV2 API
# mobilenet_api.py
import io
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls
class MobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Initialize the model and preprocessing pipeline"""
# Load pre-trained MobileNetV2
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
self.model.to(device)
# Define image preprocessing
self.transform = transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),
transforms.Normalize(=[0.485, 0.456, 0.406],
mean=[0.229, 0.224, 0.225]
std
)
])
# ImageNet class labels (first 10 for brevity)
self.class_labels = [
"tench", "goldfish", "great white shark", "tiger shark",
"hammerhead", "electric ray", "stingray", "cock",
"hen", "ostrich"
# ... add all 1000 ImageNet classes
]
def decode_request(self, request):
"""Parse incoming request and prepare image"""
if isinstance(request, dict) and "image" in request:
# Handle base64 encoded image
import base64
= base64.b64decode(request["image"])
image_data = Image.open(io.BytesIO(image_data)).convert('RGB')
image else:
# Handle direct image upload
= Image.open(io.BytesIO(request)).convert('RGB')
image
return image
def predict(self, image):
"""Run inference on the preprocessed image"""
# Preprocess image
= self.transform(image).unsqueeze(0)
input_tensor = input_tensor.to(next(self.model.parameters()).device)
input_tensor
# Run inference
with torch.no_grad():
= self.model(input_tensor)
outputs = torch.nn.functional.softmax(outputs[0], dim=0)
probabilities
return probabilities
def encode_response(self, probabilities):
"""Format the response"""
# Get top 5 predictions
= torch.topk(probabilities, 5)
top5_prob, top5_indices
= []
results for i in range(5):
= top5_indices[i].item()
idx = top5_prob[i].item()
prob = self.class_labels[idx] if idx < len(self.class_labels) else f"class_{idx}"
label
results.append({"class": label,
"confidence": round(prob, 4),
"class_id": idx
})
return {
"predictions": results,
"model": "mobilenet_v2"
}
if __name__ == "__main__":
# Create and run the API
= MobileNetV2API()
api = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server =8000) server.run(port
Client Code for Testing
# client.py
import requests
import base64
from PIL import Image
import io
def encode_image(image_path):
"""Encode image to base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def test_image_classification(image_path, server_url="http://localhost:8000/predict"):
"""Test the MobileNetV2 API"""
# Method 1: Send as base64 in JSON
= encode_image(image_path)
encoded_image = requests.post(
response
server_url,={"image": encoded_image}
json
)
if response.status_code == 200:
= response.json()
result print("Top 5 Predictions:")
for pred in result["predictions"]:
print(f" {pred['class']}: {pred['confidence']:.4f}")
else:
print(f"Error: {response.status_code} - {response.text}")
def test_direct_upload(image_path, server_url="http://localhost:8000/predict"):
"""Test with direct image upload"""
with open(image_path, 'rb') as f:
= {'file': f}
files = requests.post(server_url, files=files)
response
if response.status_code == 200:
= response.json()
result print("Predictions:", result)
if __name__ == "__main__":
# Test with a sample image
"sample_image.jpg") test_image_classification(
Advanced Features
Batch Processing with Custom Batching
# advanced_mobilenet_api.py
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls
import json
from typing import List, Any
class BatchedMobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Initialize model with batch processing capabilities"""
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
self.model.to(device)
self.device = device
# Optimized transform for batch processing
self.transform = transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),
transforms.Normalize(=[0.485, 0.456, 0.406],
mean=[0.229, 0.224, 0.225]
std
)
])
# Load class labels from file or define them
self.load_imagenet_labels()
def load_imagenet_labels(self):
"""Load ImageNet class labels"""
# You can download from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
try:
with open('imagenet_classes.txt', 'r') as f:
self.class_labels = [line.strip() for line in f.readlines()]
except FileNotFoundError:
# Fallback to first few classes
self.class_labels = [f"class_{i}" for i in range(1000)]
def decode_request(self, request):
"""Handle both single images and batch requests"""
if isinstance(request, dict):
if "images" in request: # Batch request
= []
images for img_data in request["images"]:
if isinstance(img_data, str): # base64
import base64
= base64.b64decode(img_data)
image_data = Image.open(io.BytesIO(image_data)).convert('RGB')
image else: # direct bytes
= Image.open(io.BytesIO(img_data)).convert('RGB')
image
images.append(image)return images
elif "image" in request: # Single request
import base64
= base64.b64decode(request["image"])
image_data = Image.open(io.BytesIO(image_data)).convert('RGB')
image return [image] # Wrap in list for consistent handling
# Direct upload
= Image.open(io.BytesIO(request)).convert('RGB')
image return [image]
def batch(self, inputs: List[Any]) -> List[Any]:
"""Custom batching logic"""
# Flatten all images from all requests
= []
all_images = []
batch_sizes
for inp in inputs:
if isinstance(inp, list):
all_images.extend(inp)len(inp))
batch_sizes.append(else:
all_images.append(inp)1)
batch_sizes.append(
return all_images, batch_sizes
def predict(self, batch_data):
"""Process batch of images efficiently"""
= batch_data
images, batch_sizes
# Preprocess all images
= torch.stack([
batch_tensor self.transform(img) for img in images
self.device)
]).to(
# Run batch inference
with torch.no_grad():
= self.model(batch_tensor)
outputs = torch.nn.functional.softmax(outputs, dim=1)
probabilities
return probabilities, batch_sizes
def unbatch(self, output):
"""Split batch results back to individual responses"""
= output
probabilities, batch_sizes = []
results = 0
start_idx
for batch_size in batch_sizes:
= probabilities[start_idx:start_idx + batch_size]
batch_probs
results.append(batch_probs)+= batch_size
start_idx
return results
def encode_response(self, probabilities):
"""Format response for batch or single predictions"""
if len(probabilities.shape) == 1: # Single prediction
= probabilities.unsqueeze(0)
probabilities
= []
all_results for prob_vector in probabilities:
= torch.topk(prob_vector, 5)
top5_prob, top5_indices
= []
predictions for i in range(5):
= top5_indices[i].item()
idx = top5_prob[i].item()
prob
predictions.append({"class": self.class_labels[idx],
"confidence": round(prob, 4),
"class_id": idx
})
all_results.append(predictions)
return {
"predictions": all_results[0] if len(all_results) == 1 else all_results,
"model": "mobilenet_v2",
"batch_size": len(all_results)
}
if __name__ == "__main__":
= BatchedMobileNetV2API()
api = ls.LitServer(
server
api,="auto",
accelerator=8,
max_batch_size=0.1, # 100ms timeout for batching
batch_timeout
)=8000, num_workers=2) server.run(port
Adding Model Quantization for Better Performance
# quantized_mobilenet_api.py
import torch
import torch.quantization
from torchvision import models
import litserve as ls
class QuantizedMobileNetV2API(ls.LitAPI):
def setup(self, device):
"""Setup quantized MobileNetV2 for faster inference"""
# Load pre-trained model
self.model = models.mobilenet_v2(pretrained=True)
self.model.eval()
# Apply dynamic quantization for CPU inference
if device == "cpu":
self.model = torch.quantization.quantize_dynamic(
self.model,
{torch.nn.Linear, torch.nn.Conv2d},=torch.qint8
dtype
)print("Applied dynamic quantization for CPU")
else:
self.model.to(device)
print(f"Using device: {device}")
# Rest of setup code...
self.transform = transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),
transforms.Normalize(=[0.485, 0.456, 0.406],
mean=[0.229, 0.224, 0.225]
std
)
])
# ... rest of the methods remain the same
Performance Optimization
Configuration for Production
# production_config.py
import litserve as ls
from advanced_mobilenet_api import BatchedMobileNetV2API
def create_production_server():
"""Create optimized server for production"""
= BatchedMobileNetV2API()
api
= ls.LitServer(
server
api,="auto", # Auto-detect GPU/CPU
accelerator=16, # Larger batches for throughput
max_batch_size=0.05, # 50ms batching timeout
batch_timeout=2, # Multiple workers per GPU
workers_per_device=30, # Request timeout
timeout
)
return server
if __name__ == "__main__":
= create_production_server()
server
server.run(=8000,
port="0.0.0.0", # Accept external connections
host=4 # Total number of workers
num_workers )
Monitoring and Logging
# monitored_api.py
import time
import logging
from collections import defaultdict
import litserve as ls
class MonitoredMobileNetV2API(BatchedMobileNetV2API):
def setup(self, device):
"""Setup with monitoring"""
super().setup(device)
# Setup logging
=logging.INFO)
logging.basicConfig(levelself.logger = logging.getLogger(__name__)
# Metrics tracking
self.metrics = defaultdict(list)
self.request_count = 0
def predict(self, batch_data):
"""Predict with timing and metrics"""
= time.time()
start_time
= super().predict(batch_data)
result
# Log timing
= time.time() - start_time
inference_time = len(batch_data[0])
batch_size
self.metrics['inference_times'].append(inference_time)
self.metrics['batch_sizes'].append(batch_size)
self.request_count += 1
self.logger.info(
f"Processed batch of {batch_size} images in {inference_time:.3f}s "
f"(Total requests: {self.request_count})"
)
return result
def encode_response(self, probabilities):
"""Add metrics to response"""
= super().encode_response(probabilities)
response
# Add performance metrics every 100 requests
if self.request_count % 100 == 0:
= sum(self.metrics['inference_times'][-100:]) / min(100, len(self.metrics['inference_times']))
avg_time 'metrics'] = {
response['avg_inference_time': round(avg_time, 4),
'total_requests': self.request_count
}
return response
Deployment
Docker Deployment
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
\
wget && rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Download ImageNet labels
RUN wget -O imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
EXPOSE 8000
CMD ["python", "production_config.py"]
# docker-compose.yml
version: '3.8'
services:
mobilenet-api:
build: .
ports:
- "8000:8000"
environment:
- CUDA_VISIBLE_DEVICES=0 # Set GPU if available
volumes:
- ./models:/app/models # Optional: for custom models
restart: unless-stopped
# Optional: Add nginx for load balancing
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- mobilenet-api
Kubernetes Deployment
# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: mobilenet-api
spec:
replicas: 3
selector:
matchLabels:
app: mobilenet-api
template:
metadata:
labels:
app: mobilenet-api
spec:
containers:
- name: mobilenet-api
image: your-registry/mobilenet-api:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "1Gi"
cpu: "500m"
limits:
memory: "2Gi"
cpu: "1000m"
env:
- name: WORKERS
value: "2"
---
apiVersion: v1
kind: Service
metadata:
name: mobilenet-service
spec:
selector:
app: mobilenet-api
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
Testing
Comprehensive Test Suite
# test_api.py
import pytest
import requests
import base64
import numpy as np
from PIL import Image
import io
class TestMobileNetV2API:
def setup_class(self):
"""Setup test configuration"""
self.base_url = "http://localhost:8000"
self.test_image = self.create_test_image()
def create_test_image(self):
"""Create a test image"""
# Create a simple test image
= Image.new('RGB', (224, 224), color='red')
img = io.BytesIO()
img_buffer format='JPEG')
img.save(img_buffer, 0)
img_buffer.seek(return img_buffer.getvalue()
def test_single_prediction(self):
"""Test single image prediction"""
= base64.b64encode(self.test_image).decode('utf-8')
encoded_image
= requests.post(
response f"{self.base_url}/predict",
={"image": encoded_image}
json
)
assert response.status_code == 200
= response.json()
result assert "predictions" in result
assert len(result["predictions"]) == 5
assert all("class" in pred and "confidence" in pred for pred in result["predictions"])
def test_batch_prediction(self):
"""Test batch prediction"""
= base64.b64encode(self.test_image).decode('utf-8')
encoded_image
= requests.post(
response f"{self.base_url}/predict",
={"images": [encoded_image, encoded_image]}
json
)
assert response.status_code == 200
= response.json()
result assert "batch_size" in result
assert result["batch_size"] == 2
def test_performance(self):
"""Test API performance"""
import time
= base64.b64encode(self.test_image).decode('utf-8')
encoded_image
# Warmup
f"{self.base_url}/predict", json={"image": encoded_image})
requests.post(
# Time multiple requests
= []
times for _ in range(10):
= time.time()
start = requests.post(f"{self.base_url}/predict", json={"image": encoded_image})
response - start)
times.append(time.time() assert response.status_code == 200
= sum(times) / len(times)
avg_time print(f"Average response time: {avg_time:.3f}s")
assert avg_time < 1.0 # Should respond within 1 second
if __name__ == "__main__":
__file__, "-v"]) pytest.main([
Load Testing
# load_test.py
import asyncio
import aiohttp
import base64
import time
from PIL import Image
import io
async def send_request(session, url, data):
"""Send a single request"""
try:
async with session.post(url, json=data) as response:
= await response.json()
result return response.status, time.time()
except Exception as e:
return 500, time.time()
async def load_test(num_requests=100, concurrent=10):
"""Run load test"""
# Create test image
= Image.new('RGB', (224, 224), color='blue')
img = io.BytesIO()
img_buffer format='JPEG')
img.save(img_buffer, = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
encoded_image
= "http://localhost:8000/predict"
url = {"image": encoded_image}
data
= time.time()
start_time
async with aiohttp.ClientSession() as session:
# Create semaphore to limit concurrent requests
= asyncio.Semaphore(concurrent)
semaphore
async def bounded_request():
async with semaphore:
return await send_request(session, url, data)
# Send all requests
= [bounded_request() for _ in range(num_requests)]
tasks = await asyncio.gather(*tasks)
results
= time.time() - start_time
total_time
# Analyze results
= sum(1 for status, _ in results if status == 200)
successful = num_requests - successful
failed
print(f"Load Test Results:")
print(f" Total Requests: {num_requests}")
print(f" Successful: {successful}")
print(f" Failed: {failed}")
print(f" Total Time: {total_time:.2f}s")
print(f" Requests/sec: {num_requests/total_time:.2f}")
print(f" Concurrent: {concurrent}")
if __name__ == "__main__":
=200, concurrent=20)) asyncio.run(load_test(num_requests
Requirements File
# requirements.txt
litserve>=0.2.0
torch>=1.9.0
torchvision>=0.10.0
Pillow>=8.0.0
requests>=2.25.0
numpy>=1.21.0
aiohttp>=3.8.0 # For async testing pytest>=6.0.0 # For testing
Best Practices
- Model Optimization: Use quantization and TorchScript for production
- Batch Processing: Configure appropriate batch sizes based on your hardware
- Error Handling: Implement comprehensive error handling for robustness
- Monitoring: Add logging and metrics collection for production monitoring
- Security: Implement authentication and input validation for production APIs
- Caching: Consider caching frequently requested predictions
- Scaling: Use container orchestration for high-availability deployments
This guide provides a complete foundation for deploying MobileNetV2 with LitServe, from basic implementation to production-ready deployment with monitoring and testing.