Skip to content

Quick Start

This guide walks you through training your first image classifier with AutoTimm.

Training Workflow

graph LR
    A[<b>1. Import</b>] --> B[<b>2. Data</b><br/>ImageDataModule]
    B --> C[<b>3. Metrics</b><br/>MetricConfig]
    C --> D[<b>4. Model</b><br/>Select backbone +<br/>num_classes]
    D --> E[<b>5. Trainer</b><br/>AutoTrainer]
    E --> F[<b>6. Train</b><br/>trainer.fit]
    F --> G[<b>7. Evaluate</b><br/>trainer.test]

    style A fill:#1565C0,stroke:#0D47A1
    style B fill:#1976D2,stroke:#1565C0
    style C fill:#1565C0,stroke:#0D47A1
    style D fill:#1976D2,stroke:#1565C0
    style E fill:#1565C0,stroke:#0D47A1
    style F fill:#1976D2,stroke:#1565C0
    style G fill:#4CAF50,stroke:#388E3C

Basic Training

1. Import Required Classes

import autotimm as at  # recommended alias
from autotimm import (
    AutoTrainer,
    ImageClassifier,
    ImageDataModule,
    MetricConfig,
    MetricManager,
)

2. Set Up Data

AutoTimm supports built-in datasets (CIFAR10, CIFAR100, MNIST, FashionMNIST) and custom folder-based datasets.

data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",  # Downloads automatically
    image_size=224,
    batch_size=64,
)

3. Define Metrics with MetricManager

AutoTimm requires explicit metric configuration using MetricConfig and MetricManager:

metric_configs = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
        prog_bar=True,
    ),
]

# Create MetricManager for programmatic access
metric_manager = MetricManager(configs=metric_configs, num_classes=10)

# Access metrics by name
accuracy = metric_manager.get_metric_by_name("accuracy")

# Iterate over configs
for config in metric_manager:
    print(f"{config.name}: {config.stages}")

4. Create Model

Choose from 1000+ timm backbones:

model = ImageClassifier(
    backbone="resnet18",
    num_classes=10,
    metrics=metric_manager,
    lr=1e-3,
)

torch.compile Enabled by Default

AutoTimm automatically applies torch.compile() (PyTorch 2.0+) to all models for faster training and inference. To disable: compile_model=False. For custom options: compile_kwargs={"mode": "reduce-overhead"}.

5. Train

trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

6. Evaluate

trainer.test(model, datamodule=data)

Complete Example

from autotimm import (
    AutoTrainer,
    ImageClassifier,
    ImageDataModule,
    LoggerConfig,
    MetricConfig,
    MetricManager,
)


def main():
    # Data
    data = ImageDataModule(
        data_dir="./data",
        dataset_name="CIFAR10",
        image_size=224,
        batch_size=64,
    )

    # Metrics
    metric_configs = [
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["train", "val", "test"],
            prog_bar=True,
        ),
        MetricConfig(
            name="f1",
            backend="torchmetrics",
            metric_class="F1Score",
            params={"task": "multiclass", "average": "macro"},
            stages=["val", "test"],
        ),
    ]

    # Create MetricManager
    metric_manager = MetricManager(configs=metric_configs, num_classes=10)

    # Model
    model = ImageClassifier(
        backbone="resnet18",
        num_classes=10,
        metrics=metric_manager,
        lr=1e-3,
    )

    # Trainer with TensorBoard logging
    trainer = AutoTrainer(
        max_epochs=10,
        logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
        checkpoint_monitor="val/accuracy",
    )

    # Train and evaluate
    trainer.fit(model, datamodule=data)
    trainer.test(model, datamodule=data)


if __name__ == "__main__":
    main()

Using Custom Datasets

Organize your images in ImageFolder format:

dataset/
  train/
    class_a/
      img1.jpg
      img2.jpg
    class_b/
      img3.jpg
  val/
    class_a/
      img4.jpg
    class_b/
      img5.jpg

Then load:

def main():
    data = ImageDataModule(
        data_dir="./dataset",
        image_size=224,
        batch_size=32,
    )
    data.setup("fit")

    metric_configs = [
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["train", "val", "test"],
            prog_bar=True,
        ),
    ]

    metric_manager = MetricManager(configs=metric_configs, num_classes=data.num_classes)

    model = ImageClassifier(
        backbone="efficientnet_b0",
        num_classes=data.num_classes,
        metrics=metric_manager,
    )


