LitServe is a high-performance, flexible AI model serving framework designed to deploy machine learning models with minimal code. It provides automatic batching, GPU acceleration, and easy scaling capabilities.

Installation

# Install LitServe
pip install litserve

# For GPU support
pip install litserve[gpu]

# For development dependencies
pip install litserve[dev]

Basic Usage

1. Creating Your First LitServe API

import litserve as ls
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class SimpleTextGenerator(ls.LitAPI):
    def setup(self, device):
        # Load model and tokenizer during server startup
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.model = AutoModelForCausalLM.from_pretrained("gpt2")
        self.model.to(device)
        self.model.eval()
    
    def decode_request(self, request):
        # Process incoming request
        return request["prompt"]
    
    def predict(self, x):
        # Run inference
        inputs = self.tokenizer.encode(x, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model.generate(
                inputs, 
                max_length=100, 
                num_return_sequences=1,
                temperature=0.7
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def encode_response(self, output):
        # Format response
        return {"generated_text": output}

# Create and start server
if __name__ == "__main__":
    api = SimpleTextGenerator()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
    server.run(port=8000)

2. Making Requests to Your API

import requests

# Test the API
response = requests.post(
    "http://localhost:8000/predict",
    json={"prompt": "The future of AI is"}
)
print(response.json())

Core Concepts

LitAPI Class Structure

Every LitServe API must inherit from ls.LitAPI and implement these core methods:

class MyAPI(ls.LitAPI):
    def setup(self, device):
        """Initialize models, load weights, set up preprocessing"""
        pass
    
    def decode_request(self, request):
        """Parse and validate incoming requests"""
        pass
    
    def predict(self, x):
        """Run model inference"""
        pass
    
    def encode_response(self, output):
        """Format model output for HTTP response"""
        pass

Optional Methods

class AdvancedAPI(ls.LitAPI):
    def batch(self, inputs):
        """Custom batching logic (optional)"""
        return inputs
    
    def unbatch(self, output):
        """Custom unbatching logic (optional)"""
        return output
    
    def preprocess(self, input_data):
        """Additional preprocessing (optional)"""
        return input_data
    
    def postprocess(self, output):
        """Additional postprocessing (optional)"""
        return output

Advanced Features

1. Custom Batching

class BatchedImageClassifier(ls.LitAPI):
    def setup(self, device):
        from torchvision import models, transforms
        self.model = models.resnet50(pretrained=True)
        self.model.to(device)
        self.model.eval()
        
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    
    def decode_request(self, request):
        from PIL import Image
        import base64
        import io
        
        # Decode base64 image
        image_data = base64.b64decode(request["image"])
        image = Image.open(io.BytesIO(image_data))
        return self.transform(image).unsqueeze(0)
    
    def batch(self, inputs):
        # Custom batching for images
        return torch.cat(inputs, dim=0)
    
    def predict(self, batch):
        with torch.no_grad():
            outputs = self.model(batch)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
        return probabilities
    
    def unbatch(self, output):
        # Split batch back to individual predictions
        return [pred.unsqueeze(0) for pred in output]
    
    def encode_response(self, output):
        # Get top prediction
        confidence, predicted = torch.max(output, 1)
        return {
            "class_id": predicted.item(),
            "confidence": confidence.item()
        }

2. Streaming Responses

class StreamingChatAPI(ls.LitAPI):
    def setup(self, device):
        from transformers import AutoTokenizer, AutoModelForCausalLM
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
        self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
        self.model.to(device)
    
    def decode_request(self, request):
        return request["message"]
    
    def predict(self, x):
        # Generator for streaming
        inputs = self.tokenizer.encode(x, return_tensors="pt")
        
        for i in range(50):  # Generate up to 50 tokens
            with torch.no_grad():
                outputs = self.model(inputs)
                next_token_logits = outputs.logits[:, -1, :]
                next_token = torch.multinomial(
                    torch.softmax(next_token_logits, dim=-1), 
                    num_samples=1
                )
                inputs = torch.cat([inputs, next_token], dim=-1)
                
                # Yield each token
                token_text = self.tokenizer.decode(next_token[0])
                yield {"token": token_text}
                
                # Stop if end token
                if next_token.item() == self.tokenizer.eos_token_id:
                    break
    
    def encode_response(self, output):
        return output

# Enable streaming
server = ls.LitServer(api, accelerator="auto", stream=True)

3. Multiple GPU Support

# Automatic multi-GPU scaling
server = ls.LitServer(
    api, 
    accelerator="auto",
    devices="auto",  # Use all available GPUs
    max_batch_size=8
)

# Specify specific GPUs
server = ls.LitServer(
    api,
    accelerator="gpu",
    devices=[0, 1, 2],  # Use GPUs 0, 1, and 2
    max_batch_size=16
)

4. Authentication and Middleware

class AuthenticatedAPI(ls.LitAPI):
    def setup(self, device):
        # Your model setup
        pass
    
    def authenticate(self, request):
        """Custom authentication logic"""
        api_key = request.headers.get("Authorization")
        if not api_key or not self.validate_api_key(api_key):
            raise ls.AuthenticationError("Invalid API key")
        return True
    
    def validate_api_key(self, api_key):
        # Your API key validation logic
        valid_keys = ["your-secret-key-1", "your-secret-key-2"]
        return api_key.replace("Bearer ", "") in valid_keys
    
    def decode_request(self, request):
        return request["data"]
    
    def predict(self, x):
        # Your prediction logic
        return f"Processed: {x}"
    
    def encode_response(self, output):
        return {"result": output}

Configuration Options

Server Configuration

server = ls.LitServer(
    api=api,
    accelerator="auto",           # "auto", "cpu", "gpu", "mps"
    devices="auto",               # Device selection
    max_batch_size=4,            # Maximum batch size
    batch_timeout=0.1,           # Batch timeout in seconds
    workers_per_device=1,        # Workers per device
    timeout=30,                  # Request timeout
    stream=False,                # Enable streaming
    spec=None,                   # Custom OpenAPI spec
)

Environment Variables

# Set device preferences
export CUDA_VISIBLE_DEVICES=0,1,2

# Set batch configuration
export LITSERVE_MAX_BATCH_SIZE=8
export LITSERVE_BATCH_TIMEOUT=0.05

# Set worker configuration
export LITSERVE_WORKERS_PER_DEVICE=2

Custom Configuration File

# config.py
class Config:
    MAX_BATCH_SIZE = 8
    BATCH_TIMEOUT = 0.1
    WORKERS_PER_DEVICE = 2
    ACCELERATOR = "auto"
    TIMEOUT = 60

# Use in your API
from config import Config

server = ls.LitServer(
    api, 
    max_batch_size=Config.MAX_BATCH_SIZE,
    batch_timeout=Config.BATCH_TIMEOUT,
    workers_per_device=Config.WORKERS_PER_DEVICE
)

Examples

1. Image Classification API

import litserve as ls
import torch
from torchvision import models, transforms
from PIL import Image
import base64
import io

class ImageClassificationAPI(ls.LitAPI):
    def setup(self, device):
        # Load pre-trained ResNet model
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.model.to(device)
        self.model.eval()
        
        # Image preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225]),
        ])
        
        # Load ImageNet class labels
        with open('imagenet_classes.txt') as f:
            self.classes = [line.strip() for line in f.readlines()]
    
    def decode_request(self, request):
        # Decode base64 image
        encoded_image = request["image"]
        image_bytes = base64.b64decode(encoded_image)
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        return self.preprocess(image).unsqueeze(0)
    
    def predict(self, x):
        with torch.no_grad():
            output = self.model(x)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
        return probabilities
    
    def encode_response(self, output):
        # Get top 5 predictions
        top5_prob, top5_catid = torch.topk(output, 5)
        results = []
        for i in range(top5_prob.size(0)):
            results.append({
                "class": self.classes[top5_catid[i]],
                "probability": float(top5_prob[i])
            })
        return {"predictions": results}

