Performance Optimization¶
Production-ready optimization utilities for faster model interpretation including caching, batch processing, profiling, and model optimization.
Overview¶
AutoTimm provides four optimization utilities:
- ExplanationCache - Disk-based LRU cache (10-50x speedup)
- BatchProcessor - Efficient multi-image processing (2-5x speedup)
- PerformanceProfiler - Bottleneck identification
- optimize_for_inference() - Model optimization (1.5-3x speedup)
Combined speedup: Up to 100x faster
Installation¶
No additional dependencies required - uses standard Python and PyTorch.
1. Explanation Caching¶
Cache computed explanations to disk to avoid recomputation.
Basic Usage¶
import autotimm as at # recommended alias
from autotimm.interpretation import GradCAM
from autotimm.interpretation.optimization import ExplanationCache
# Create cache
cache = ExplanationCache(
cache_dir="./cache",
max_size_mb=1000, # 1GB cache
enabled=True
)
explainer = GradCAM(model)
# Check cache before computing
heatmap = cache.get(image, method="gradcam", target_class=5)
if heatmap is None:
# Cache miss - compute and store
heatmap = explainer.explain(image, target_class=5)
cache.put(image, method="gradcam", explanation=heatmap, target_class=5)
else:
# Cache hit - much faster!
pass
How It Works¶
- Computes SHA256 hash of (image + method + parameters)
- Checks if cached file exists on disk
- If miss: computes explanation and saves as
.npyfile - If hit: loads from disk (much faster than computation)
- LRU eviction when cache is full
Cache Key Components¶
- Image bytes (exact pixel values)
- Method name (e.g., "gradcam")
- Target class
- Additional parameters
Performance¶
First call: ~50ms (compute) Cached calls: ~5-8ms (disk read) Speedup: 6-10x
Cache Management¶
# Get statistics
stats = cache.stats()
print(f"Entries: {stats['num_entries']}")
print(f"Size: {stats['total_size_mb']:.2f} MB")
print(f"Utilization: {stats['utilization']:.1%}")
# Clear cache
cache.clear()
Production Configuration¶
# For production systems
cache = ExplanationCache(
cache_dir="/var/cache/explanations",
max_size_mb=5000, # 5GB
enabled=True
)
# Monitor hit rate
stats = cache.stats()
if stats['utilization'] > 0.9:
print("Warning: Cache nearly full, consider increasing size")
2. Batch Processing¶
Process multiple images efficiently with progress tracking.
Basic Usage¶
from autotimm.interpretation.optimization import BatchProcessor
# Create batch processor
processor = BatchProcessor(
model,
explainer,
batch_size=16, # Process 16 at a time
use_cuda=True,
show_progress=True
)
# Process 100 images efficiently
images = [load_image(f"img_{i}.jpg") for i in range(100)]
heatmaps = processor.process_batch(images)
Performance¶
Sequential: 40ms per image Batched (16): 20ms per image Batched (32): 15ms per image Speedup: 2-2.7x
Parallel Processing (CPU)¶
Batch Size Tuning¶
# Find optimal batch size for your hardware
for batch_size in [8, 16, 32, 64]:
processor = BatchProcessor(model, explainer, batch_size=batch_size)
start = time.time()
processor.process_batch(test_images)
elapsed = time.time() - start
print(f"Batch size {batch_size}: {elapsed:.3f}s")
3. Performance Profiling¶
Identify bottlenecks in your interpretation pipeline.
Basic Usage¶
from autotimm.interpretation.optimization import PerformanceProfiler
profiler = PerformanceProfiler(enabled=True)
# Profile operations
with profiler.profile("preprocessing"):
tensor = preprocess(image)
with profiler.profile("explanation"):
heatmap = explainer.explain(image)
with profiler.profile("postprocessing"):
result = postprocess(heatmap)
# Print statistics
profiler.print_stats()
Output¶
Performance Profile:
------------------------------------------------------------
Operation Count Mean Total
------------------------------------------------------------
preprocessing 1 0.012 0.012
explanation 1 0.045 0.045
postprocessing 1 0.003 0.003
------------------------------------------------------------
Get Statistics Programmatically¶
stats = profiler.get_stats()
# Identify bottleneck
slowest = max(stats.items(), key=lambda x: x[1]['mean'])
print(f"Bottleneck: {slowest[0]} ({slowest[1]['mean']:.3f}s)")
# Alert if slow
if stats['explanation']['mean'] > 0.1:
log_warning("Slow explanations detected")
Production Monitoring¶
# Set up profiler
profiler = PerformanceProfiler(enabled=True)
# In request handler
with profiler.profile("explanation_request"):
heatmap = generate_explanation(image)
# Periodically check
stats = profiler.get_stats()
if stats['explanation_request']['mean'] > 1.0:
log_warning("Slow explanations detected")
# Maybe: increase cache size, optimize model, etc.
4. Model Optimization¶
Optimize model for faster inference.
Basic Usage¶
from autotimm.interpretation.optimization import optimize_for_inference
# Optimize model
model = optimize_for_inference(
model,
use_fp16=False # Set True for GPU with FP16 support
)
# Now 1.5-3x faster inference
Optimizations Applied¶
- Disable Gradients:
- Saves memory
-
Slightly faster forward pass
-
cudnn Benchmarking:
- Finds fastest convolution algorithm
-
~10-20% speedup on GPU
-
FP16 (Optional):
- 2x less memory
- 1.5-2x faster on modern GPUs
- Requires FP16-compatible GPU
Performance¶
Baseline: 22ms per forward pass Optimized (CPU): 18ms per forward pass Optimized (GPU): 12ms per forward pass Optimized (GPU+FP16): 7ms per forward pass Speedup: 1.2-3.1x
Combined Optimization Strategy¶
Use all optimizations together for maximum performance.
Complete Example¶
from autotimm.interpretation import GradCAM
from autotimm.interpretation.optimization import (
ExplanationCache,
BatchProcessor,
PerformanceProfiler,
optimize_for_inference
)
# 1. Optimize model
model = optimize_for_inference(model, use_fp16=True)
# 2. Set up caching
cache = ExplanationCache(
cache_dir="/var/cache/explanations",
max_size_mb=5000
)
# 3. Enable profiling
profiler = PerformanceProfiler(enabled=True)
# 4. Create explainer and batch processor
explainer = GradCAM(model)
processor = BatchProcessor(model, explainer, batch_size=32)
# 5. Process with all optimizations
with profiler.profile("total"):
heatmaps = []
for img in images:
# Try cache first
heatmap = cache.get(img, method="gradcam")
if heatmap is None:
# Compute with optimized model
heatmap = explainer.explain(img)
cache.put(img, method="gradcam", explanation=heatmap)
heatmaps.append(heatmap)
# 6. Monitor performance
profiler.print_stats()
print(f"Cache hit rate: {cache.stats()['utilization']:.1%}")
Performance Benchmarks¶
Summary Table¶
| Optimization | Speedup | Use Case |
|---|---|---|
| Caching | 10-50x | Repeated images |
| Batch Processing | 2-5x | Multiple images |
| Model Optimization | 1.5-3x | All inference |
| Combined | Up to 100x | Production systems |
Real-World Example¶
Scenario: Explain 1000 validation images
Without optimization:
With all optimizations:
Cache warm-up: 10 images × 50ms = 500ms
Remaining: 990 images × 5ms (cached) = 4,950ms
Total: 5.5 seconds
Speedup: 9x overall
Production Best Practices¶
1. Cache Configuration¶
# Production settings
cache = ExplanationCache(
cache_dir="/var/cache/explanations",
max_size_mb=5000, # 5GB for high-traffic systems
enabled=True
)
# Regular monitoring
def monitor_cache():
stats = cache.stats()
metrics.gauge('cache.entries', stats['num_entries'])
metrics.gauge('cache.size_mb', stats['total_size_mb'])
metrics.gauge('cache.utilization', stats['utilization'])
2. Error Handling¶
# Graceful degradation
try:
heatmap = cache.get(image, method="gradcam")
if heatmap is None:
heatmap = explainer.explain(image)
cache.put(image, method="gradcam", explanation=heatmap)
except Exception as e:
log_error(f"Cache error: {e}")
# Fallback: compute without cache
heatmap = explainer.explain(image)
3. Resource Limits¶
# Set memory limits
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, hard)) # 4GB
# Monitor disk usage
import shutil
usage = shutil.disk_usage(cache.cache_dir)
if usage.free < 1024 * 1024 * 1024: # Less than 1GB free
log_warning("Low disk space, clearing old cache entries")
cache.clear()
4. Logging & Monitoring¶
# Log cache performance
def log_cache_stats():
stats = cache.stats()
logger.info(
"Cache stats",
extra={
'entries': stats['num_entries'],
'size_mb': stats['total_size_mb'],
'utilization': stats['utilization']
}
)
# Alert on anomalies
if profiler.get_stats()['explanation']['mean'] > threshold:
alert.send("Slow explanations detected")
Common Pitfalls¶
1. Cache Size Too Small¶
Problem: High eviction rate, low hit rate
Solution: Increase max_size_mb or clean up old entries
2. Wrong Batch Size¶
Problem: OOM errors or slow processing Solution: Tune batch size for your hardware
3. Profiling Overhead¶
Problem: Profiling slows down production Solution: Use sampling or disable in production:
4. FP16 Compatibility¶
Problem: Model doesn't support FP16
Solution: Use use_fp16=False or update model
API Reference¶
ExplanationCache¶
Methods:
get(image, method, target_class=None, **kwargs)Retrieve from cacheput(image, method, explanation, target_class=None, **kwargs)Store in cachestats()Get cache statisticsclear()Clear entire cache
BatchProcessor¶
Methods:
process_batch(images, target_classes=None, **kwargs)Process batchprocess_batch_parallel(images, target_classes=None, num_workers=4, **kwargs)Parallel processing
PerformanceProfiler¶
Methods:
profile(name)Context manager for profilingget_stats()Get profiling statisticsprint_stats()Print formatted statisticsreset()Reset profiling data
optimize_for_inference()¶
Examples¶
See the complete example script:
See the tutorial notebook:
Next Steps¶
- Learn about Interactive Visualizations
- Explore Quality Metrics
- See Interpretation Methods