if __name__ == "__main__":
    main()

Using Different Backbones

Browse available backbones:

import autotimm as at  # recommended alias

# Search by pattern
at.list_backbones("*resnet*")
at.list_backbones("*efficientnet*", pretrained_only=True)
at.list_backbones("*vit*")

Popular choices:

Backbone Description
resnet18, resnet50 Classic ResNet models
efficientnet_b0 to efficientnet_b7 EfficientNet family
vit_base_patch16_224 Vision Transformer
convnext_tiny ConvNeXt models
swin_tiny_patch4_window7_224 Swin Transformer

Discovering Available Optimizers and Schedulers

AutoTimm provides utilities to discover all available optimizers and learning rate schedulers from both PyTorch and timm.

List Optimizers

import autotimm as at  # recommended alias

# Get all optimizers from torch and timm
optimizers = at.list_optimizers()
print("Torch optimizers:", optimizers["torch"])
print("Timm optimizers:", optimizers.get("timm", []))

# Get only torch optimizers
optimizers = at.list_optimizers(include_timm=False)

Available optimizers:

  • PyTorch: adamw, adam, sgd, rmsprop, adagrad, adadelta, adamax, asgd
  • Timm: adamp, sgdp, adabelief, radam, lamb, lars, madgrad, novograd

List Schedulers

# Get all schedulers from torch and timm
schedulers = at.list_schedulers()
print("Torch schedulers:", schedulers["torch"])
print("Timm schedulers:", schedulers.get("timm", []))

Available schedulers:

  • PyTorch: cosineannealinglr, cosineannealingwarmrestarts, steplr, multisteplr, exponentiallr, onecyclelr, reducelronplateau, and more (15 total)
  • Timm: cosinelrscheduler, multisteplrscheduler, plateaulrscheduler, steplrscheduler, and more (6 total)

Using Custom Optimizers and Schedulers

# Use a timm optimizer
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    optimizer="adamw",  # or "adamp", "lamb", etc.
    lr=1e-3,
)

# Use a custom scheduler
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    lr=1e-3,
    scheduler="cosineannealinglr",
    scheduler_kwargs={"T_max": 10},
)

Object Detection Quick Start

AutoTimm also supports object detection with FCOS-style anchor-free detectors.

Basic Object Detection Example

from autotimm import (
    AutoTrainer,
    DetectionDataModule,
    MetricConfig,
    ObjectDetector,
)


def main():
    # Data - COCO format detection dataset
    data = DetectionDataModule(
        data_dir="./coco",
        image_size=640,
        batch_size=16,
        augmentation_preset="default",
    )

    # Metrics
    metric_configs = [
        MetricConfig(
            name="mAP",
            backend="torchmetrics",
            metric_class="MeanAveragePrecision",
            params={"box_format": "xyxy"},
            stages=["val", "test"],
            prog_bar=True,
        ),
    ]

    # Model - FCOS detector with any timm backbone
    model = ObjectDetector(
        backbone="resnet50",  # Try: swin_tiny, efficientnet_b3, etc.
        num_classes=80,  # COCO has 80 classes
        metrics=metric_configs,
        lr=1e-4,
    )

    # Train
    trainer = AutoTrainer(
        max_epochs=12,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, datamodule=data)
    trainer.test(model, datamodule=data)


if __name__ == "__main__":
    main()

Detection with Transformer Backbones

Use Vision Transformers for detection:

# Swin Transformer (recommended for detection)
model = ObjectDetector(
    backbone="swin_tiny_patch4_window7_224",
    num_classes=80,
    metrics=metric_configs,
    lr=1e-4,
)

# Vision Transformer
model = ObjectDetector(
    backbone="vit_base_patch16_224",
    num_classes=80,
    metrics=metric_configs,
    lr=1e-4,
)

See Object Detection Examples for more details.

Alternative: RT-DETR

For end-to-end transformer detection without NMS:

from transformers import RTDetrForObjectDetection
from autotimm import DetectionDataModule

# Use AutoTimm's data loading
data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=4,
)

# RT-DETR model
model = RTDetrForObjectDetection.from_pretrained(
    "PekingU/rtdetr_r50vd",
    num_labels=80,
)

See RT-DETR Example for complete integration.

Semantic Segmentation Quick Start