if __name__ == "__main__":
    api = ImageClassificationAPI()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=4)
    server.run(port=8000)

2. Text Embedding API

import litserve as ls
from sentence_transformers import SentenceTransformer
import numpy as np

class TextEmbeddingAPI(ls.LitAPI):
    def setup(self, device):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.model.to(device)
    
    def decode_request(self, request):
        texts = request.get("texts", [])
        if isinstance(texts, str):
            texts = [texts]
        return texts
    
    def predict(self, texts):
        embeddings = self.model.encode(texts)
        return embeddings
    
    def encode_response(self, embeddings):
        return {
            "embeddings": embeddings.tolist(),
            "dimension": embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings)
        }

if __name__ == "__main__":
    api = TextEmbeddingAPI()
    server = ls.LitServer(api, accelerator="auto", max_batch_size=32)
    server.run(port=8000)

3. Multi-Modal API (Text + Image)

import litserve as ls
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import base64
import io

class ImageCaptioningAPI(ls.LitAPI):
    def setup(self, device):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model.to(device)
    
    def decode_request(self, request):
        # Handle both image and optional text input
        encoded_image = request["image"]
        text_prompt = request.get("text", "")
        
        image_bytes = base64.b64decode(encoded_image)
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        return {"image": image, "text": text_prompt}
    
    def predict(self, inputs):
        processed = self.processor(
            images=inputs["image"], 
            text=inputs["text"], 
            return_tensors="pt"
        )
        
        with torch.no_grad():
            outputs = self.model.generate(**processed, max_length=50)
        
        caption = self.processor.decode(outputs[0], skip_special_tokens=True)
        return caption
    
    def encode_response(self, caption):
        return {"caption": caption}

