Skip to content

Transforms

AutoTimm provides a flexible transform system for image preprocessing and augmentation, supporting both torchvision and albumentations backends with model-specific normalization.

Transform Pipeline

graph TD
    A[Raw Image] --> A1[Load from Disk]
    A1 --> A2[Decode Format]
    A2 --> B{TransformConfig}

    B -->|Backend| C1[torchvision]
    C1 --> C1a[PIL Backend]
    C1a --> C1b[Import Transforms]

    B -->|Backend| C2[albumentations]
    C2 --> C2a[OpenCV Backend]
    C2a --> C2b[Import Transforms]

    C1b --> D1[Preset Transforms]
    C2b --> D2[Preset Transforms]

    D1 --> D1a[weak/medium/strong]
    D1a --> D1b[auto_augment]
    D1b --> D1c[rand_augment]

    D2 --> D2a[weak/medium/strong]
    D2a --> D2b[auto_augment]
    D2b --> D2c[rand_augment]

    D1c --> E[Resize]
    D2c --> E

    E --> E1[Determine Size]
    E1 --> E2[Maintain Aspect]
    E2 --> E3[Apply Interpolation]
    E3 --> F[Augmentation]

    F -->|Training| G1[RandAugment/Strong]
    G1 --> G1a[Random Crop]
    G1a --> G1b[Random Flip]
    G1b --> G1c[Color Jitter]
    G1c --> G1d[Rotation]
    G1d --> G1e[Mixup/Cutmix]

    F -->|Evaluation| G2[CenterCrop]
    G2 --> G2a[Center Crop]
    G2a --> G2b[Resize to Target]

    G1e --> H[Normalization]
    G2b --> H

    H -->|timm config| I1[Model-specific]
    I1 --> I1a[Load Pretrained Stats]
    I1a --> I1b[Get Mean/Std]
    I1b --> I1c[Apply Normalization]

    H -->|custom| I2[User-defined]
    I2 --> I2a[Custom Mean/Std]
    I2a --> I2b[Apply Normalization]

    I1c --> J[Tensor]
    I2b --> J

    J --> J1[Convert to Tensor]
    J1 --> J2[Permute Dimensions]
    J2 --> J3[Ensure Float32]
    J3 --> J4[Ready for Model]

    style B fill:#2196F3,stroke:#1976D2
    style C1 fill:#1976D2,stroke:#1565C0
    style C2 fill:#2196F3,stroke:#1976D2
    style E fill:#1976D2,stroke:#1565C0
    style F fill:#2196F3,stroke:#1976D2
    style H fill:#1976D2,stroke:#1565C0
    style J fill:#2196F3,stroke:#1976D2
    style J4 fill:#1976D2,stroke:#1565C0

Overview

The transform system consists of three main components:

  1. TransformConfig: Unified configuration dataclass for all transform settings
  2. Augmentation Presets: Built-in transform pipelines for training and evaluation
  3. Model-Specific Normalization: Automatic normalization using timm's pretrained statistics

TransformConfig

TransformConfig is the central configuration class for all transforms in AutoTimm. It provides a consistent interface across models and data modules.

Basic Usage

import autotimm as at  # recommended alias
from autotimm import TransformConfig

# Default configuration
config = TransformConfig()

# Custom configuration
config = TransformConfig(
    preset="randaugment",
    backend="torchvision",
    image_size=384,
    use_timm_config=True,
)

Parameters

Parameter Type Default Description
preset str "default" Augmentation preset name
backend str "torchvision" Transform backend ("torchvision" or "albumentations")
image_size int 224 Target image size (square)
use_timm_config bool True Use model's pretrained normalization
mean tuple None Override normalization mean
std tuple None Override normalization std
interpolation str "bicubic" Resize interpolation mode
crop_pct float 0.875 Center crop percentage for eval

Detection-Specific Parameters

Parameter Type Default Description
min_bbox_area float 0.0 Minimum bbox area to keep
min_visibility float 0.0 Minimum visibility ratio (0.0-1.0)
bbox_format str "coco" Bbox format ("coco", "pascal_voc", "yolo")

Segmentation-Specific Parameters

Parameter Type Default Description
ignore_index int 255 Label index to ignore in masks

RandAugment Parameters

Parameter Type Default Description
randaugment_num_ops int 2 Number of augmentation operations
randaugment_magnitude int 9 Magnitude of augmentations (0-30)

Configuration Methods

# Create config with overrides
base_config = TransformConfig(image_size=224)
large_config = base_config.with_overrides(image_size=384)

# Convert to dictionary
config_dict = config.to_dict()

