Skip to content

Performance Optimization

Production-ready optimization utilities for faster model interpretation including caching, batch processing, profiling, and model optimization.

Overview

AutoTimm provides four optimization utilities:

  1. ExplanationCache - Disk-based LRU cache (10-50x speedup)
  2. BatchProcessor - Efficient multi-image processing (2-5x speedup)
  3. PerformanceProfiler - Bottleneck identification
  4. optimize_for_inference() - Model optimization (1.5-3x speedup)

Combined speedup: Up to 100x faster

Installation

No additional dependencies required - uses standard Python and PyTorch.

1. Explanation Caching

Cache computed explanations to disk to avoid recomputation.

Basic Usage

import autotimm as at  # recommended alias
from autotimm.interpretation import GradCAM
from autotimm.interpretation.optimization import ExplanationCache

# Create cache
cache = ExplanationCache(
    cache_dir="./cache",
    max_size_mb=1000,  # 1GB cache
    enabled=True
)

explainer = GradCAM(model)

# Check cache before computing
heatmap = cache.get(image, method="gradcam", target_class=5)
if heatmap is None:
    # Cache miss - compute and store
    heatmap = explainer.explain(image, target_class=5)
    cache.put(image, method="gradcam", explanation=heatmap, target_class=5)
else:
    # Cache hit - much faster!
    pass

How It Works

  1. Computes SHA256 hash of (image + method + parameters)
  2. Checks if cached file exists on disk
  3. If miss: computes explanation and saves as .npy file
  4. If hit: loads from disk (much faster than computation)
  5. LRU eviction when cache is full

Cache Key Components

  • Image bytes (exact pixel values)
  • Method name (e.g., "gradcam")
  • Target class
  • Additional parameters

Performance

First call: ~50ms (compute) Cached calls: ~5-8ms (disk read) Speedup: 6-10x

Cache Management

# Get statistics
stats = cache.stats()
print(f"Entries: {stats['num_entries']}")
print(f"Size: {stats['total_size_mb']:.2f} MB")
print(f"Utilization: {stats['utilization']:.1%}")

# Clear cache
cache.clear()

Production Configuration

# For production systems
cache = ExplanationCache(
    cache_dir="/var/cache/explanations",
    max_size_mb=5000,  # 5GB
    enabled=True
)

# Monitor hit rate
stats = cache.stats()
if stats['utilization'] > 0.9:
    print("Warning: Cache nearly full, consider increasing size")

2. Batch Processing

Process multiple images efficiently with progress tracking.

Basic Usage

from autotimm.interpretation.optimization import BatchProcessor

# Create batch processor
processor = BatchProcessor(
    model,
    explainer,
    batch_size=16,  # Process 16 at a time
    use_cuda=True,
    show_progress=True
)

# Process 100 images efficiently
images = [load_image(f"img_{i}.jpg") for i in range(100)]
heatmaps = processor.process_batch(images)

Performance

Sequential: 40ms per image Batched (16): 20ms per image Batched (32): 15ms per image Speedup: 2-2.7x

Parallel Processing (CPU)

# CPU-based parallelism
heatmaps = processor.process_batch_parallel(
    images,
    num_workers=4
)

Batch Size Tuning

# Find optimal batch size for your hardware
for batch_size in [8, 16, 32, 64]:
    processor = BatchProcessor(model, explainer, batch_size=batch_size)
    start = time.time()
    processor.process_batch(test_images)
    elapsed = time.time() - start
    print(f"Batch size {batch_size}: {elapsed:.3f}s")

3. Performance Profiling

Identify bottlenecks in your interpretation pipeline.

Basic Usage

from autotimm.interpretation.optimization import PerformanceProfiler

profiler = PerformanceProfiler(enabled=True)

# Profile operations
with profiler.profile("preprocessing"):
    tensor = preprocess(image)

with profiler.profile("explanation"):
    heatmap = explainer.explain(image)

with profiler.profile("postprocessing"):
    result = postprocess(heatmap)

# Print statistics
profiler.print_stats()

Output

Performance Profile:
------------------------------------------------------------
Operation                      Count    Mean       Total
------------------------------------------------------------
preprocessing                  1        0.012      0.012
explanation                    1        0.045      0.045
postprocessing                 1        0.003      0.003
------------------------------------------------------------

Get Statistics Programmatically

stats = profiler.get_stats()

# Identify bottleneck
slowest = max(stats.items(), key=lambda x: x[1]['mean'])
print(f"Bottleneck: {slowest[0]} ({slowest[1]['mean']:.3f}s)")

# Alert if slow
if stats['explanation']['mean'] > 0.1:
    log_warning("Slow explanations detected")

Production Monitoring

# Set up profiler
profiler = PerformanceProfiler(enabled=True)

# In request handler
with profiler.profile("explanation_request"):
    heatmap = generate_explanation(image)

# Periodically check
stats = profiler.get_stats()
if stats['explanation_request']['mean'] > 1.0:
    log_warning("Slow explanations detected")
    # Maybe: increase cache size, optimize model, etc.

4. Model Optimization

Optimize model for faster inference.

Basic Usage

from autotimm.interpretation.optimization import optimize_for_inference

# Optimize model
model = optimize_for_inference(
    model,
    use_fp16=False  # Set True for GPU with FP16 support
)

# Now 1.5-3x faster inference

Optimizations Applied

  1. Disable Gradients:
    for param in model.parameters():
        param.requires_grad = False
    
  2. Saves memory
  3. Slightly faster forward pass

  4. cudnn Benchmarking:

    torch.backends.cudnn.benchmark = True
    

  5. Finds fastest convolution algorithm
  6. ~10-20% speedup on GPU

  7. FP16 (Optional):

    model = model.half()
    

  8. 2x less memory
  9. 1.5-2x faster on modern GPUs
  10. Requires FP16-compatible GPU

