LitServe Code Guide

code
mlops
beginner
Author

Krishnatheja Vanka

Published

May 27, 2025

LitServe Code Guide

LitServe is a high-performance, flexible AI model serving framework designed to deploy machine learning models with minimal code. It provides automatic batching, GPU acceleration, and easy scaling capabilities.

Table of Contents

  1. Installation
  2. Basic Usage
  3. Core Concepts
  4. Advanced Features
  5. Configuration Options
  6. Examples
  7. Best Practices
  8. Troubleshooting

Installation

# Install LitServe
pip install litserve

# For GPU support
pip install litserve[gpu]

# For development dependencies
pip install litserve[dev]

Basic Usage

1. Creating Your First LitServe API

import litserve as ls
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class SimpleTextGenerator(ls.LitAPI):
    def setup(self, device):
        # Load model and tokenizer during server startup
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.model = AutoModelForCausalLM.from_pretrained("gpt2")
        self.model.to(device)
        self.model.eval()
    
    def decode_request(self, request):
        # Process incoming request
        return request["prompt"]
    
    def predict(self, x):
        # Run inference
        inputs = self.tokenizer.encode(x, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model.generate(
                inputs, 
                max_length=100, 
                num_return_sequences=1,
                temperature=0.7
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def encode_response(self, output):
        # Format response
        return {"generated_text": output}

# Create and start server
if __name__ == "__main__":
    api = SimpleTextGenerator()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
    server.run(port=8000)

2. Making Requests to Your API

import requests

# Test the API
response = requests.post(
    "http://localhost:8000/predict",
    json={"prompt": "The future of AI is"}
)
print(response.json())

Core Concepts

LitAPI Class Structure

Every LitServe API must inherit from ls.LitAPI and implement these core methods:

class MyAPI(ls.LitAPI):
    def setup(self, device):
        """Initialize models, load weights, set up preprocessing"""
        pass
    
    def decode_request(self, request):
        """Parse and validate incoming requests"""
        pass
    
    def predict(self, x):
        """Run model inference"""
        pass
    
    def encode_response(self, output):
        """Format model output for HTTP response"""
        pass

Optional Methods

class AdvancedAPI(ls.LitAPI):
    def batch(self, inputs):
        """Custom batching logic (optional)"""
        return inputs
    
    def unbatch(self, output):
        """Custom unbatching logic (optional)"""
        return output
    
    def preprocess(self, input_data):
        """Additional preprocessing (optional)"""
        return input_data
    
    def postprocess(self, output):
        """Additional postprocessing (optional)"""
        return output

Advanced Features

1. Custom Batching

class BatchedImageClassifier(ls.LitAPI):
    def setup(self, device):
        from torchvision import models, transforms
        self.model = models.resnet50(pretrained=True)
        self.model.to(device)
        self.model.eval()
        
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    
    def decode_request(self, request):
        from PIL import Image
        import base64
        import io
        
        # Decode base64 image
        image_data = base64.b64decode(request["image"])
        image = Image.open(io.BytesIO(image_data))
        return self.transform(image).unsqueeze(0)
    
    def batch(self, inputs):
        # Custom batching for images
        return torch.cat(inputs, dim=0)
    
    def predict(self, batch):
        with torch.no_grad():
            outputs = self.model(batch)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
        return probabilities
    
    def unbatch(self, output):
        # Split batch back to individual predictions
        return [pred.unsqueeze(0) for pred in output]
    
    def encode_response(self, output):
        # Get top prediction
        confidence, predicted = torch.max(output, 1)
        return {
            "class_id": predicted.item(),
            "confidence": confidence.item()
        }

2. Streaming Responses

class StreamingChatAPI(ls.LitAPI):
    def setup(self, device):
        from transformers import AutoTokenizer, AutoModelForCausalLM
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
        self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
        self.model.to(device)
    
    def decode_request(self, request):
        return request["message"]
    
    def predict(self, x):
        # Generator for streaming
        inputs = self.tokenizer.encode(x, return_tensors="pt")
        
        for i in range(50):  # Generate up to 50 tokens
            with torch.no_grad():
                outputs = self.model(inputs)
                next_token_logits = outputs.logits[:, -1, :]
                next_token = torch.multinomial(
                    torch.softmax(next_token_logits, dim=-1), 
                    num_samples=1
                )
                inputs = torch.cat([inputs, next_token], dim=-1)
                
                # Yield each token
                token_text = self.tokenizer.decode(next_token[0])
                yield {"token": token_text}
                
                # Stop if end token
                if next_token.item() == self.tokenizer.eos_token_id:
                    break
    
    def encode_response(self, output):
        return output

# Enable streaming
server = ls.LitServer(api, accelerator="auto", stream=True)

3. Multiple GPU Support

# Automatic multi-GPU scaling
server = ls.LitServer(
    api, 
    accelerator="auto",
    devices="auto",  # Use all available GPUs
    max_batch_size=8
)

# Specify specific GPUs
server = ls.LitServer(
    api,
    accelerator="gpu",
    devices=[0, 1, 2],  # Use GPUs 0, 1, and 2
    max_batch_size=16
)

4. Authentication and Middleware

class AuthenticatedAPI(ls.LitAPI):
    def setup(self, device):
        # Your model setup
        pass
    
    def authenticate(self, request):
        """Custom authentication logic"""
        api_key = request.headers.get("Authorization")
        if not api_key or not self.validate_api_key(api_key):
            raise ls.AuthenticationError("Invalid API key")
        return True
    
    def validate_api_key(self, api_key):
        # Your API key validation logic
        valid_keys = ["your-secret-key-1", "your-secret-key-2"]
        return api_key.replace("Bearer ", "") in valid_keys
    
    def decode_request(self, request):
        return request["data"]
    
    def predict(self, x):
        # Your prediction logic
        return f"Processed: {x}"
    
    def encode_response(self, output):
        return {"result": output}

Configuration Options

Server Configuration

server = ls.LitServer(
    api=api,
    accelerator="auto",           # "auto", "cpu", "gpu", "mps"
    devices="auto",               # Device selection
    max_batch_size=4,            # Maximum batch size
    batch_timeout=0.1,           # Batch timeout in seconds
    workers_per_device=1,        # Workers per device
    timeout=30,                  # Request timeout
    stream=False,                # Enable streaming
    spec=None,                   # Custom OpenAPI spec
)

Environment Variables

# Set device preferences
export CUDA_VISIBLE_DEVICES=0,1,2

# Set batch configuration
export LITSERVE_MAX_BATCH_SIZE=8
export LITSERVE_BATCH_TIMEOUT=0.05

# Set worker configuration
export LITSERVE_WORKERS_PER_DEVICE=2

Custom Configuration File

# config.py
class Config:
    MAX_BATCH_SIZE = 8
    BATCH_TIMEOUT = 0.1
    WORKERS_PER_DEVICE = 2
    ACCELERATOR = "auto"
    TIMEOUT = 60

# Use in your API
from config import Config

server = ls.LitServer(
    api, 
    max_batch_size=Config.MAX_BATCH_SIZE,
    batch_timeout=Config.BATCH_TIMEOUT,
    workers_per_device=Config.WORKERS_PER_DEVICE
)

Examples

1. Image Classification API

import litserve as ls
import torch
from torchvision import models, transforms
from PIL import Image
import base64
import io

class ImageClassificationAPI(ls.LitAPI):
    def setup(self, device):
        # Load pre-trained ResNet model
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.model.to(device)
        self.model.eval()
        
        # Image preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225]),
        ])
        
        # Load ImageNet class labels
        with open('imagenet_classes.txt') as f:
            self.classes = [line.strip() for line in f.readlines()]
    
    def decode_request(self, request):
        # Decode base64 image
        encoded_image = request["image"]
        image_bytes = base64.b64decode(encoded_image)
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        return self.preprocess(image).unsqueeze(0)
    
    def predict(self, x):
        with torch.no_grad():
            output = self.model(x)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
        return probabilities
    
    def encode_response(self, output):
        # Get top 5 predictions
        top5_prob, top5_catid = torch.topk(output, 5)
        results = []
        for i in range(top5_prob.size(0)):
            results.append({
                "class": self.classes[top5_catid[i]],
                "probability": float(top5_prob[i])
            })
        return {"predictions": results}

