Skip to content

Classification Model Inference

This guide covers how to use trained classification models for inference and prediction.

Loading a Trained Model

import autotimm as at  # recommended alias
import torch
from autotimm import ImageClassifier, MetricConfig, TransformConfig

# Define metrics (required for loading)
metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
    ),
]

# Load from checkpoint with TransformConfig for preprocessing
model = ImageClassifier.load_from_checkpoint(
    "path/to/checkpoint.ckpt",
    backbone="resnet50",
    num_classes=10,
    compile_model=False,                 # skip compilation for inference
    metrics=metrics,                     # not saved in checkpoint
    transform_config=TransformConfig(),  # not saved in checkpoint; enables preprocessing
)
model.eval()

Use the built-in preprocess() method for correct model-specific normalization:

from PIL import Image

# Load image
image = Image.open("image.jpg")

# Preprocess using model's native normalization
input_tensor = model.preprocess(image)  # Returns (1, 3, 224, 224)

# Predict
with torch.inference_mode():
    logits = model(input_tensor)
    probabilities = torch.softmax(logits, dim=1)
    predicted_class = probabilities.argmax(dim=1).item()
    confidence = probabilities.max().item()

print(f"Predicted class: {predicted_class}")
print(f"Confidence: {confidence:.2%}")

Single Image Prediction (Manual)

If you need manual control over transforms:

from PIL import Image
from torchvision import transforms

