LitServe Code Guide

LitServe is a high-performance, flexible AI model serving framework designed to deploy machine learning models with minimal code. It provides automatic batching, GPU acceleration, and easy scaling capabilities.
Installation
# Install LitServe
pip install litserve
# For GPU support
pip install litserve[gpu]
# For development dependencies
pip install litserve[dev]Basic Usage
1. Creating Your First LitServe API
import litserve as ls
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class SimpleTextGenerator(ls.LitAPI):
def setup(self, device):
# Load model and tokenizer during server startup
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
self.model.to(device)
self.model.eval()
def decode_request(self, request):
# Process incoming request
return request["prompt"]
def predict(self, x):
# Run inference
inputs = self.tokenizer.encode(x, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=100,
num_return_sequences=1,
temperature=0.7
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def encode_response(self, output):
# Format response
return {"generated_text": output}
# Create and start server
if __name__ == "__main__":
api = SimpleTextGenerator()
server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server.run(port=8000)2. Making Requests to Your API
import requests
# Test the API
response = requests.post(
"http://localhost:8000/predict",
json={"prompt": "The future of AI is"}
)
print(response.json())Core Concepts
LitAPI Class Structure
Every LitServe API must inherit from ls.LitAPI and implement these core methods:
class MyAPI(ls.LitAPI):
def setup(self, device):
"""Initialize models, load weights, set up preprocessing"""
pass
def decode_request(self, request):
"""Parse and validate incoming requests"""
pass
def predict(self, x):
"""Run model inference"""
pass
def encode_response(self, output):
"""Format model output for HTTP response"""
passOptional Methods
class AdvancedAPI(ls.LitAPI):
def batch(self, inputs):
"""Custom batching logic (optional)"""
return inputs
def unbatch(self, output):
"""Custom unbatching logic (optional)"""
return output
def preprocess(self, input_data):
"""Additional preprocessing (optional)"""
return input_data
def postprocess(self, output):
"""Additional postprocessing (optional)"""
return outputAdvanced Features
1. Custom Batching
class BatchedImageClassifier(ls.LitAPI):
def setup(self, device):
from torchvision import models, transforms
self.model = models.resnet50(pretrained=True)
self.model.to(device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def decode_request(self, request):
from PIL import Image
import base64
import io
# Decode base64 image
image_data = base64.b64decode(request["image"])
image = Image.open(io.BytesIO(image_data))
return self.transform(image).unsqueeze(0)
def batch(self, inputs):
# Custom batching for images
return torch.cat(inputs, dim=0)
def predict(self, batch):
with torch.no_grad():
outputs = self.model(batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
return probabilities
def unbatch(self, output):
# Split batch back to individual predictions
return [pred.unsqueeze(0) for pred in output]
def encode_response(self, output):
# Get top prediction
confidence, predicted = torch.max(output, 1)
return {
"class_id": predicted.item(),
"confidence": confidence.item()
}2. Streaming Responses
class StreamingChatAPI(ls.LitAPI):
def setup(self, device):
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
self.model.to(device)
def decode_request(self, request):
return request["message"]
def predict(self, x):
# Generator for streaming
inputs = self.tokenizer.encode(x, return_tensors="pt")
for i in range(50): # Generate up to 50 tokens
with torch.no_grad():
outputs = self.model(inputs)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.multinomial(
torch.softmax(next_token_logits, dim=-1),
num_samples=1
)
inputs = torch.cat([inputs, next_token], dim=-1)
# Yield each token
token_text = self.tokenizer.decode(next_token[0])
yield {"token": token_text}
# Stop if end token
if next_token.item() == self.tokenizer.eos_token_id:
break
def encode_response(self, output):
return output
# Enable streaming
server = ls.LitServer(api, accelerator="auto", stream=True)3. Multiple GPU Support
# Automatic multi-GPU scaling
server = ls.LitServer(
api,
accelerator="auto",
devices="auto", # Use all available GPUs
max_batch_size=8
)
# Specify specific GPUs
server = ls.LitServer(
api,
accelerator="gpu",
devices=[0, 1, 2], # Use GPUs 0, 1, and 2
max_batch_size=16
)4. Authentication and Middleware
class AuthenticatedAPI(ls.LitAPI):
def setup(self, device):
# Your model setup
pass
def authenticate(self, request):
"""Custom authentication logic"""
api_key = request.headers.get("Authorization")
if not api_key or not self.validate_api_key(api_key):
raise ls.AuthenticationError("Invalid API key")
return True
def validate_api_key(self, api_key):
# Your API key validation logic
valid_keys = ["your-secret-key-1", "your-secret-key-2"]
return api_key.replace("Bearer ", "") in valid_keys
def decode_request(self, request):
return request["data"]
def predict(self, x):
# Your prediction logic
return f"Processed: {x}"
def encode_response(self, output):
return {"result": output}Configuration Options
Server Configuration
server = ls.LitServer(
api=api,
accelerator="auto", # "auto", "cpu", "gpu", "mps"
devices="auto", # Device selection
max_batch_size=4, # Maximum batch size
batch_timeout=0.1, # Batch timeout in seconds
workers_per_device=1, # Workers per device
timeout=30, # Request timeout
stream=False, # Enable streaming
spec=None, # Custom OpenAPI spec
)Environment Variables
# Set device preferences
export CUDA_VISIBLE_DEVICES=0,1,2
# Set batch configuration
export LITSERVE_MAX_BATCH_SIZE=8
export LITSERVE_BATCH_TIMEOUT=0.05
# Set worker configuration
export LITSERVE_WORKERS_PER_DEVICE=2Custom Configuration File
# config.py
class Config:
MAX_BATCH_SIZE = 8
BATCH_TIMEOUT = 0.1
WORKERS_PER_DEVICE = 2
ACCELERATOR = "auto"
TIMEOUT = 60
# Use in your API
from config import Config
server = ls.LitServer(
api,
max_batch_size=Config.MAX_BATCH_SIZE,
batch_timeout=Config.BATCH_TIMEOUT,
workers_per_device=Config.WORKERS_PER_DEVICE
)Examples
1. Image Classification API
import litserve as ls
import torch
from torchvision import models, transforms
from PIL import Image
import base64
import io
class ImageClassificationAPI(ls.LitAPI):
def setup(self, device):
# Load pre-trained ResNet model
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
self.model.to(device)
self.model.eval()
# Image preprocessing
self.preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Load ImageNet class labels
with open('imagenet_classes.txt') as f:
self.classes = [line.strip() for line in f.readlines()]
def decode_request(self, request):
# Decode base64 image
encoded_image = request["image"]
image_bytes = base64.b64decode(encoded_image)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
return self.preprocess(image).unsqueeze(0)
def predict(self, x):
with torch.no_grad():
output = self.model(x)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
return probabilities
def encode_response(self, output):
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(output, 5)
results = []
for i in range(top5_prob.size(0)):
results.append({
"class": self.classes[top5_catid[i]],
"probability": float(top5_prob[i])
})
return {"predictions": results}
if __name__ == "__main__":
api = ImageClassificationAPI()
server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server.run(port=8000)2. Text Embedding API
import litserve as ls
from sentence_transformers import SentenceTransformer
import numpy as np
class TextEmbeddingAPI(ls.LitAPI):
def setup(self, device):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model.to(device)
def decode_request(self, request):
texts = request.get("texts", [])
if isinstance(texts, str):
texts = [texts]
return texts
def predict(self, texts):
embeddings = self.model.encode(texts)
return embeddings
def encode_response(self, embeddings):
return {
"embeddings": embeddings.tolist(),
"dimension": embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings)
}
if __name__ == "__main__":
api = TextEmbeddingAPI()
server = ls.LitServer(api, accelerator="auto", max_batch_size=32)
server.run(port=8000)3. Multi-Modal API (Text + Image)
import litserve as ls
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import base64
import io
class ImageCaptioningAPI(ls.LitAPI):
def setup(self, device):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
self.model.to(device)
def decode_request(self, request):
# Handle both image and optional text input
encoded_image = request["image"]
text_prompt = request.get("text", "")
image_bytes = base64.b64decode(encoded_image)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
return {"image": image, "text": text_prompt}
def predict(self, inputs):
processed = self.processor(
images=inputs["image"],
text=inputs["text"],
return_tensors="pt"
)
with torch.no_grad():
outputs = self.model.generate(**processed, max_length=50)
caption = self.processor.decode(outputs[0], skip_special_tokens=True)
return caption
def encode_response(self, caption):
return {"caption": caption}Best Practices
1. Resource Management
class OptimizedAPI(ls.LitAPI):
def setup(self, device):
# Use torch.jit.script for optimization
self.model = torch.jit.script(your_model)
# Enable mixed precision if using GPU
if device.type == 'cuda':
self.scaler = torch.cuda.amp.GradScaler()
# Pre-allocate tensors for common shapes
self.common_shapes = {}
def predict(self, x):
# Use autocast for mixed precision
if hasattr(self, 'scaler'):
with torch.cuda.amp.autocast():
return self.model(x)
return self.model(x)2. Error Handling
class RobustAPI(ls.LitAPI):
def decode_request(self, request):
try:
# Validate required fields
if "input" not in request:
raise ls.ValidationError("Missing required field: input")
data = request["input"]
# Type validation
if not isinstance(data, (str, list)):
raise ls.ValidationError("Input must be string or list")
return data
except Exception as e:
raise ls.ValidationError(f"Request parsing failed: {str(e)}")
def predict(self, x):
try:
result = self.model(x)
return result
except torch.cuda.OutOfMemoryError:
raise ls.ServerError("GPU memory exhausted")
except Exception as e:
raise ls.ServerError(f"Prediction failed: {str(e)}")3. Monitoring and Logging
import logging
import time
class MonitoredAPI(ls.LitAPI):
def setup(self, device):
self.logger = logging.getLogger(__name__)
self.request_count = 0
# Your model setup
def decode_request(self, request):
self.request_count += 1
self.logger.info(f"Processing request #{self.request_count}")
return request["data"]
def predict(self, x):
start_time = time.time()
result = self.model(x)
inference_time = time.time() - start_time
self.logger.info(f"Inference completed in {inference_time:.3f}s")
return result4. Model Versioning
class VersionedAPI(ls.LitAPI):
def setup(self, device):
self.version = "1.0.0"
self.model_path = f"models/model_v{self.version}"
# Load versioned model
def encode_response(self, output):
return {
"result": output,
"model_version": self.version,
"timestamp": time.time()
}Troubleshooting
Common Issues and Solutions
1. CUDA Out of Memory
# Solution: Reduce batch size or implement gradient checkpointing
server = ls.LitServer(api, max_batch_size=2) # Reduce batch size
# Or clear cache in your predict method
def predict(self, x):
torch.cuda.empty_cache() # Clear unused memory
result = self.model(x)
return result2. Slow Inference
# Enable model optimization
def setup(self, device):
self.model.eval() # Set to evaluation mode
self.model = torch.jit.script(self.model) # JIT compilation
# Use half precision if supported
if device.type == 'cuda':
self.model.half()3. Request Timeout
# Increase timeout settings
server = ls.LitServer(
api,
timeout=60, # Increase request timeout
batch_timeout=1.0 # Increase batch timeout
)4. Port Already in Use
# Check and kill existing processes
import subprocess
subprocess.run(["lsof", "-ti:8000", "|", "xargs", "kill", "-9"], shell=True)
# Or use a different port
server.run(port=8001)Performance Optimization Tips
- Use appropriate batch sizes: Start with small batches and gradually increase
- Enable GPU acceleration: Use
accelerator="auto"for automatic GPU detection - Optimize model loading: Load models once in
setup(), not inpredict() - Use mixed precision: Enable autocast for GPU inference
- Profile your code: Use tools like
torch.profilerto identify bottlenecks - Cache preprocessed data: Store frequently used transformations
Deployment Checklist
This guide covers the essential aspects of using LitServe for deploying AI models. For the most up-to-date information, always refer to the official LitServe documentation.