Performance

Baseline: 22ms per forward pass Optimized (CPU): 18ms per forward pass Optimized (GPU): 12ms per forward pass Optimized (GPU+FP16): 7ms per forward pass Speedup: 1.2-3.1x

Combined Optimization Strategy

Use all optimizations together for maximum performance.

Complete Example

from autotimm.interpretation import GradCAM
from autotimm.interpretation.optimization import (
    ExplanationCache,
    BatchProcessor,
    PerformanceProfiler,
    optimize_for_inference
)

# 1. Optimize model
model = optimize_for_inference(model, use_fp16=True)

# 2. Set up caching
cache = ExplanationCache(
    cache_dir="/var/cache/explanations",
    max_size_mb=5000
)

# 3. Enable profiling
profiler = PerformanceProfiler(enabled=True)

# 4. Create explainer and batch processor
explainer = GradCAM(model)
processor = BatchProcessor(model, explainer, batch_size=32)

# 5. Process with all optimizations
with profiler.profile("total"):
    heatmaps = []
    for img in images:
        # Try cache first
        heatmap = cache.get(img, method="gradcam")
        if heatmap is None:
            # Compute with optimized model
            heatmap = explainer.explain(img)
            cache.put(img, method="gradcam", explanation=heatmap)
        heatmaps.append(heatmap)

# 6. Monitor performance
profiler.print_stats()
print(f"Cache hit rate: {cache.stats()['utilization']:.1%}")

Performance Benchmarks

Summary Table

Optimization Speedup Use Case
Caching 10-50x Repeated images
Batch Processing 2-5x Multiple images
Model Optimization 1.5-3x All inference
Combined Up to 100x Production systems

Real-World Example

Scenario: Explain 1000 validation images

Without optimization:

1000 images × 50ms = 50,000ms = 50 seconds

With all optimizations:

Cache warm-up: 10 images × 50ms = 500ms
Remaining: 990 images × 5ms (cached) = 4,950ms
Total: 5.5 seconds
Speedup: 9x overall

Production Best Practices

1. Cache Configuration

# Production settings
cache = ExplanationCache(
    cache_dir="/var/cache/explanations",
    max_size_mb=5000,  # 5GB for high-traffic systems
    enabled=True
)

# Regular monitoring
def monitor_cache():
    stats = cache.stats()
    metrics.gauge('cache.entries', stats['num_entries'])
    metrics.gauge('cache.size_mb', stats['total_size_mb'])
    metrics.gauge('cache.utilization', stats['utilization'])

2. Error Handling

# Graceful degradation
try:
    heatmap = cache.get(image, method="gradcam")
    if heatmap is None:
        heatmap = explainer.explain(image)
        cache.put(image, method="gradcam", explanation=heatmap)
except Exception as e:
    log_error(f"Cache error: {e}")
    # Fallback: compute without cache
    heatmap = explainer.explain(image)

3. Resource Limits

# Set memory limits
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, hard))  # 4GB

# Monitor disk usage
import shutil
usage = shutil.disk_usage(cache.cache_dir)
if usage.free < 1024 * 1024 * 1024:  # Less than 1GB free
    log_warning("Low disk space, clearing old cache entries")
    cache.clear()

4. Logging & Monitoring

# Log cache performance
def log_cache_stats():
    stats = cache.stats()
    logger.info(
        "Cache stats",
        extra={
            'entries': stats['num_entries'],
            'size_mb': stats['total_size_mb'],
            'utilization': stats['utilization']
        }
    )

# Alert on anomalies
if profiler.get_stats()['explanation']['mean'] > threshold:
    alert.send("Slow explanations detected")

Common Pitfalls

1. Cache Size Too Small

Problem: High eviction rate, low hit rate Solution: Increase max_size_mb or clean up old entries

2. Wrong Batch Size

Problem: OOM errors or slow processing Solution: Tune batch size for your hardware

3. Profiling Overhead

Problem: Profiling slows down production Solution: Use sampling or disable in production:

profiler = PerformanceProfiler(enabled=config.DEBUG)

4. FP16 Compatibility

Problem: Model doesn't support FP16 Solution: Use use_fp16=False or update model

API Reference

ExplanationCache

cache = ExplanationCache(
    cache_dir="./explanation_cache",
    max_size_mb=1000,
    enabled=True
)

Methods:

  • get(image, method, target_class=None, **kwargs) Retrieve from cache
  • put(image, method, explanation, target_class=None, **kwargs) Store in cache
  • stats() Get cache statistics
  • clear() Clear entire cache

BatchProcessor

processor = BatchProcessor(
    model,
    explainer,
    batch_size=16,
    use_cuda=True,
    show_progress=True
)

Methods:

  • process_batch(images, target_classes=None, **kwargs) Process batch
  • process_batch_parallel(images, target_classes=None, num_workers=4, **kwargs) Parallel processing

PerformanceProfiler

profiler = PerformanceProfiler(enabled=True)

Methods:

  • profile(name) Context manager for profiling
  • get_stats() Get profiling statistics
  • print_stats() Print formatted statistics
  • reset() Reset profiling data

optimize_for_inference()

model = optimize_for_inference(
    model,
    use_fp16=False
)

Examples

See the complete example script:

examples/data_training/performance_optimization_demo.py

See the tutorial notebook:

examples/interpretation/comprehensive_interpretation_tutorial.ipynb

Next Steps