# Prepare transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and transform image
image = Image.open("image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Predict
with torch.inference_mode():
    logits = model(input_tensor)
    probabilities = torch.softmax(logits, dim=1)
    predicted_class = probabilities.argmax(dim=1).item()
    confidence = probabilities.max().item()

print(f"Predicted class: {predicted_class}")
print(f"Confidence: {confidence:.2%}")

Tip: Use model.get_data_config() to get the correct normalization values:

config = model.get_data_config()
print(f"Mean: {config['mean']}")  # Use these values in your transforms
print(f"Std: {config['std']}")

Batch Prediction

from PIL import Image

# Load multiple images
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg", "img4.jpg"]
images = [Image.open(p) for p in image_paths]

# Preprocess all at once
input_tensor = model.preprocess(images)  # Returns (4, 3, 224, 224)

# Predict
model.eval()
with torch.inference_mode():
    logits = model(input_tensor)
    probs = torch.softmax(logits, dim=1)
    preds = probs.argmax(dim=1)

for path, pred, prob in zip(image_paths, preds, probs):
    print(f"{path}: class {pred.item()} ({prob[pred].item():.2%})")

Using DataLoader

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Prepare dataset
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = datasets.ImageFolder("path/to/images", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

# Predict all
model.eval()
all_predictions = []
all_probabilities = []

with torch.inference_mode():
    for images, _ in dataloader:
        logits = model(images)
        probs = torch.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)

        all_predictions.extend(preds.tolist())
        all_probabilities.extend(probs.tolist())

Using Trainer.predict

from autotimm import AutoTrainer, ImageDataModule

data = ImageDataModule(
    data_dir="./test_images",
    image_size=224,
    batch_size=32,
)
data.setup("test")

trainer = AutoTrainer()
predictions = trainer.predict(model, dataloaders=data.test_dataloader())

# predictions is a list of batched probability tensors
all_probs = torch.cat(predictions, dim=0)
all_preds = all_probs.argmax(dim=1)

GPU Inference

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

with torch.inference_mode():
    input_tensor = input_tensor.to(device)
    logits = model(input_tensor)
    probabilities = torch.softmax(logits, dim=1)

Performance Tips:

  • Always move model to GPU before inference
  • Use .to(device, non_blocking=True) for faster transfers
  • Process larger batches on GPU for better efficiency

Top-K Predictions

Get the top K most likely classes:

def get_topk_predictions(model, image_tensor, k=5, class_names=None):
    """Get top-k predictions with class names and probabilities."""
    model.eval()
    with torch.inference_mode():
        logits = model(image_tensor)
        probs = torch.softmax(logits, dim=1)
        topk_probs, topk_indices = probs.topk(k, dim=1)

    results = []
    for i in range(k):
        idx = topk_indices[0, i].item()
        prob = topk_probs[0, i].item()
        name = class_names[idx] if class_names else str(idx)
        results.append({"class": name, "probability": prob})

    return results

# Usage
results = get_topk_predictions(
    model, 
    input_tensor, 
    k=5, 
    class_names=["cat", "dog", "bird", "fish", "hamster"]
)

for r in results:
    print(f"{r['class']}: {r['probability']:.2%}")

Output Example:

cat: 95.32%
dog: 3.21%
hamster: 1.15%
bird: 0.28%
fish: 0.04%


Complete Inference Pipeline

Production-ready inference pipeline with TransformConfig:

import torch
from PIL import Image
from autotimm import ImageClassifier, MetricConfig, TransformConfig


class InferencePipeline:
    """End-to-end inference pipeline for classification."""

    def __init__(self, checkpoint_path, backbone, num_classes, class_names=None):
        # Define minimal metrics for loading
        metrics = [
            MetricConfig(
                name="accuracy",
                backend="torchmetrics",
                metric_class="Accuracy",
                params={"task": "multiclass"},
                stages=["val"],
            ),
        ]

        # Load model with TransformConfig for preprocessing
        self.model = ImageClassifier.load_from_checkpoint(
            checkpoint_path,
            backbone=backbone,
            num_classes=num_classes,
            compile_model=False,                 # skip compilation for inference
            metrics=metrics,                     # not saved in checkpoint
            transform_config=TransformConfig(),  # not saved in checkpoint; enables preprocess()
        )
        self.model.eval()

        # Setup device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        self.class_names = class_names
        self.num_classes = num_classes

    def predict(self, image_path, top_k=1):
        """Predict class for a single image."""
        # Load and preprocess image using model's native normalization
        image = Image.open(image_path).convert("RGB")
        input_tensor = self.model.preprocess(image).to(self.device)

        # Predict
        with torch.inference_mode():
            logits = self.model(input_tensor)
            probs = torch.softmax(logits, dim=1)

        # Get top-k results
        topk_probs, topk_indices = probs.topk(top_k, dim=1)

        results = []
        for i in range(top_k):
            idx = topk_indices[0, i].item()
            prob = topk_probs[0, i].item()
            class_name = self.class_names[idx] if self.class_names else idx

            results.append({
                "class": class_name,
                "class_index": idx,
                "confidence": prob,
            })

        # Return single result or list
        if top_k == 1:
            return results[0]
        return results

    def predict_batch(self, image_paths, top_k=1):
        """Predict classes for multiple images."""
        return [self.predict(path, top_k=top_k) for path in image_paths]

    def get_all_probabilities(self, image_path):
        """Get probabilities for all classes."""
        image = Image.open(image_path).convert("RGB")
        input_tensor = self.transform(image).unsqueeze(0).to(self.device)

        with torch.inference_mode():
            logits = self.model(input_tensor)
            probs = torch.softmax(logits, dim=1)

        return probs[0].cpu().tolist()


# Usage
pipeline = InferencePipeline(
    checkpoint_path="best-model.ckpt",
    backbone="resnet50",
    num_classes=10,
    class_names=["cat", "dog", "bird", "fish", "hamster", "rabbit", "mouse", "snake", "turtle", "frog"],
)

# Single prediction
result = pipeline.predict("test_image.jpg")
print(f"Predicted: {result['class']} ({result['confidence']:.2%})")

# Top-5 predictions
top5 = pipeline.predict("test_image.jpg", top_k=5)
for i, r in enumerate(top5, 1):
    print(f"{i}. {r['class']}: {r['confidence']:.2%}")

# Batch predictions
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
results = pipeline.predict_batch(image_paths)
for path, result in zip(image_paths, results):
    print(f"{path}: {result['class']} ({result['confidence']:.2%})")

Multi-Label Inference

Multi-label models output per-label sigmoid probabilities (each in [0, 1], independent of each other) instead of softmax probabilities.

Loading a Multi-Label Model

import torch
from autotimm import ImageClassifier, MetricConfig, TransformConfig

metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="MultilabelAccuracy",
        params={"num_labels": 4},
        stages=["val"],
    ),
]

model = ImageClassifier.load_from_checkpoint(
    "path/to/checkpoint.ckpt",
    backbone="resnet50",
    num_classes=4,
    multi_label=True,
    threshold=0.5,
    compile_model=False,                 # skip compilation for inference
    metrics=metrics,                     # not saved in checkpoint
    transform_config=TransformConfig(),  # not saved in checkpoint
)
model.eval()

Single Image — Per-Label Probabilities

from PIL import Image

image = Image.open("image.jpg").convert("RGB")
input_tensor = model.preprocess(image)

with torch.inference_mode():
    logits = model(input_tensor)
    probs = logits.sigmoid().squeeze(0)  # (num_labels,)

label_names = ["cat", "dog", "outdoor", "indoor"]
threshold = 0.5

for name, prob in zip(label_names, probs):
    marker = "+" if prob > threshold else " "
    print(f"  [{marker}] {name}: {prob:.2%}")

# Get predicted labels
predicted = [n for n, p in zip(label_names, probs) if p > threshold]
print(f"Predicted: {predicted}")

Output Example:

  [+] cat: 92.15%
  [ ] dog: 12.40%
  [+] outdoor: 87.63%
  [ ] indoor: 3.11%
