Classification Examples¶
This page demonstrates image classification tasks using AutoTimm.
Classification Workflow¶
graph TD
A[Dataset] --> A1[Download/Locate]
A1 --> A2[Verify Structure]
A2 --> B[ImageDataModule]
B --> B1[Configure Transforms]
B1 --> B2[Setup DataLoaders]
C[Backbone] --> C1[Select Architecture]
C1 --> C2[Load Pretrained]
C2 --> D[ImageClassifier]
D --> D1[Add Task Head]
E[MetricConfig] --> E1[Choose Metrics]
E1 --> E2[Set Parameters]
E2 --> F[MetricManager]
F --> F1[Initialize Metrics]
F1 --> D
B2 --> G[AutoTrainer]
D --> G
H[LoggerConfig] --> H1[Configure Backend]
H1 --> H2[Set Parameters]
H2 --> G
G --> G1[Setup Callbacks]
G1 --> G2[Configure GPUs]
G2 --> I[Training]
I --> I1[Training Loop]
I1 --> I2[Log Metrics]
I2 --> I3[Save Checkpoints]
I3 --> J[Validation]
J --> J1[Compute Metrics]
J1 --> J2[Early Stopping Check]
J2 --> J3{Continue?}
J3 -->|Yes| I1
J3 -->|No| K[Testing]
K --> K1[Load Best Model]
K1 --> K2[Test Evaluation]
K2 --> L[Results]
L --> L1[Generate Reports]
L1 --> L2[Save Predictions]
L2 --> L3[Plot Metrics]
style B fill:#2196F3,stroke:#1976D2
style D fill:#1976D2,stroke:#1565C0
style F fill:#2196F3,stroke:#1976D2
style G fill:#1976D2,stroke:#1565C0
style I fill:#2196F3,stroke:#1976D2
style K fill:#1976D2,stroke:#1565C0
style L fill:#2196F3,stroke:#1976D2
CIFAR-10 Classification¶
Basic training with ResNet-18 on CIFAR-10 using MetricManager.
import autotimm as at # recommended alias
from autotimm import (
AutoTrainer,
ImageClassifier,
ImageDataModule,
LoggerConfig,
MetricConfig,
MetricManager,
)
def main():
# Data
data = ImageDataModule(
data_dir="./data",
dataset_name="CIFAR10",
image_size=224,
batch_size=64,
)
# Metrics
metric_configs = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
prog_bar=True,
),
]
# Create MetricManager
metric_manager = MetricManager(configs=metric_configs, num_classes=10)
# Model
model = ImageClassifier(
backbone="resnet18",
num_classes=10,
metrics=metric_manager,
lr=1e-3,
)
# Train
trainer = AutoTrainer(
max_epochs=10,
logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
checkpoint_monitor="val/accuracy",
)
trainer.fit(model, datamodule=data)
trainer.test(model, datamodule=data)
if __name__ == "__main__":
main()
Custom Folder Dataset¶
Training on your own images organized in folders.
from autotimm import (
AutoTrainer,
ImageClassifier,
ImageDataModule,
LoggerConfig,
MetricConfig,
MetricManager,
)
def main():
# Your data structure:
# dataset/
# train/
# class_a/...
# class_b/...
# val/
# class_a/...
# class_b/...
data = ImageDataModule(
data_dir="./dataset",
image_size=384,
batch_size=16,
)
data.setup("fit")
metric_configs = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
prog_bar=True,
),
]
metric_manager = MetricManager(configs=metric_configs, num_classes=data.num_classes)
model = ImageClassifier(
backbone="efficientnet_b3",
num_classes=data.num_classes,
metrics=metric_manager,
)
trainer = AutoTrainer(
max_epochs=20,
logger=[LoggerConfig(backend="wandb", params={"project": "my-project"})],
checkpoint_monitor="val/accuracy",
)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()
ViT Fine-Tuning¶
Two-phase fine-tuning for Vision Transformers with MetricManager.
from autotimm import AutoTrainer, ImageClassifier, MetricConfig, MetricManager
def main():
metric_configs = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
),
]
metric_manager = MetricManager(configs=metric_configs, num_classes=data.num_classes)
# Phase 1: Linear probe (frozen backbone)
model = ImageClassifier(
backbone="vit_base_patch16_224",
num_classes=data.num_classes,
metrics=metric_manager,
freeze_backbone=True,
lr=1e-2,
)
trainer = AutoTrainer(max_epochs=5)
trainer.fit(model, datamodule=data)
# Phase 2: Full fine-tune
for param in model.backbone.parameters():
param.requires_grad = True
model._lr = 1e-4
trainer = AutoTrainer(max_epochs=20, gradient_clip_val=1.0)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()
Multi-Label Classification¶
Multi-label classification where each image can belong to multiple classes simultaneously.
Uses MultiLabelImageDataModule for CSV-based data and ImageClassifier with multi_label=True.
from autotimm import (
AutoTrainer,
ImageClassifier,
MetricConfig,
MultiLabelImageDataModule,
)
def main():
# Data - CSV with columns: image_path, cat, dog, outdoor, indoor
data = MultiLabelImageDataModule(
train_csv="train.csv",
val_csv="val.csv",
image_dir="./images",
image_size=224,
batch_size=32,
)
data.setup("fit")
num_labels = data.num_labels # auto-detected from CSV
# Multilabel metrics
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="MultilabelAccuracy",
params={"num_labels": num_labels},
stages=["train", "val"],
prog_bar=True,
),
MetricConfig(
name="f1",
backend="torchmetrics",
metric_class="MultilabelF1Score",
params={"num_labels": num_labels, "average": "macro"},
stages=["val"],
),
]
# Model - multi_label=True switches to BCEWithLogitsLoss + sigmoid
# Use loss_fn to override the default loss (e.g., loss_fn="focal" or a custom nn.Module)
model = ImageClassifier(
backbone="resnet50",
num_classes=num_labels,
multi_label=True,
threshold=0.5,
metrics=metrics,
lr=1e-3,
)
# Train
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()
Key differences from single-label:
- Data: CSV with multi-hot label columns instead of ImageFolder directories
- Model:
multi_label=TrueusesBCEWithLogitsLossand sigmoid predictions - Metrics: Use
Multilabel*metrics (e.g.,MultilabelAccuracy,MultilabelF1Score) predict_stepreturns per-label sigmoid probabilities (each in [0, 1], don't sum to 1)
Running Classification Examples¶
git clone https://github.com/theja-vanka/AutoTimm.git
cd AutoTimm
pip install -e ".[all]"
# Run examples
python examples/getting_started/classify_cifar10.py
python examples/getting_started/classify_custom_folder.py
python examples/getting_started/vit_finetuning.py
python examples/data_training/multilabel_classification.py