Augmentation Presets

AutoTimm provides several built-in augmentation presets for different training scenarios.

Torchvision Presets

Preset Description Use Case
default RandomResizedCrop, HorizontalFlip, ColorJitter General training
autoaugment AutoAugment (ImageNet policy) Proven augmentation
randaugment RandAugment (configurable ops/magnitude) Flexible augmentation
trivialaugment TrivialAugmentWide Simple but effective
light RandomResizedCrop, HorizontalFlip only Minimal augmentation

Albumentations Presets

Preset Description Use Case
default RandomResizedCrop, HorizontalFlip, ColorJitter General training
strong Affine, blur/noise, ColorJitter, CoarseDropout Heavy augmentation
light RandomResizedCrop, HorizontalFlip only Minimal augmentation

Preset Examples

from autotimm import ImageDataModule, TransformConfig

# Standard training
data = ImageDataModule(
    data_dir="./data",
    transform_config=TransformConfig(preset="default"),
    backbone="resnet50",
)

# Strong augmentation with RandAugment
data = ImageDataModule(
    data_dir="./data",
    transform_config=TransformConfig(
        preset="randaugment",
        randaugment_num_ops=3,
        randaugment_magnitude=12,
    ),
    backbone="resnet50",
)

# Heavy augmentation with albumentations
data = ImageDataModule(
    data_dir="./data",
    transform_config=TransformConfig(
        preset="strong",
        backend="albumentations",
    ),
    backbone="resnet50",
)

Default Transform Pipelines

Training Transforms (default preset)

1. RandomResizedCrop(image_size)
2. RandomHorizontalFlip(p=0.5)
3. ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
4. ToTensor()
5. Normalize(mean, std)

Evaluation Transforms

1. Resize(image_size / crop_pct)  # e.g., 256 for 224 with 0.875 crop_pct
2. CenterCrop(image_size)
3. ToTensor()
4. Normalize(mean, std)

AutoAugment Pipeline

1. RandomResizedCrop(image_size)
2. RandomHorizontalFlip(p=0.5)
3. AutoAugment(policy=IMAGENET)
4. ToTensor()
5. Normalize(mean, std)

RandAugment Pipeline

1. RandomResizedCrop(image_size)
2. RandomHorizontalFlip(p=0.5)
3. RandAugment(num_ops=2, magnitude=9)
4. ToTensor()
5. Normalize(mean, std)

Strong Albumentations Pipeline

1. RandomResizedCrop(image_size)
2. HorizontalFlip(p=0.5)
3. Affine(translate, scale, rotate, p=0.5)
4. OneOf([MotionBlur, GaussianBlur, GaussNoise], p=0.3)
5. ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
6. CoarseDropout(num_holes=1-3, p=0.3)
7. Normalize(mean, std)
8. ToTensorV2()

Model-Specific Normalization

Different pretrained models use different normalization statistics. AutoTimm automatically uses the correct normalization for each model.

How It Works

  1. When use_timm_config=True (default), AutoTimm queries timm for the model's pretrained data config
  2. The normalization mean/std are extracted from this config
  3. Transforms are created with the correct normalization

Common Model Normalizations

Model Family Mean Std
ResNet, EfficientNet (0.485, 0.456, 0.406) (0.229, 0.224, 0.225)
ViT (CLIP) (0.5, 0.5, 0.5) (0.5, 0.5, 0.5)
Inception (0.5, 0.5, 0.5) (0.5, 0.5, 0.5)

Example: Get Model Data Config

from autotimm.data import resolve_backbone_data_config

# Get config for a specific model
config = resolve_backbone_data_config("efficientnet_b4")
print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")
print(f"Input size: {config['input_size']}")
print(f"Interpolation: {config['interpolation']}")
print(f"Crop percentage: {config['crop_pct']}")

Override Normalization

# Use ImageNet normalization regardless of model
config = TransformConfig(
    use_timm_config=False,
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
)

# Or just override specific values
config = TransformConfig(
    mean=(0.5, 0.5, 0.5),  # Override mean, get std from model
)

Using Transforms with DataModules

from autotimm import ImageDataModule, TransformConfig

config = TransformConfig(
    preset="randaugment",
    image_size=384,
)

data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    transform_config=config,
    backbone="efficientnet_b4",  # Required for model-specific normalization
)

Method 2: Augmentation Preset Only

# Simple preset selection
data = ImageDataModule(
    data_dir="./data",
    augmentation_preset="randaugment",
    image_size=224,
)

Method 3: Custom Transforms

from torchvision import transforms

custom_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

custom_eval = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

