This guide demonstrates how to deploy a MobileNetV2 image classification model using LitServe for efficient, scalable inference.

Installation

# Install required packages
pip install litserve torch torchvision pillow requests

Basic Implementation

Simple MobileNetV2 API

# mobilenet_api.py
import io
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls


class MobileNetV2API(ls.LitAPI):
    def setup(self, device):
        """Initialize the model and preprocessing pipeline"""
        # Load pre-trained MobileNetV2
        self.model = models.mobilenet_v2(pretrained=True)
        self.model.eval()
        self.model.to(device)
        
        # Define image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        # ImageNet class labels (first 10 for brevity)
        self.class_labels = [
            "tench", "goldfish", "great white shark", "tiger shark",
            "hammerhead", "electric ray", "stingray", "cock", 
            "hen", "ostrich"
            # ... add all 1000 ImageNet classes
        ]

    def decode_request(self, request):
        """Parse incoming request and prepare image"""
        if isinstance(request, dict) and "image" in request:
            # Handle base64 encoded image
            import base64
            image_data = base64.b64decode(request["image"])
            image = Image.open(io.BytesIO(image_data)).convert('RGB')
        else:
            # Handle direct image upload
            image = Image.open(io.BytesIO(request)).convert('RGB')
        
        return image

    def predict(self, image):
        """Run inference on the preprocessed image"""
        # Preprocess image
        input_tensor = self.transform(image).unsqueeze(0)
        input_tensor = input_tensor.to(next(self.model.parameters()).device)
        
        # Run inference
        with torch.no_grad():
            outputs = self.model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
        
        return probabilities

    def encode_response(self, probabilities):
        """Format the response"""
        # Get top 5 predictions
        top5_prob, top5_indices = torch.topk(probabilities, 5)
        
        results = []
        for i in range(5):
            idx = top5_indices[i].item()
            prob = top5_prob[i].item()
            label = self.class_labels[idx] if idx < len(self.class_labels) else f"class_{idx}"
            results.append({
                "class": label,
                "confidence": round(prob, 4),
                "class_id": idx
            })
        
        return {
            "predictions": results,
            "model": "mobilenet_v2"
        }


if __name__ == "__main__":
    # Create and run the API
    api = MobileNetV2API()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
    server.run(port=8000)

Client Code for Testing

# client.py
import requests
import base64
from PIL import Image
import io