AutoTimm provides DeepLabV3+ and FCN architectures for semantic segmentation.

Basic Semantic Segmentation Example

from autotimm import (
    AutoTrainer,
    SemanticSegmentor,
    SegmentationDataModule,
    MetricConfig,
)


def main():
    # Data - supports PNG, COCO, Cityscapes, Pascal VOC formats
    data = SegmentationDataModule(
        data_dir="./cityscapes",
        format="cityscapes",  # or "png", "coco", "voc"
        image_size=512,
        batch_size=8,
        augmentation_preset="default",
    )

    # Metrics
    metric_configs = [
        MetricConfig(
            name="iou",
            backend="torchmetrics",
            metric_class="JaccardIndex",
            params={
                "task": "multiclass",
                "num_classes": 19,
                "average": "macro",
                "ignore_index": 255,
            },
            stages=["val", "test"],
            prog_bar=True,
        ),
    ]

    # Model - DeepLabV3+ with any timm backbone
    model = SemanticSegmentor(
        backbone="resnet50",  # Try: swin_tiny, efficientnet_b3, etc.
        num_classes=19,  # Cityscapes has 19 classes
        head_type="deeplabv3plus",  # or "fcn"
        loss_type="combined",  # CE + Dice
        metrics=metric_configs,
        lr=1e-4,
    )

    # Train
    trainer = AutoTrainer(
        max_epochs=100,
        precision="16-mixed",  # Mixed precision for faster training
    )

    trainer.fit(model, datamodule=data)
    trainer.test(model, datamodule=data)


if __name__ == "__main__":
    main()

Using Import Aliases

AutoTimm supports cleaner imports:

from autotimm.task import SemanticSegmentor
from autotimm.loss import DiceLoss, CombinedSegmentationLoss
from autotimm.head import DeepLabV3PlusHead
from autotimm.metric import MetricConfig

# Create model using aliases
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
)

See Semantic Segmentation Guide for more details.

Instance Segmentation Quick Start

AutoTimm supports instance segmentation with Mask R-CNN style architecture.

Basic Instance Segmentation Example

from autotimm import (
    AutoTrainer,
    InstanceSegmentor,
    InstanceSegmentationDataModule,
    MetricConfig,
)


def main():
    # Data - COCO format with masks
    data = InstanceSegmentationDataModule(
        data_dir="./coco",
        image_size=640,
        batch_size=4,
        augmentation_preset="default",
    )

    # Metrics
    metric_configs = [
        MetricConfig(
            name="mask_mAP",
            backend="torchmetrics",
            metric_class="MeanAveragePrecision",
            params={"box_format": "xyxy", "iou_type": "segm"},
            stages=["val", "test"],
            prog_bar=True,
        ),
    ]

    # Model - FCOS detection + mask head
    model = InstanceSegmentor(
        backbone="resnet50",
        num_classes=80,  # COCO has 80 classes
        metrics=metric_configs,
        lr=1e-4,
        mask_loss_weight=1.0,
    )

    # Train
    trainer = AutoTrainer(
        max_epochs=12,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, datamodule=data)
    trainer.test(model, datamodule=data)


if __name__ == "__main__":
    main()

See Instance Segmentation Guide for more details.

Inference with Preprocessing

AutoTimm models can preprocess raw images using model-specific normalization:

from autotimm import ImageClassifier, TransformConfig, MetricConfig
from PIL import Image

# Create model with TransformConfig
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=[
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["val"],
        ),
    ],
    transform_config=TransformConfig(),  # Enable preprocessing
)

# Preprocess raw images for inference
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Uses model's pretrained normalization

# Run inference
model.eval()
with torch.inference_mode():
    predictions = model(tensor).softmax(dim=1)
    predicted_class = predictions.argmax(dim=1).item()
    print(f"Predicted class: {predicted_class}")

Shared Config for Training and Inference

from autotimm import ImageClassifier, ImageDataModule, TransformConfig

# Create shared config
config = TransformConfig(preset="randaugment", image_size=384)
backbone = "efficientnet_b4"

# DataModule uses model's normalization for training
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    transform_config=config,
    backbone=backbone,
)

# Model uses same config for inference preprocessing
model = ImageClassifier(
    backbone=backbone,
    num_classes=10,
    metrics=metrics,
    transform_config=config,
)

See TransformConfig API for full details.

Next Steps