Skip to content

Inference

This guide covers how to use trained AutoTimm models for inference and prediction.

Inference Pipeline

Load & Preprocess

graph LR
    A[Trained Model] --> B[Load Checkpoint]
    B --> C[Set Eval Mode]
    C --> D[Preprocess Image<br/>Resize + Normalize + Tensor]
    D --> E{Task-Specific<br/>Inference}

    style A fill:#1565C0,stroke:#0D47A1
    style B fill:#1976D2,stroke:#1565C0
    style C fill:#1976D2,stroke:#1565C0
    style D fill:#1976D2,stroke:#1565C0
    style E fill:#FF9800,stroke:#F57C00

Task-Specific Output

graph LR
    A[Forward Pass] --> B1[Classification<br/>Softmax → Top-K]
    A --> B2[Multi-Label<br/>Sigmoid → Threshold]
    A --> B3[Detection<br/>Decode → NMS]
    A --> B4[Segmentation<br/>Argmax → Mask]
    B1 --> C[Results]
    B2 --> C
    B3 --> C
    B4 --> C

    style A fill:#1565C0,stroke:#0D47A1
    style B1 fill:#1976D2,stroke:#1565C0
    style B2 fill:#1976D2,stroke:#1565C0
    style B3 fill:#1976D2,stroke:#1565C0
    style B4 fill:#1976D2,stroke:#1565C0
    style C fill:#4CAF50,stroke:#388E3C

Quick Start

Classification

import autotimm as at  # recommended alias
from autotimm import ImageClassifier, MetricConfig, TransformConfig
import torch
from PIL import Image

metrics = [MetricConfig(name="accuracy", backend="torchmetrics",
                        metric_class="Accuracy", params={"task": "multiclass"},
                        stages=["val"])]
model = ImageClassifier.load_from_checkpoint(
    "checkpoint.ckpt",
    backbone="resnet50",
    num_classes=10,
    compile_model=False,                 # skip compilation for inference
    metrics=metrics,                     # not saved in checkpoint
    transform_config=TransformConfig(),  # not saved in checkpoint
)
model.eval()

image = Image.open("test.jpg").convert("RGB")
with torch.inference_mode():
    logits = model(model.preprocess(image))
    predicted_class = logits.argmax(dim=1).item()

Multi-Label Classification

from autotimm import ImageClassifier, MetricConfig, TransformConfig
import torch
from PIL import Image

model = ImageClassifier.load_from_checkpoint(
    "multilabel.ckpt",
    backbone="resnet50",
    num_classes=4,
    multi_label=True,
    threshold=0.5,
    compile_model=False,                 # skip compilation for inference
    transform_config=TransformConfig(),  # not saved in checkpoint
)
model.eval()

image = Image.open("test.jpg").convert("RGB")
with torch.inference_mode():
    logits = model(model.preprocess(image))
    probs = logits.sigmoid().squeeze(0)        # per-label probabilities
    predicted = (probs > 0.5).nonzero().squeeze(-1).tolist()

Object Detection

from autotimm import ObjectDetector, MetricConfig, TransformConfig

model = ObjectDetector.load_from_checkpoint(
    "detector.ckpt",
    backbone="resnet50",
    num_classes=80,
    compile_model=False,                                # skip compilation for inference
    transform_config=TransformConfig(image_size=640),   # not saved in checkpoint
)
model.eval()

with torch.inference_mode():
    detections = model.predict_step(model.preprocess(image), batch_idx=0)

Model Export

# TorchScript
traced = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
traced.save("model.pt")

# ONNX
torch.onnx.export(model, torch.randn(1, 3, 224, 224), "model.onnx")

Performance Tips

Optimization Speed Gain When to Use
Batch processing 2-5x Multiple images
GPU inference 10-50x GPU available
FP16 precision 2-3x Tensor core GPUs
TorchScript 10-20% Production
ONNX Runtime 20-40% Cross-platform

Checklist:

  • Use model.eval() before inference
  • Wrap predictions in torch.inference_mode()
  • Process multiple images in batches
  • Use GPU when available

Detailed Guides