if __name__ == "__main__":
    api = ImageClassificationAPI()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
    server.run(port=8000)

2. Text Embedding API

import litserve as ls
from sentence_transformers import SentenceTransformer
import numpy as np

class TextEmbeddingAPI(ls.LitAPI):
    def setup(self, device):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.model.to(device)
    
    def decode_request(self, request):
        texts = request.get("texts", [])
        if isinstance(texts, str):
            texts = [texts]
        return texts
    
    def predict(self, texts):
        embeddings = self.model.encode(texts)
        return embeddings
    
    def encode_response(self, embeddings):
        return {
            "embeddings": embeddings.tolist(),
            "dimension": embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings)
        }

if __name__ == "__main__":
    api = TextEmbeddingAPI()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=32)
    server.run(port=8000)

3. Multi-Modal API (Text + Image)

import litserve as ls
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import base64
import io

class ImageCaptioningAPI(ls.LitAPI):
    def setup(self, device):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model.to(device)
    
    def decode_request(self, request):
        # Handle both image and optional text input
        encoded_image = request["image"]
        text_prompt = request.get("text", "")
        
        image_bytes = base64.b64decode(encoded_image)
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        return {"image": image, "text": text_prompt}
    
    def predict(self, inputs):
        processed = self.processor(
            images=inputs["image"], 
            text=inputs["text"], 
            return_tensors="pt"
        )
        
        with torch.no_grad():
            outputs = self.model.generate(**processed, max_length=50)
        
        caption = self.processor.decode(outputs[0], skip_special_tokens=True)
        return caption
    
    def encode_response(self, caption):
        return {"caption": caption}