Best Practices

1. Resource Management

class OptimizedAPI(ls.LitAPI):
    def setup(self, device):
        # Use torch.jit.script for optimization
        self.model = torch.jit.script(your_model)
        
        # Enable mixed precision if using GPU
        if device.type == 'cuda':
            self.scaler = torch.cuda.amp.GradScaler()
        
        # Pre-allocate tensors for common shapes
        self.common_shapes = {}
    
    def predict(self, x):
        # Use autocast for mixed precision
        if hasattr(self, 'scaler'):
            with torch.cuda.amp.autocast():
                return self.model(x)
        return self.model(x)

2. Error Handling

class RobustAPI(ls.LitAPI):
    def decode_request(self, request):
        try:
            # Validate required fields
            if "input" not in request:
                raise ls.ValidationError("Missing required field: input")
            
            data = request["input"]
            
            # Type validation
            if not isinstance(data, (str, list)):
                raise ls.ValidationError("Input must be string or list")
            
            return data
        except Exception as e:
            raise ls.ValidationError(f"Request parsing failed: {str(e)}")
    
    def predict(self, x):
        try:
            result = self.model(x)
            return result
        except torch.cuda.OutOfMemoryError:
            raise ls.ServerError("GPU memory exhausted")
        except Exception as e:
            raise ls.ServerError(f"Prediction failed: {str(e)}")

3. Monitoring and Logging

import logging
import time

class MonitoredAPI(ls.LitAPI):
    def setup(self, device):
        self.logger = logging.getLogger(__name__)
        self.request_count = 0
        # Your model setup
    
    def decode_request(self, request):
        self.request_count += 1
        self.logger.info(f"Processing request #{self.request_count}")
        return request["data"]
    
    def predict(self, x):
        start_time = time.time()
        result = self.model(x)
        inference_time = time.time() - start_time
        
        self.logger.info(f"Inference completed in {inference_time:.3f}s")
        return result

4. Model Versioning

class VersionedAPI(ls.LitAPI):
    def setup(self, device):
        self.version = "1.0.0"
        self.model_path = f"models/model_v{self.version}"
        # Load versioned model
    
    def encode_response(self, output):
        return {
            "result": output,
            "model_version": self.version,
            "timestamp": time.time()
        }

Troubleshooting

Common Issues and Solutions

1. CUDA Out of Memory

# Solution: Reduce batch size or implement gradient checkpointing
server = ls.LitServer(api, max_batch_size=2)  # Reduce batch size

# Or clear cache in your predict method
def predict(self, x):
    torch.cuda.empty_cache()  # Clear unused memory
    result = self.model(x)
    return result

2. Slow Inference

# Enable model optimization
def setup(self, device):
    self.model.eval()  # Set to evaluation mode
    self.model = torch.jit.script(self.model)  # JIT compilation
    
    # Use half precision if supported
    if device.type == 'cuda':
        self.model.half()

3. Request Timeout

# Increase timeout settings
server = ls.LitServer(
    api, 
    timeout=60,  # Increase request timeout
    batch_timeout=1.0  # Increase batch timeout
)

4. Port Already in Use

# Check and kill existing processes
import subprocess
subprocess.run(["lsof", "-ti:8000", "|", "xargs", "kill", "-9"], shell=True)

# Or use a different port
server.run(port=8001)

Performance Optimization Tips

  1. Use appropriate batch sizes: Start with small batches and gradually increase
  2. Enable GPU acceleration: Use accelerator="auto" for automatic GPU detection
  3. Optimize model loading: Load models once in setup(), not in predict()
  4. Use mixed precision: Enable autocast for GPU inference
  5. Profile your code: Use tools like torch.profiler to identify bottlenecks
  6. Cache preprocessed data: Store frequently used transformations

Deployment Checklist

This guide covers the essential aspects of using LitServe for deploying AI models. For the most up-to-date information, always refer to the official LitServe documentation.