data = ImageDataModule(
    data_dir="./data",
    train_transforms=custom_train,
    eval_transforms=custom_eval,
)

Using Transforms with Models

Models can use TransformConfig for inference-time preprocessing:

from autotimm import ImageClassifier, TransformConfig, MetricConfig
from PIL import Image

# Create model with TransformConfig
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=[MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["val"],
    )],
    transform_config=TransformConfig(),
)

# Preprocess raw images
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Uses model's normalization

# Batch preprocessing
images = [Image.open(f"img{i}.jpg") for i in range(4)]
batch = model.preprocess(images)  # Returns (4, 3, 224, 224)

# Get model's data config
config = model.get_data_config()
print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")

Shared Config for Model and Data

Ensure consistent preprocessing between training and inference:

from autotimm import (
    ImageClassifier,
    ImageDataModule,
    TransformConfig,
    MetricConfig,
)

# Shared configuration
config = TransformConfig(
    preset="randaugment",
    image_size=384,
)
backbone_name = "efficientnet_b4"

# DataModule uses model's normalization during training
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    transform_config=config,
    backbone=backbone_name,
)
data.setup("fit")

# Model uses same config for inference preprocessing
model = ImageClassifier(
    backbone=backbone_name,
    num_classes=data.num_classes,
    metrics=[MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["val"],
    )],
    transform_config=config,
)

Utility Functions

get_transforms_from_backbone

Create transforms using model-specific normalization:

from autotimm.data import get_transforms_from_backbone, TransformConfig

config = TransformConfig(preset="randaugment", image_size=384)

# Training transforms
train_transform = get_transforms_from_backbone(
    backbone="efficientnet_b4",
    transform_config=config,
    is_train=True,
)

# Evaluation transforms
eval_transform = get_transforms_from_backbone(
    backbone="efficientnet_b4",
    transform_config=config,
    is_train=False,
)

create_inference_transform

Convenience function for inference:

from autotimm.data import create_inference_transform

# Quick inference transform with model normalization
transform = create_inference_transform("resnet50")
tensor = transform(pil_image)

# With custom config
from autotimm import TransformConfig
config = TransformConfig(image_size=384)
transform = create_inference_transform("efficientnet_b4", config)

resolve_backbone_data_config

Get model's pretrained data configuration:

from autotimm.data import resolve_backbone_data_config

config = resolve_backbone_data_config("vit_base_patch16_224")
print(config)
# {'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'input_size': (3, 224, 224), ...}

# With overrides
config = resolve_backbone_data_config(
    "resnet50",
    override_mean=(0.5, 0.5, 0.5),
    override_input_size=(3, 384, 384),
)

Backend Comparison

Torchvision

Pros:

  • Native PyTorch integration
  • No additional dependencies
  • PIL-based (good for web images)
  • Faster for simple transforms

Cons:

  • Limited augmentation variety
  • No spatial transforms that preserve bboxes

Best for: Standard image classification

Albumentations

Pros:

  • Rich augmentation library
  • Faster for complex transforms (OpenCV backend)
  • Built-in bbox/mask handling
  • Better for detection/segmentation
  • Included by default in AutoTimm

Best for: Object detection, segmentation, advanced augmentation

Switching Backends

# Torchvision (default)
config = TransformConfig(backend="torchvision")

# Albumentations
config = TransformConfig(backend="albumentations")

Best Practices

1. Use TransformConfig for Consistency

# Good: Shared config ensures same preprocessing
config = TransformConfig(preset="randaugment", image_size=384)
data = ImageDataModule(..., transform_config=config, backbone="resnet50")
model = ImageClassifier(..., transform_config=config)

2. Match Training and Inference Preprocessing

# Always use the same normalization for training and inference
model = ImageClassifier(
    backbone="efficientnet_b4",
    transform_config=TransformConfig(use_timm_config=True),  # Uses model's pretrained stats
)

3. Choose Appropriate Augmentation Strength

# Light augmentation for small datasets or transfer learning
config = TransformConfig(preset="light")

# Standard augmentation for most cases
config = TransformConfig(preset="default")

# Strong augmentation for large datasets or preventing overfitting
config = TransformConfig(preset="strong", backend="albumentations")

4. Consider Image Size Trade-offs

# Smaller size: faster training, less GPU memory
config = TransformConfig(image_size=224)

# Larger size: better accuracy, slower training
config = TransformConfig(image_size=384)

Troubleshooting

For transform-related issues, see the Troubleshooting - Augmentation including:

  • Wrong predictions after training
  • Bounding boxes not preserved
  • Albumentations transform errors
  • Normalization mismatches

See Also