Inference¶
This guide covers how to use trained AutoTimm models for inference and prediction.
Inference Pipeline¶
Load & Preprocess¶
graph LR
A[Trained Model] --> B[Load Checkpoint]
B --> C[Set Eval Mode]
C --> D[Preprocess Image<br/>Resize + Normalize + Tensor]
D --> E{Task-Specific<br/>Inference}
style A fill:#1565C0,stroke:#0D47A1
style B fill:#1976D2,stroke:#1565C0
style C fill:#1976D2,stroke:#1565C0
style D fill:#1976D2,stroke:#1565C0
style E fill:#FF9800,stroke:#F57C00
Task-Specific Output¶
graph LR
A[Forward Pass] --> B1[Classification<br/>Softmax → Top-K]
A --> B2[Multi-Label<br/>Sigmoid → Threshold]
A --> B3[Detection<br/>Decode → NMS]
A --> B4[Segmentation<br/>Argmax → Mask]
B1 --> C[Results]
B2 --> C
B3 --> C
B4 --> C
style A fill:#1565C0,stroke:#0D47A1
style B1 fill:#1976D2,stroke:#1565C0
style B2 fill:#1976D2,stroke:#1565C0
style B3 fill:#1976D2,stroke:#1565C0
style B4 fill:#1976D2,stroke:#1565C0
style C fill:#4CAF50,stroke:#388E3C
Quick Links¶
- Classification Inference - Image classification model inference
- Object Detection Inference - Object detection model inference
- Semantic Segmentation Inference - Semantic segmentation model inference
- Model Export - Export to TorchScript, ONNX, and quantization
Quick Start¶
Classification¶
import autotimm as at # recommended alias
from autotimm import ImageClassifier, MetricConfig, TransformConfig
import torch
from PIL import Image
metrics = [MetricConfig(name="accuracy", backend="torchmetrics",
metric_class="Accuracy", params={"task": "multiclass"},
stages=["val"])]
model = ImageClassifier.load_from_checkpoint(
"checkpoint.ckpt",
backbone="resnet50",
num_classes=10,
compile_model=False, # skip compilation for inference
metrics=metrics, # not saved in checkpoint
transform_config=TransformConfig(), # not saved in checkpoint
)
model.eval()
image = Image.open("test.jpg").convert("RGB")
with torch.inference_mode():
logits = model(model.preprocess(image))
predicted_class = logits.argmax(dim=1).item()
Multi-Label Classification¶
from autotimm import ImageClassifier, MetricConfig, TransformConfig
import torch
from PIL import Image
model = ImageClassifier.load_from_checkpoint(
"multilabel.ckpt",
backbone="resnet50",
num_classes=4,
multi_label=True,
threshold=0.5,
compile_model=False, # skip compilation for inference
transform_config=TransformConfig(), # not saved in checkpoint
)
model.eval()
image = Image.open("test.jpg").convert("RGB")
with torch.inference_mode():
logits = model(model.preprocess(image))
probs = logits.sigmoid().squeeze(0) # per-label probabilities
predicted = (probs > 0.5).nonzero().squeeze(-1).tolist()
Object Detection¶
from autotimm import ObjectDetector, MetricConfig, TransformConfig
model = ObjectDetector.load_from_checkpoint(
"detector.ckpt",
backbone="resnet50",
num_classes=80,
compile_model=False, # skip compilation for inference
transform_config=TransformConfig(image_size=640), # not saved in checkpoint
)
model.eval()
with torch.inference_mode():
detections = model.predict_step(model.preprocess(image), batch_idx=0)
Model Export¶
# TorchScript
traced = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
traced.save("model.pt")
# ONNX
torch.onnx.export(model, torch.randn(1, 3, 224, 224), "model.onnx")
Performance Tips¶
| Optimization | Speed Gain | When to Use |
|---|---|---|
| Batch processing | 2-5x | Multiple images |
| GPU inference | 10-50x | GPU available |
| FP16 precision | 2-3x | Tensor core GPUs |
| TorchScript | 10-20% | Production |
| ONNX Runtime | 20-40% | Cross-platform |
Checklist:
- Use
model.eval()before inference - Wrap predictions in
torch.inference_mode() - Process multiple images in batches
- Use GPU when available
Detailed Guides¶
- Classification Inference - Single/batch prediction, top-K, multi-label, pipelines
- Object Detection Inference - Bounding boxes, visualization, NMS tuning
- Model Export - TorchScript, ONNX, quantization