Predicted: ['cat', 'outdoor']

Batch Prediction

model.eval()
device = next(model.parameters()).device
threshold = 0.5

all_preds = []
all_probs = []

with torch.inference_mode():
    for images, _ in dataloader:
        images = images.to(device)
        logits = model(images)
        probs = logits.sigmoid()
        preds = (probs > threshold).int()

        all_preds.extend(preds.cpu().tolist())
        all_probs.extend(probs.cpu().tolist())

Using Trainer.predict

predict_step returns sigmoid probabilities automatically for multi-label models:

from autotimm import AutoTrainer, MultiLabelImageDataModule

data = MultiLabelImageDataModule(
    train_csv="data.csv",
    image_dir="./images",
    image_size=224,
    batch_size=32,
)
data.setup("fit")

trainer = AutoTrainer()
predictions = trainer.predict(model, dataloaders=data.val_dataloader())

# predictions is a list of batched sigmoid probability tensors
all_probs = torch.cat(predictions, dim=0)   # (N, num_labels)
all_preds = (all_probs > 0.5).int()          # binary predictions

Adjusting the Threshold

Different thresholds trade off precision and recall per label:

# Stricter — fewer but more confident predictions
strict_preds = (probs > 0.7).int()

# Lenient — more predictions, some less confident
lenient_preds = (probs > 0.3).int()

# Per-label thresholds (tuned on validation set)
thresholds = torch.tensor([0.4, 0.6, 0.5, 0.3])
custom_preds = (probs > thresholds).int()

Exporting Multi-Label Predictions

import csv

label_names = ["cat", "dog", "outdoor", "indoor"]

with open("predictions.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["index"] + label_names + [f"{n}_prob" for n in label_names])
    for i, (pred, prob) in enumerate(zip(all_preds, all_probs)):
        row = [i] + [int(p) for p in pred] + [f"{p:.4f}" for p in prob]
        writer.writerow(row)

Key Differences from Single-Label

Single-Label Multi-Label
Output activation softmax (sums to 1) sigmoid (each in [0, 1])
Prediction argmax → one class sigmoid > threshold → multiple labels
Loss CrossEntropyLoss BCEWithLogitsLoss
Targets Integer class index Multi-hot float vector
Confidence Max softmax probability Per-label sigmoid probability

Performance Optimization

1. Batch Processing

Process multiple images at once for GPU efficiency:

def predict_batch_efficient(model, image_paths, transform, device, batch_size=32):
    """Efficient batch prediction."""
    model.eval()
    all_predictions = []

    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i + batch_size]

        # Load and transform batch
        images = []
        for path in batch_paths:
            img = Image.open(path).convert("RGB")
            images.append(transform(img))

        # Stack into batch
        batch_tensor = torch.stack(images).to(device)

        # Predict
        with torch.inference_mode():
            logits = model(batch_tensor)
            probs = torch.softmax(logits, dim=1)
            preds = probs.argmax(dim=1)

        all_predictions.extend(preds.cpu().tolist())

    return all_predictions

2. Half Precision (FP16)

Reduce memory usage and increase speed:

# Convert model to half precision
model = model.half()

# Convert input to half precision
input_tensor = input_tensor.half()

with torch.inference_mode():
    logits = model(input_tensor)

Note: Requires a GPU with tensor cores (V100, A100, RTX series)

3. Compiled Models

Use torch.compile for optimized inference (PyTorch 2.0+):

import torch

model = ImageClassifier.load_from_checkpoint(..., compile_model=False)
model.eval()

# Compile model
model = torch.compile(model, mode="reduce-overhead")

# First run is slower (compilation)
# Subsequent runs are faster
with torch.inference_mode():
    output = model(input_tensor)

4. Disable Gradient Tracking

Always use torch.inference_mode() for inference:

# Good - saves memory
with torch.inference_mode():
    output = model(input_tensor)

# Also good
@torch.inference_mode()
def predict(model, input_tensor):
    return model(input_tensor)

# Bad - wastes memory tracking gradients
output = model(input_tensor)  # Don't do this for inference!

Common Issues

For classification inference issues, see the Troubleshooting - Export & Inference including:

  • Out of memory
  • Slow inference
  • Wrong predictions
  • Batch processing issues

Quick Troubleshooting Checklist

# 1. Use exact same normalization as training
config = model.get_data_config()
transform = transforms.Normalize(mean=config['mean'], std=config['std'])

# 2. Ensure model is in eval mode
model.eval()

# 3. Check input preprocessing
# Image should be RGB, not BGR
image = Image.open("img.jpg").convert("RGB")

# 4. Verify class mapping
# Ensure class names match training order

See Also