LitServe Code Guide
LitServe is a high-performance, flexible AI model serving framework designed to deploy machine learning models with minimal code. It provides automatic batching, GPU acceleration, and easy scaling capabilities.
Table of Contents
Installation
# Install LitServe
pip install litserve
# For GPU support
pip install litserve[gpu]
# For development dependencies
pip install litserve[dev]
Basic Usage
1. Creating Your First LitServe API
import litserve as ls
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class SimpleTextGenerator(ls.LitAPI):
def setup(self, device):
# Load model and tokenizer during server startup
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
self.model.to(device)
self.model.eval()
def decode_request(self, request):
# Process incoming request
return request["prompt"]
def predict(self, x):
# Run inference
= self.tokenizer.encode(x, return_tensors="pt")
inputs with torch.no_grad():
= self.model.generate(
outputs
inputs, =100,
max_length=1,
num_return_sequences=0.7
temperature
)return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def encode_response(self, output):
# Format response
return {"generated_text": output}
# Create and start server
if __name__ == "__main__":
= SimpleTextGenerator()
api = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server =8000) server.run(port
2. Making Requests to Your API
import requests
# Test the API
= requests.post(
response "http://localhost:8000/predict",
={"prompt": "The future of AI is"}
json
)print(response.json())
Core Concepts
LitAPI Class Structure
Every LitServe API must inherit from ls.LitAPI
and implement these core methods:
class MyAPI(ls.LitAPI):
def setup(self, device):
"""Initialize models, load weights, set up preprocessing"""
pass
def decode_request(self, request):
"""Parse and validate incoming requests"""
pass
def predict(self, x):
"""Run model inference"""
pass
def encode_response(self, output):
"""Format model output for HTTP response"""
pass
Optional Methods
class AdvancedAPI(ls.LitAPI):
def batch(self, inputs):
"""Custom batching logic (optional)"""
return inputs
def unbatch(self, output):
"""Custom unbatching logic (optional)"""
return output
def preprocess(self, input_data):
"""Additional preprocessing (optional)"""
return input_data
def postprocess(self, output):
"""Additional postprocessing (optional)"""
return output
Advanced Features
1. Custom Batching
class BatchedImageClassifier(ls.LitAPI):
def setup(self, device):
from torchvision import models, transforms
self.model = models.resnet50(pretrained=True)
self.model.to(device)
self.model.eval()
self.transform = transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),=[0.485, 0.456, 0.406],
transforms.Normalize(mean=[0.229, 0.224, 0.225])
std
])
def decode_request(self, request):
from PIL import Image
import base64
import io
# Decode base64 image
= base64.b64decode(request["image"])
image_data = Image.open(io.BytesIO(image_data))
image return self.transform(image).unsqueeze(0)
def batch(self, inputs):
# Custom batching for images
return torch.cat(inputs, dim=0)
def predict(self, batch):
with torch.no_grad():
= self.model(batch)
outputs = torch.nn.functional.softmax(outputs, dim=1)
probabilities return probabilities
def unbatch(self, output):
# Split batch back to individual predictions
return [pred.unsqueeze(0) for pred in output]
def encode_response(self, output):
# Get top prediction
= torch.max(output, 1)
confidence, predicted return {
"class_id": predicted.item(),
"confidence": confidence.item()
}
2. Streaming Responses
class StreamingChatAPI(ls.LitAPI):
def setup(self, device):
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
self.model.to(device)
def decode_request(self, request):
return request["message"]
def predict(self, x):
# Generator for streaming
= self.tokenizer.encode(x, return_tensors="pt")
inputs
for i in range(50): # Generate up to 50 tokens
with torch.no_grad():
= self.model(inputs)
outputs = outputs.logits[:, -1, :]
next_token_logits = torch.multinomial(
next_token =-1),
torch.softmax(next_token_logits, dim=1
num_samples
)= torch.cat([inputs, next_token], dim=-1)
inputs
# Yield each token
= self.tokenizer.decode(next_token[0])
token_text yield {"token": token_text}
# Stop if end token
if next_token.item() == self.tokenizer.eos_token_id:
break
def encode_response(self, output):
return output
# Enable streaming
= ls.LitServer(api, accelerator="auto", stream=True) server
3. Multiple GPU Support
# Automatic multi-GPU scaling
= ls.LitServer(
server
api, ="auto",
accelerator="auto", # Use all available GPUs
devices=8
max_batch_size
)
# Specify specific GPUs
= ls.LitServer(
server
api,="gpu",
accelerator=[0, 1, 2], # Use GPUs 0, 1, and 2
devices=16
max_batch_size )
4. Authentication and Middleware
class AuthenticatedAPI(ls.LitAPI):
def setup(self, device):
# Your model setup
pass
def authenticate(self, request):
"""Custom authentication logic"""
= request.headers.get("Authorization")
api_key if not api_key or not self.validate_api_key(api_key):
raise ls.AuthenticationError("Invalid API key")
return True
def validate_api_key(self, api_key):
# Your API key validation logic
= ["your-secret-key-1", "your-secret-key-2"]
valid_keys return api_key.replace("Bearer ", "") in valid_keys
def decode_request(self, request):
return request["data"]
def predict(self, x):
# Your prediction logic
return f"Processed: {x}"
def encode_response(self, output):
return {"result": output}
Configuration Options
Server Configuration
= ls.LitServer(
server =api,
api="auto", # "auto", "cpu", "gpu", "mps"
accelerator="auto", # Device selection
devices=4, # Maximum batch size
max_batch_size=0.1, # Batch timeout in seconds
batch_timeout=1, # Workers per device
workers_per_device=30, # Request timeout
timeout=False, # Enable streaming
stream=None, # Custom OpenAPI spec
spec )
Environment Variables
# Set device preferences
export CUDA_VISIBLE_DEVICES=0,1,2
# Set batch configuration
export LITSERVE_MAX_BATCH_SIZE=8
export LITSERVE_BATCH_TIMEOUT=0.05
# Set worker configuration
export LITSERVE_WORKERS_PER_DEVICE=2
Custom Configuration File
# config.py
class Config:
= 8
MAX_BATCH_SIZE = 0.1
BATCH_TIMEOUT = 2
WORKERS_PER_DEVICE = "auto"
ACCELERATOR = 60
TIMEOUT
# Use in your API
from config import Config
= ls.LitServer(
server
api, =Config.MAX_BATCH_SIZE,
max_batch_size=Config.BATCH_TIMEOUT,
batch_timeout=Config.WORKERS_PER_DEVICE
workers_per_device )
Examples
1. Image Classification API
import litserve as ls
import torch
from torchvision import models, transforms
from PIL import Image
import base64
import io
class ImageClassificationAPI(ls.LitAPI):
def setup(self, device):
# Load pre-trained ResNet model
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
self.model.to(device)
self.model.eval()
# Image preprocessing
self.preprocess = transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),=[0.485, 0.456, 0.406],
transforms.Normalize(mean=[0.229, 0.224, 0.225]),
std
])
# Load ImageNet class labels
with open('imagenet_classes.txt') as f:
self.classes = [line.strip() for line in f.readlines()]
def decode_request(self, request):
# Decode base64 image
= request["image"]
encoded_image = base64.b64decode(encoded_image)
image_bytes = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image return self.preprocess(image).unsqueeze(0)
def predict(self, x):
with torch.no_grad():
= self.model(x)
output = torch.nn.functional.softmax(output[0], dim=0)
probabilities return probabilities
def encode_response(self, output):
# Get top 5 predictions
= torch.topk(output, 5)
top5_prob, top5_catid = []
results for i in range(top5_prob.size(0)):
results.append({"class": self.classes[top5_catid[i]],
"probability": float(top5_prob[i])
})return {"predictions": results}
if __name__ == "__main__":
= ImageClassificationAPI()
api = ls.LitServer(api, accelerator="auto", max_batch_size=4)
server =8000) server.run(port
2. Text Embedding API
import litserve as ls
from sentence_transformers import SentenceTransformer
import numpy as np
class TextEmbeddingAPI(ls.LitAPI):
def setup(self, device):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model.to(device)
def decode_request(self, request):
= request.get("texts", [])
texts if isinstance(texts, str):
= [texts]
texts return texts
def predict(self, texts):
= self.model.encode(texts)
embeddings return embeddings
def encode_response(self, embeddings):
return {
"embeddings": embeddings.tolist(),
"dimension": embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings)
}
if __name__ == "__main__":
= TextEmbeddingAPI()
api = ls.LitServer(api, accelerator="auto", max_batch_size=32)
server =8000) server.run(port
3. Multi-Modal API (Text + Image)
import litserve as ls
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import base64
import io
class ImageCaptioningAPI(ls.LitAPI):
def setup(self, device):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
self.model.to(device)
def decode_request(self, request):
# Handle both image and optional text input
= request["image"]
encoded_image = request.get("text", "")
text_prompt
= base64.b64decode(encoded_image)
image_bytes = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image
return {"image": image, "text": text_prompt}
def predict(self, inputs):
= self.processor(
processed =inputs["image"],
images=inputs["text"],
text="pt"
return_tensors
)
with torch.no_grad():
= self.model.generate(**processed, max_length=50)
outputs
= self.processor.decode(outputs[0], skip_special_tokens=True)
caption return caption
def encode_response(self, caption):
return {"caption": caption}
Best Practices
1. Resource Management
class OptimizedAPI(ls.LitAPI):
def setup(self, device):
# Use torch.jit.script for optimization
self.model = torch.jit.script(your_model)
# Enable mixed precision if using GPU
if device.type == 'cuda':
self.scaler = torch.cuda.amp.GradScaler()
# Pre-allocate tensors for common shapes
self.common_shapes = {}
def predict(self, x):
# Use autocast for mixed precision
if hasattr(self, 'scaler'):
with torch.cuda.amp.autocast():
return self.model(x)
return self.model(x)
2. Error Handling
class RobustAPI(ls.LitAPI):
def decode_request(self, request):
try:
# Validate required fields
if "input" not in request:
raise ls.ValidationError("Missing required field: input")
= request["input"]
data
# Type validation
if not isinstance(data, (str, list)):
raise ls.ValidationError("Input must be string or list")
return data
except Exception as e:
raise ls.ValidationError(f"Request parsing failed: {str(e)}")
def predict(self, x):
try:
= self.model(x)
result return result
except torch.cuda.OutOfMemoryError:
raise ls.ServerError("GPU memory exhausted")
except Exception as e:
raise ls.ServerError(f"Prediction failed: {str(e)}")
3. Monitoring and Logging
import logging
import time
class MonitoredAPI(ls.LitAPI):
def setup(self, device):
self.logger = logging.getLogger(__name__)
self.request_count = 0
# Your model setup
def decode_request(self, request):
self.request_count += 1
self.logger.info(f"Processing request #{self.request_count}")
return request["data"]
def predict(self, x):
= time.time()
start_time = self.model(x)
result = time.time() - start_time
inference_time
self.logger.info(f"Inference completed in {inference_time:.3f}s")
return result
4. Model Versioning
class VersionedAPI(ls.LitAPI):
def setup(self, device):
self.version = "1.0.0"
self.model_path = f"models/model_v{self.version}"
# Load versioned model
def encode_response(self, output):
return {
"result": output,
"model_version": self.version,
"timestamp": time.time()
}
Troubleshooting
Common Issues and Solutions
1. CUDA Out of Memory
# Solution: Reduce batch size or implement gradient checkpointing
= ls.LitServer(api, max_batch_size=2) # Reduce batch size
server
# Or clear cache in your predict method
def predict(self, x):
# Clear unused memory
torch.cuda.empty_cache() = self.model(x)
result return result
2. Slow Inference
# Enable model optimization
def setup(self, device):
self.model.eval() # Set to evaluation mode
self.model = torch.jit.script(self.model) # JIT compilation
# Use half precision if supported
if device.type == 'cuda':
self.model.half()
3. Request Timeout
# Increase timeout settings
= ls.LitServer(
server
api, =60, # Increase request timeout
timeout=1.0 # Increase batch timeout
batch_timeout )
4. Port Already in Use
# Check and kill existing processes
import subprocess
"lsof", "-ti:8000", "|", "xargs", "kill", "-9"], shell=True)
subprocess.run([
# Or use a different port
=8001) server.run(port
Performance Optimization Tips
- Use appropriate batch sizes: Start with small batches and gradually increase
- Enable GPU acceleration: Use
accelerator="auto"
for automatic GPU detection - Optimize model loading: Load models once in
setup()
, not inpredict()
- Use mixed precision: Enable autocast for GPU inference
- Profile your code: Use tools like
torch.profiler
to identify bottlenecks - Cache preprocessed data: Store frequently used transformations
Deployment Checklist
This guide covers the essential aspects of using LitServe for deploying AI models. For the most up-to-date information, always refer to the official LitServe documentation.