Multi-Label Classification Data¶
The MultiLabelImageDataModule handles multi-label image classification datasets from CSV files, where each image can belong to multiple classes simultaneously.
Data Loading Flow¶
graph TD
A[CSV File] --> A1[Read CSV]
A1 --> A2[Parse Headers]
A2 --> A3[Load Paths & Labels]
A3 --> B[MultiLabelImageDataset]
C[Image Dir] --> C1[Verify Directory]
C1 --> C2[List Image Files]
C2 --> B
B --> B1[Create Dataset]
B1 --> B2[Validate Labels]
B2 --> B3[Multi-hot Encoding]
B3 --> D[MultiLabelImageDataModule]
D --> D1[Initialize Module]
D1 --> D2[Configure Params]
D2 --> E{Split Strategy}
E -->|val_csv provided| F1[Separate val set]
F1 --> F1a[Load val CSV]
F1a --> F1b[Create Val Dataset]
E -->|No val_csv| F2[Auto split val_split]
F2 --> F2a[Random Split]
F2a --> F2b[Create Val Subset]
F1b --> G[Train / Val / Test Splits]
F2b --> G
G --> G1[Assign Indices]
G1 --> G2[Verify Splits]
G2 --> H[Transforms]
H -->|Train| I1[Augmentation]
I1 --> I1a[RandomResizedCrop]
I1a --> I1b[RandomHorizontalFlip]
I1b --> I1c[ColorJitter]
I1c --> I1d[Normalize]
H -->|Val / Test| I2[Resize + Normalize]
I2 --> I2a[Resize]
I2a --> I2b[CenterCrop]
I2b --> I2c[Normalize]
I1d --> J[DataLoader]
I2c --> J
J --> J1[Create Loaders]
J1 --> J2[Set Batch Size]
J2 --> J3[Configure Workers]
J3 --> K[Batches: image, multi-hot labels]
K --> K1[Stack Images]
K1 --> K2[Stack Labels]
K2 --> K3[Ready for Model]
style A fill:#2196F3,stroke:#1976D2
style B fill:#1976D2,stroke:#1565C0
style D fill:#2196F3,stroke:#1976D2
style H fill:#1976D2,stroke:#1565C0
style J fill:#2196F3,stroke:#1976D2
style K fill:#1976D2,stroke:#1565C0
style K3 fill:#2196F3,stroke:#1976D2
CSV Format¶
Each row contains an image path and binary label columns:
- First column: relative image path (resolved against
image_dir) - Remaining columns: binary label indicators (0 or 1)
Basic Usage¶
import autotimm as at # recommended alias
from autotimm import MultiLabelImageDataModule
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
val_csv="val.csv",
image_size=224,
batch_size=32,
)
data.setup("fit")
print(f"Num labels: {data.num_labels}") # 4
print(f"Label names: {data.label_names}") # ['cat', 'dog', 'outdoor', 'indoor']
print(f"Train samples: {len(data.train_dataset)}")
Auto Validation Split¶
If no val_csv is provided, a fraction of training data is held out:
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
val_split=0.2, # 20% for validation
)
Explicit Label Columns¶
By default, all columns except the first are used as labels. To select specific columns:
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
label_columns=["cat", "dog"], # Only use these 2 labels
image_column="filepath", # Explicit image column name
)
Transform Backends¶
Torchvision (Default)¶
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
transform_backend="torchvision",
augmentation_preset="randaugment",
)
Albumentations¶
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
transform_backend="albumentations",
augmentation_preset="strong",
)
Custom Transforms¶
from torchvision import transforms
custom_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
train_transforms=custom_train,
)
DataLoader Options¶
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
batch_size=64,
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
)
Full Parameter Reference¶
MultiLabelImageDataModule(
train_csv="train.csv", # Path to training CSV
image_dir="./images", # Root directory for image paths
val_csv=None, # Optional validation CSV
test_csv=None, # Optional test CSV
label_columns=None, # Label column names (auto-detected)
image_column=None, # Image column name (first column)
image_size=224, # Target image size
batch_size=32, # Batch size
num_workers=4, # Data loading workers
val_split=0.1, # Validation split fraction
train_transforms=None, # Custom train transforms
eval_transforms=None, # Custom eval transforms
augmentation_preset=None, # Preset name
transform_backend="torchvision", # "torchvision" or "albumentations"
transform_config=None, # TransformConfig for model normalization
backbone=None, # Backbone for model-specific normalization
pin_memory=True, # Pin memory for GPU
persistent_workers=False, # Keep workers alive
prefetch_factor=None, # Prefetch batches per worker
)
Complete Example¶
from autotimm import (
AutoTrainer,
ImageClassifier,
MetricConfig,
MultiLabelImageDataModule,
)
# Data
data = MultiLabelImageDataModule(
train_csv="train.csv",
image_dir="./images",
val_csv="val.csv",
image_size=224,
batch_size=32,
num_workers=4,
)
data.setup("fit")
# Model
model = ImageClassifier(
backbone="resnet50",
num_classes=data.num_labels,
multi_label=True,
threshold=0.5,
metrics=[
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="MultilabelAccuracy",
params={"num_labels": data.num_labels},
stages=["train", "val"],
prog_bar=True,
),
MetricConfig(
name="f1",
backend="torchmetrics",
metric_class="MultilabelF1Score",
params={"num_labels": data.num_labels, "average": "macro"},
stages=["val"],
),
],
lr=1e-3,
)
# Train
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
See Also¶
- Image Classification Data - Single-label classification data loading
- Image Classifier Guide - Multi-label model configuration
- Multi-Label Example - Runnable example
- API Reference - Full API docs