Best Practices

1. Resource Management

class OptimizedAPI(ls.LitAPI):
    def setup(self, device):
        # Use torch.jit.script for optimization
        self.model = torch.jit.script(your_model)
        
        # Enable mixed precision if using GPU
        if device.type == 'cuda':
            self.scaler = torch.cuda.amp.GradScaler()
        
        # Pre-allocate tensors for common shapes
        self.common_shapes = {}
    
    def predict(self, x):
        # Use autocast for mixed precision
        if hasattr(self, 'scaler'):
            with torch.cuda.amp.autocast():
                return self.model(x)
        return self.model(x)

2. Error Handling

class RobustAPI(ls.LitAPI):
    def decode_request(self, request):
        try:
            # Validate required fields
            if "input" not in request:
                raise ls.ValidationError("Missing required field: input")
            
            data = request["input"]
            
            # Type validation
            if not isinstance(data, (str, list)):
                raise ls.ValidationError("Input must be string or list")
            
            return data
        except Exception as e:
            raise ls.ValidationError(f"Request parsing failed: {str(e)}")
    
    def predict(self, x):
        try:
            result = self.model(x)
            return result
        except torch.cuda.OutOfMemoryError:
            raise ls.ServerError("GPU memory exhausted")
        except Exception as e:
            raise ls.ServerError(f"Prediction failed: {str(e)}")

3. Monitoring and Logging

import logging
import time

class MonitoredAPI(ls.LitAPI):
    def setup(self, device):
        self.logger = logging.getLogger(__name__)
        self.request_count = 0
        # Your model setup
    
    def decode_request(self, request):
        self.request_count += 1
        self.logger.info(f"Processing request #{self.request_count}")
        return request["data"]
    
    def predict(self, x):
        start_time = time.time()
        result = self.model(x)
        inference_time = time.time() - start_time
        
        self.logger.info(f"Inference completed in {inference_time:.3f}s")
        return result

4. Model Versioning

class VersionedAPI(ls.LitAPI):
    def setup(self, device):
        self.version = "1.0.0"
        self.model_path = f"models/model_v{self.version}"
        # Load versioned model
    
    def encode_response(self, output):
        return {
            "result": output,
            "model_version": self.version,
            "timestamp": time.time()
        }

Troubleshooting

Common Issues and Solutions

1. CUDA Out of Memory

# Solution: Reduce batch size or implement gradient checkpointing
server = ls.LitServer(api, max_batch_size=2)  # Reduce batch size

# Or clear cache in your predict method
def predict(self, x):
    torch.cuda.empty_cache()  # Clear unused memory
    result = self.model(x)
    return result

2. Slow Inference

# Enable model optimization
def setup(self, device):
    self.model.eval()  # Set to evaluation mode
    self.model = torch.jit.script(self.model)  # JIT compilation
    
    # Use half precision if supported
    if device.type == 'cuda':
        self.model.half()

3. Request Timeout

# Increase timeout settings
server = ls.LitServer(
    api, 
    timeout=60,  # Increase request timeout
    batch_timeout=1.0  # Increase batch timeout
)

4. Port Already in Use

# Check and kill existing processes
import subprocess
subprocess.run(["lsof", "-ti:8000", "|", "xargs", "kill", "-9"], shell=True)

# Or use a different port
server.run(port=8001)

Performance Optimization Tips

  1. Use appropriate batch sizes: Start with small batches and gradually increase
  2. Enable GPU acceleration: Use accelerator="auto" for automatic GPU detection
  3. Optimize model loading: Load models once in setup(), not in predict()
  4. Use mixed precision: Enable autocast for GPU inference
  5. Profile your code: Use tools like torch.profiler to identify bottlenecks
  6. Cache preprocessed data: Store frequently used transformations

Deployment Checklist

This guide covers the essential aspects of using LitServe for deploying AI models. For the most up-to-date information, always refer to the official LitServe documentation.