def encode_image(image_path):
    """Encode image to base64"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


def test_image_classification(image_path, server_url="http://localhost:8000/predict"):
    """Test the MobileNetV2 API"""
    # Method 1: Send as base64 in JSON
    encoded_image = encode_image(image_path)
    response = requests.post(
        server_url,
        json={"image": encoded_image}
    )
    
    if response.status_code == 200:
        result = response.json()
        print("Top 5 Predictions:")
        for pred in result["predictions"]:
            print(f"  {pred['class']}: {pred['confidence']:.4f}")
    else:
        print(f"Error: {response.status_code} - {response.text}")


def test_direct_upload(image_path, server_url="http://localhost:8000/predict"):
    """Test with direct image upload"""
    with open(image_path, 'rb') as f:
        files = {'file': f}
        response = requests.post(server_url, files=files)
    
    if response.status_code == 200:
        result = response.json()
        print("Predictions:", result)


if __name__ == "__main__":
    # Test with a sample image
    test_image_classification("sample_image.jpg")

Advanced Features

Batch Processing with Custom Batching

# advanced_mobilenet_api.py
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import litserve as ls
import json
from typing import List, Any


class BatchedMobileNetV2API(ls.LitAPI):
    def setup(self, device):
        """Initialize model with batch processing capabilities"""
        self.model = models.mobilenet_v2(pretrained=True)
        self.model.eval()
        self.model.to(device)
        self.device = device
        
        # Optimized transform for batch processing
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        # Load class labels from file or define them
        self.load_imagenet_labels()

    def load_imagenet_labels(self):
        """Load ImageNet class labels"""
        # You can download from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
        try:
            with open('imagenet_classes.txt', 'r') as f:
                self.class_labels = [line.strip() for line in f.readlines()]
        except FileNotFoundError:
            # Fallback to first few classes
            self.class_labels = [f"class_{i}" for i in range(1000)]

    def decode_request(self, request):
        """Handle both single images and batch requests"""
        if isinstance(request, dict):
            if "images" in request:  # Batch request
                images = []
                for img_data in request["images"]:
                    if isinstance(img_data, str):  # base64
                        import base64
                        image_data = base64.b64decode(img_data)
                        image = Image.open(io.BytesIO(image_data)).convert('RGB')
                    else:  # direct bytes
                        image = Image.open(io.BytesIO(img_data)).convert('RGB')
                    images.append(image)
                return images
            elif "image" in request:  # Single request
                import base64
                image_data = base64.b64decode(request["image"])
                image = Image.open(io.BytesIO(image_data)).convert('RGB')
                return [image]  # Wrap in list for consistent handling
        
        # Direct upload
        image = Image.open(io.BytesIO(request)).convert('RGB')
        return [image]

    def batch(self, inputs: List[Any]) -> List[Any]:
        """Custom batching logic"""
        # Flatten all images from all requests
        all_images = []
        batch_sizes = []
        
        for inp in inputs:
            if isinstance(inp, list):
                all_images.extend(inp)
                batch_sizes.append(len(inp))
            else:
                all_images.append(inp)
                batch_sizes.append(1)
        
        return all_images, batch_sizes

    def predict(self, batch_data):
        """Process batch of images efficiently"""
        images, batch_sizes = batch_data
        
        # Preprocess all images
        batch_tensor = torch.stack([
            self.transform(img) for img in images
        ]).to(self.device)
        
        # Run batch inference
        with torch.no_grad():
            outputs = self.model(batch_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
        
        return probabilities, batch_sizes

    def unbatch(self, output):
        """Split batch results back to individual responses"""
        probabilities, batch_sizes = output
        results = []
        start_idx = 0
        
        for batch_size in batch_sizes:
            batch_probs = probabilities[start_idx:start_idx + batch_size]
            results.append(batch_probs)
            start_idx += batch_size
        
        return results

    def encode_response(self, probabilities):
        """Format response for batch or single predictions"""
        if len(probabilities.shape) == 1:  # Single prediction
            probabilities = probabilities.unsqueeze(0)
        
        all_results = []
        for prob_vector in probabilities:
            top5_prob, top5_indices = torch.topk(prob_vector, 5)
            
            predictions = []
            for i in range(5):
                idx = top5_indices[i].item()
                prob = top5_prob[i].item()
                predictions.append({
                    "class": self.class_labels[idx],
                    "confidence": round(prob, 4),
                    "class_id": idx
                })
            
            all_results.append(predictions)
        
        return {
            "predictions": all_results[0] if len(all_results) == 1 else all_results,
            "model": "mobilenet_v2",
            "batch_size": len(all_results)
        }


if __name__ == "__main__":
    api = BatchedMobileNetV2API()
    server = ls.LitServer(
        api,
        accelerator="auto",
        max_batch_size=8,
        batch_timeout=0.1,  # 100ms timeout for batching
    )
    server.run(port=8000, num_workers=2)

Adding Model Quantization for Better Performance

# quantized_mobilenet_api.py
import torch
import torch.quantization
from torchvision import models
import litserve as ls


class QuantizedMobileNetV2API(ls.LitAPI):
    def setup(self, device):
        """Setup quantized MobileNetV2 for faster inference"""
        # Load pre-trained model
        self.model = models.mobilenet_v2(pretrained=True)
        self.model.eval()
        
        # Apply dynamic quantization for CPU inference
        if device == "cpu":
            self.model = torch.quantization.quantize_dynamic(
                self.model,
                {torch.nn.Linear, torch.nn.Conv2d},
                dtype=torch.qint8
            )
            print("Applied dynamic quantization for CPU")
        else:
            self.model.to(device)
            print(f"Using device: {device}")
        
        # Rest of setup code...
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    # ... rest of the methods remain the same

Performance Optimization

Configuration for Production

# production_config.py
import litserve as ls
from advanced_mobilenet_api import BatchedMobileNetV2API

def create_production_server():
    """Create optimized server for production"""
    api = BatchedMobileNetV2API()
    
    server = ls.LitServer(
        api,
        accelerator="auto",  # Auto-detect GPU/CPU
        max_batch_size=16,   # Larger batches for throughput
        batch_timeout=0.05,  # 50ms batching timeout
        workers_per_device=2,  # Multiple workers per GPU
        timeout=30,          # Request timeout
    )
    
    return server

if __name__ == "__main__":
    server = create_production_server()
    server.run(
        port=8000,
        host="0.0.0.0",  # Accept external connections
        num_workers=4     # Total number of workers
    )

Monitoring and Logging

# monitored_api.py
import time
import logging
from collections import defaultdict
import litserve as ls


class MonitoredMobileNetV2API(BatchedMobileNetV2API):
    def setup(self, device):
        """Setup with monitoring"""
        super().setup(device)
        
        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
        
        # Metrics tracking
        self.metrics = defaultdict(list)
        self.request_count = 0

    def predict(self, batch_data):
        """Predict with timing and metrics"""
        start_time = time.time()
        
        result = super().predict(batch_data)
        
        # Log timing
        inference_time = time.time() - start_time
        batch_size = len(batch_data[0])
        
        self.metrics['inference_times'].append(inference_time)
        self.metrics['batch_sizes'].append(batch_size)
        self.request_count += 1
        
        self.logger.info(
            f"Processed batch of {batch_size} images in {inference_time:.3f}s "
            f"(Total requests: {self.request_count})"
        )
        
        return result

    def encode_response(self, probabilities):
        """Add metrics to response"""
        response = super().encode_response(probabilities)
        
        # Add performance metrics every 100 requests
        if self.request_count % 100 == 0:
            avg_time = sum(self.metrics['inference_times'][-100:]) / min(100, len(self.metrics['inference_times']))
            response['metrics'] = {
                'avg_inference_time': round(avg_time, 4),
                'total_requests': self.request_count
            }
        
        return response

Deployment

Docker Deployment

# Dockerfile
FROM python:3.9-slim

WORKDIR /app

# Install system dependencies
RUN apt-get update && apt-get install -y \
    wget \
    && rm -rf /var/lib/apt/lists/*

# Copy requirements
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY . .

# Download ImageNet labels
RUN wget -O imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

EXPOSE 8000

CMD ["python", "production_config.py"]
# docker-compose.yml
version: '3.8'

services:
  mobilenet-api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - CUDA_VISIBLE_DEVICES=0  # Set GPU if available
    volumes:
      - ./models:/app/models  # Optional: for custom models
    restart: unless-stopped
    
  # Optional: Add nginx for load balancing
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - mobilenet-api

Kubernetes Deployment

# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: mobilenet-api
spec:
  replicas: 3
  selector:
    matchLabels:
      app: mobilenet-api
  template:
    metadata:
      labels:
        app: mobilenet-api
    spec:
      containers:
      - name: mobilenet-api
        image: your-registry/mobilenet-api:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "1Gi"
            cpu: "500m"
          limits:
            memory: "2Gi"
            cpu: "1000m"
        env:
        - name: WORKERS
          value: "2"
---
apiVersion: v1
kind: Service
metadata:
  name: mobilenet-service
spec:
  selector:
    app: mobilenet-api
  ports:
  - port: 80
    targetPort: 8000
  type: LoadBalancer

Testing

Comprehensive Test Suite

# test_api.py
import pytest
import requests
import base64
import numpy as np
from PIL import Image
import io


class TestMobileNetV2API:
    def setup_class(self):
        """Setup test configuration"""
        self.base_url = "http://localhost:8000"
        self.test_image = self.create_test_image()

    def create_test_image(self):
        """Create a test image"""
        # Create a simple test image
        img = Image.new('RGB', (224, 224), color='red')
        img_buffer = io.BytesIO()
        img.save(img_buffer, format='JPEG')
        img_buffer.seek(0)
        return img_buffer.getvalue()

    def test_single_prediction(self):
        """Test single image prediction"""
        encoded_image = base64.b64encode(self.test_image).decode('utf-8')
        
        response = requests.post(
            f"{self.base_url}/predict",
            json={"image": encoded_image}
        )
        
        assert response.status_code == 200
        result = response.json()
        assert "predictions" in result
        assert len(result["predictions"]) == 5
        assert all("class" in pred and "confidence" in pred for pred in result["predictions"])

    def test_batch_prediction(self):
        """Test batch prediction"""
        encoded_image = base64.b64encode(self.test_image).decode('utf-8')
        
        response = requests.post(
            f"{self.base_url}/predict",
            json={"images": [encoded_image, encoded_image]}
        )
        
        assert response.status_code == 200
        result = response.json()
        assert "batch_size" in result
        assert result["batch_size"] == 2

    def test_performance(self):
        """Test API performance"""
        import time
        
        encoded_image = base64.b64encode(self.test_image).decode('utf-8')
        
        # Warmup
        requests.post(f"{self.base_url}/predict", json={"image": encoded_image})
        
        # Time multiple requests
        times = []
        for _ in range(10):
            start = time.time()
            response = requests.post(f"{self.base_url}/predict", json={"image": encoded_image})
            times.append(time.time() - start)
            assert response.status_code == 200
        
        avg_time = sum(times) / len(times)
        print(f"Average response time: {avg_time:.3f}s")
        assert avg_time < 1.0  # Should respond within 1 second


if __name__ == "__main__":
    pytest.main([__file__, "-v"])

Load Testing

# load_test.py
import asyncio
import aiohttp
import base64
import time
from PIL import Image
import io


async def send_request(session, url, data):
    """Send a single request"""
    try:
        async with session.post(url, json=data) as response:
            result = await response.json()
            return response.status, time.time()
    except Exception as e:
        return 500, time.time()


async def load_test(num_requests=100, concurrent=10):
    """Run load test"""
    # Create test image
    img = Image.new('RGB', (224, 224), color='blue')
    img_buffer = io.BytesIO()
    img.save(img_buffer, format='JPEG')
    encoded_image = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
    
    url = "http://localhost:8000/predict"
    data = {"image": encoded_image}
    
    start_time = time.time()
    
    async with aiohttp.ClientSession() as session:
        # Create semaphore to limit concurrent requests
        semaphore = asyncio.Semaphore(concurrent)
        
        async def bounded_request():
            async with semaphore:
                return await send_request(session, url, data)
        
        # Send all requests
        tasks = [bounded_request() for _ in range(num_requests)]
        results = await asyncio.gather(*tasks)
    
    total_time = time.time() - start_time
    
    # Analyze results
    successful = sum(1 for status, _ in results if status == 200)
    failed = num_requests - successful
    
    print(f"Load Test Results:")
    print(f"  Total Requests: {num_requests}")
    print(f"  Successful: {successful}")
    print(f"  Failed: {failed}")
    print(f"  Total Time: {total_time:.2f}s")
    print(f"  Requests/sec: {num_requests/total_time:.2f}")
    print(f"  Concurrent: {concurrent}")


if __name__ == "__main__":
    asyncio.run(load_test(num_requests=200, concurrent=20))

Requirements File

# requirements.txt
litserve>=0.2.0
torch>=1.9.0
torchvision>=0.10.0
Pillow>=8.0.0
requests>=2.25.0
numpy>=1.21.0
aiohttp>=3.8.0  # For async testing
pytest>=6.0.0  # For testing

Best Practices

  1. Model Optimization: Use quantization and TorchScript for production
  2. Batch Processing: Configure appropriate batch sizes based on your hardware
  3. Error Handling: Implement comprehensive error handling for robustness
  4. Monitoring: Add logging and metrics collection for production monitoring
  5. Security: Implement authentication and input validation for production APIs
  6. Caching: Consider caching frequently requested predictions
  7. Scaling: Use container orchestration for high-availability deployments

This guide provides a complete foundation for deploying MobileNetV2 with LitServe, from basic implementation to production-ready deployment with monitoring and testing.