Skip to content

TransformConfig

Unified transform configuration for models and data modules.

Overview

TransformConfig provides a single configuration interface for:

  • Transform presets (default, randaugment, autoaugment, etc.)
  • Model-specific normalization from timm
  • Inference-time preprocessing via model.preprocess()
  • Shared configuration between models and data modules

API Reference

autotimm.TransformConfig dataclass

Unified configuration for image transforms.

This dataclass provides a consistent interface for configuring transforms across models and data modules. When used with models, it enables automatic preprocessing using model-specific normalization from timm.

Attributes:

Name Type Description
preset str

Transform preset name. Common presets: - "default": Standard augmentation (random crop, flip, color jitter) - "autoaugment": AutoAugment policy - "randaugment": RandAugment with configurable ops/magnitude - "trivialaugment": TrivialAugmentWide - "strong": Heavy augmentation (albumentations only) - "light": Light augmentation (minimal transforms)

backend Literal['torchvision', 'albumentations']

Transform backend to use. Either "torchvision" (PIL-based) or "albumentations" (OpenCV-based).

image_size int

Target image size (square). This is used for both training (random resized crop) and evaluation (resize + center crop).

use_timm_config bool

If True, get mean/std/input_size from the timm model's data config. This ensures the model receives inputs normalized with the same statistics it was pretrained with.

mean tuple[float, float, float] | None

Override normalization mean. If None and use_timm_config is True, uses the model's pretrained mean. Otherwise defaults to ImageNet.

std tuple[float, float, float] | None

Override normalization std. If None and use_timm_config is True, uses the model's pretrained std. Otherwise defaults to ImageNet.

interpolation str

Interpolation mode for resizing. Common values: "bilinear", "bicubic", "lanczos".

crop_pct float

Center crop percentage for evaluation transforms. For a 224x224 image with crop_pct=0.875, the image is first resized to 256x256 (224/0.875) then center cropped.

Detection-specific options
min_bbox_area float

Minimum bounding box area to keep after transforms.

min_visibility float

Minimum visibility ratio for bboxes (0.0-1.0).

bbox_format str

Bounding box format ("coco", "pascal_voc", "yolo").

Segmentation-specific options
ignore_index int

Label index to ignore in segmentation masks.

Example

from autotimm import ImageClassifier, TransformConfig config = TransformConfig( ... preset="randaugment", ... image_size=384, ... use_timm_config=True, ... ) model = ImageClassifier( ... backbone="efficientnet_b4", ... num_classes=10, ... metrics=[...], ... transform_config=config, ... )

Now model.preprocess() uses the correct normalization

tensor = model.preprocess(pil_image)

Source code in src/autotimm/data/transform_config.py
@dataclass
class TransformConfig:
    """Unified configuration for image transforms.

    This dataclass provides a consistent interface for configuring transforms
    across models and data modules. When used with models, it enables
    automatic preprocessing using model-specific normalization from timm.

    Attributes:
        preset: Transform preset name. Common presets:
            - ``"default"``: Standard augmentation (random crop, flip, color jitter)
            - ``"autoaugment"``: AutoAugment policy
            - ``"randaugment"``: RandAugment with configurable ops/magnitude
            - ``"trivialaugment"``: TrivialAugmentWide
            - ``"strong"``: Heavy augmentation (albumentations only)
            - ``"light"``: Light augmentation (minimal transforms)
        backend: Transform backend to use. Either ``"torchvision"`` (PIL-based)
            or ``"albumentations"`` (OpenCV-based).
        image_size: Target image size (square). This is used for both training
            (random resized crop) and evaluation (resize + center crop).
        use_timm_config: If True, get mean/std/input_size from the timm model's
            data config. This ensures the model receives inputs normalized
            with the same statistics it was pretrained with.
        mean: Override normalization mean. If None and use_timm_config is True,
            uses the model's pretrained mean. Otherwise defaults to ImageNet.
        std: Override normalization std. If None and use_timm_config is True,
            uses the model's pretrained std. Otherwise defaults to ImageNet.
        interpolation: Interpolation mode for resizing. Common values:
            ``"bilinear"``, ``"bicubic"``, ``"lanczos"``.
        crop_pct: Center crop percentage for evaluation transforms.
            For a 224x224 image with crop_pct=0.875, the image is first
            resized to 256x256 (224/0.875) then center cropped.

        Detection-specific options:
        min_bbox_area: Minimum bounding box area to keep after transforms.
        min_visibility: Minimum visibility ratio for bboxes (0.0-1.0).
        bbox_format: Bounding box format (``"coco"``, ``"pascal_voc"``, ``"yolo"``).

        Segmentation-specific options:
        ignore_index: Label index to ignore in segmentation masks.

    Example:
        >>> from autotimm import ImageClassifier, TransformConfig
        >>> config = TransformConfig(
        ...     preset="randaugment",
        ...     image_size=384,
        ...     use_timm_config=True,
        ... )
        >>> model = ImageClassifier(
        ...     backbone="efficientnet_b4",
        ...     num_classes=10,
        ...     metrics=[...],
        ...     transform_config=config,
        ... )
        >>> # Now model.preprocess() uses the correct normalization
        >>> tensor = model.preprocess(pil_image)
    """

    # General transform options
    preset: str = "default"
    backend: Literal["torchvision", "albumentations"] = "torchvision"
    image_size: int = 224
    use_timm_config: bool = True
    mean: tuple[float, float, float] | None = None
    std: tuple[float, float, float] | None = None
    interpolation: str = "bicubic"
    crop_pct: float = 0.875

    # Detection-specific options
    min_bbox_area: float = 0.0
    min_visibility: float = 0.0
    bbox_format: str = "coco"

    # Segmentation-specific options
    ignore_index: int = 255

    # RandAugment-specific options (when preset="randaugment")
    randaugment_num_ops: int = 2
    randaugment_magnitude: int = 9

    def __post_init__(self):
        """Validate configuration values."""
        valid_backends = ("torchvision", "albumentations")
        if self.backend not in valid_backends:
            raise ValueError(
                f"Invalid backend '{self.backend}'. Choose from: {valid_backends}"
            )

        if self.image_size <= 0:
            raise ValueError(f"image_size must be positive, got {self.image_size}")

        if not 0.0 < self.crop_pct <= 1.0:
            raise ValueError(f"crop_pct must be in (0, 1], got {self.crop_pct}")

        if self.mean is not None and len(self.mean) != 3:
            raise ValueError(f"mean must have 3 values, got {len(self.mean)}")

        if self.std is not None and len(self.std) != 3:
            raise ValueError(f"std must have 3 values, got {len(self.std)}")

    def with_overrides(self, **kwargs: object) -> "TransformConfig":
        """Create a new TransformConfig with specified overrides.

        Args:
            **kwargs: Fields to override in the new config.

        Returns:
            New TransformConfig with the overrides applied.

        Example:
            >>> base_config = TransformConfig(image_size=224)
            >>> large_config = base_config.with_overrides(image_size=384)
        """
        from dataclasses import asdict

        current = asdict(self)
        current.update(kwargs)
        return TransformConfig(**current)

    def to_dict(self) -> dict:
        """Convert config to a dictionary.

        Returns:
            Dictionary representation of the config.
        """
        from dataclasses import asdict

        return asdict(self)

__init__

__init__(preset: str = 'default', backend: Literal['torchvision', 'albumentations'] = 'torchvision', image_size: int = 224, use_timm_config: bool = True, mean: tuple[float, float, float] | None = None, std: tuple[float, float, float] | None = None, interpolation: str = 'bicubic', crop_pct: float = 0.875, min_bbox_area: float = 0.0, min_visibility: float = 0.0, bbox_format: str = 'coco', ignore_index: int = 255, randaugment_num_ops: int = 2, randaugment_magnitude: int = 9) -> None

with_overrides

with_overrides(**kwargs: object) -> 'TransformConfig'

Create a new TransformConfig with specified overrides.

Parameters:

Name Type Description Default
**kwargs object

Fields to override in the new config.

{}

Returns:

Type Description
'TransformConfig'

New TransformConfig with the overrides applied.

Example

base_config = TransformConfig(image_size=224) large_config = base_config.with_overrides(image_size=384)

Source code in src/autotimm/data/transform_config.py
def with_overrides(self, **kwargs: object) -> "TransformConfig":
    """Create a new TransformConfig with specified overrides.

    Args:
        **kwargs: Fields to override in the new config.

    Returns:
        New TransformConfig with the overrides applied.

    Example:
        >>> base_config = TransformConfig(image_size=224)
        >>> large_config = base_config.with_overrides(image_size=384)
    """
    from dataclasses import asdict

    current = asdict(self)
    current.update(kwargs)
    return TransformConfig(**current)

to_dict

to_dict() -> dict

Convert config to a dictionary.

Returns:

Type Description
dict

Dictionary representation of the config.

Source code in src/autotimm/data/transform_config.py
def to_dict(self) -> dict:
    """Convert config to a dictionary.

    Returns:
        Dictionary representation of the config.
    """
    from dataclasses import asdict

    return asdict(self)

Usage Examples

Basic Usage

from autotimm import ImageClassifier, TransformConfig, MetricConfig

# Create model with transform config
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=[
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["val"],
        ),
    ],
    transform_config=TransformConfig(),  # Uses model's pretrained normalization
)

# Preprocess raw images for inference
from PIL import Image
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Returns (1, 3, 224, 224)
output = model(tensor)

With Custom Image Size

config = TransformConfig(
    image_size=384,
    use_timm_config=True,  # Use model's mean/std
)

model = ImageClassifier(
    backbone="efficientnet_b4",
    num_classes=100,
    metrics=metrics,
    transform_config=config,
)

# Preprocessing now uses 384x384
tensor = model.preprocess(image)  # Returns (1, 3, 384, 384)

With Augmentation Preset

config = TransformConfig(
    preset="randaugment",
    image_size=224,
    randaugment_num_ops=2,
    randaugment_magnitude=9,
)

# For training with the same config
datamodule = ImageDataModule(
    data_dir="./data",
    transform_config=config,
    backbone="resnet50",
)

Custom Normalization

# Override model's normalization (not recommended for pretrained models)
config = TransformConfig(
    mean=(0.5, 0.5, 0.5),
    std=(0.5, 0.5, 0.5),
    use_timm_config=False,  # Use our mean/std instead
)

Shared Config for Model and DataModule

from autotimm import ImageClassifier, ImageDataModule, TransformConfig

# Create shared config
config = TransformConfig(
    preset="randaugment",
    image_size=384,
    use_timm_config=True,
)

backbone_name = "efficientnet_b4"

# DataModule uses the same transforms as model
datamodule = ImageDataModule(
    data_dir="./data",
    transform_config=config,
    backbone=backbone_name,
)

# Model uses same preprocessing
model = ImageClassifier(
    backbone=backbone_name,
    num_classes=datamodule.num_classes,
    metrics=metrics,
    transform_config=config,
)

Get Model's Data Config

model = ImageClassifier(
    backbone="vit_base_patch16_224",
    num_classes=100,
    metrics=metrics,
    transform_config=TransformConfig(),
)

# Get the normalization config
data_config = model.get_data_config()
print(f"Mean: {data_config['mean']}")      # (0.5, 0.5, 0.5) for ViT
print(f"Std: {data_config['std']}")        # (0.5, 0.5, 0.5) for ViT
print(f"Input size: {data_config['input_size']}")  # (3, 224, 224)

Parameters

Parameter Type Default Description
preset str "default" Augmentation preset name
backend str "torchvision" "torchvision" or "albumentations"
image_size int 224 Target image size (square)
use_timm_config bool True Use model's pretrained mean/std
mean tuple[float, ...] None Override normalization mean
std tuple[float, ...] None Override normalization std
interpolation str "bicubic" Resize interpolation mode
crop_pct float 0.875 Center crop percentage for eval
min_bbox_area float 0.0 Detection: min bbox area
min_visibility float 0.0 Detection: min visibility
bbox_format str "coco" Detection: bbox format
ignore_index int 255 Segmentation: ignore index
randaugment_num_ops int 2 RandAugment: number of ops
randaugment_magnitude int 9 RandAugment: magnitude

Augmentation Presets

Torchvision Backend

Preset Description
default RandomResizedCrop, HorizontalFlip, ColorJitter
autoaugment AutoAugment (ImageNet policy)
randaugment RandAugment with configurable ops/magnitude
trivialaugment TrivialAugmentWide
light RandomResizedCrop, HorizontalFlip only

Albumentations Backend

Preset Description
default RandomResizedCrop, HorizontalFlip, ColorJitter
strong Affine, blur/noise, ColorJitter, CoarseDropout
light RandomResizedCrop, HorizontalFlip only

Model Preprocessing Methods

When a model is created with transform_config, it gains these methods:

model.preprocess(images, is_train=False)

Preprocess raw images for model inference.

from PIL import Image
import numpy as np

# Single PIL image
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # (1, 3, H, W)

# List of PIL images
images = [Image.open(f"img{i}.jpg") for i in range(4)]
tensor = model.preprocess(images)  # (4, 3, H, W)

# Numpy array
img_np = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
tensor = model.preprocess(img_np)  # (1, 3, H, W)

# Tensor (passes through unchanged)
tensor_in = torch.randn(2, 3, 224, 224)
tensor_out = model.preprocess(tensor_in)  # (2, 3, 224, 224)

model.get_data_config()

Get the model's normalization configuration.

config = model.get_data_config()
# Returns:
# {
#     'mean': (0.485, 0.456, 0.406),
#     'std': (0.229, 0.224, 0.225),
#     'input_size': (3, 224, 224),
#     'interpolation': 'bicubic',
#     'crop_pct': 0.875,
# }

model.get_transform(is_train=False)

Get the transform pipeline directly.

eval_transform = model.get_transform(is_train=False)
train_transform = model.get_transform(is_train=True)

Utility Functions

resolve_backbone_data_config

Get model-specific preprocessing config from timm.

from autotimm import resolve_backbone_data_config

config = resolve_backbone_data_config("efficientnet_b0")
print(config["mean"])       # (0.485, 0.456, 0.406)
print(config["std"])        # (0.229, 0.224, 0.225)
print(config["input_size"]) # (3, 224, 224)

# With overrides
config = resolve_backbone_data_config(
    "efficientnet_b0",
    override_mean=(0.5, 0.5, 0.5),
    override_std=(0.5, 0.5, 0.5),
)

get_transforms_from_backbone

Create transforms using model-specific normalization.

from autotimm import get_transforms_from_backbone, TransformConfig

config = TransformConfig(preset="randaugment", image_size=384)

train_transforms = get_transforms_from_backbone(
    backbone="efficientnet_b4",
    transform_config=config,
    is_train=True,
)

eval_transforms = get_transforms_from_backbone(
    backbone="efficientnet_b4",
    transform_config=config,
    is_train=False,
)

create_inference_transform

Convenience function for creating inference transforms.

from autotimm import create_inference_transform

# Simple usage
transform = create_inference_transform("resnet50")
tensor = transform(pil_image)

# With custom config
config = TransformConfig(image_size=384)
transform = create_inference_transform("resnet50", transform_config=config)

list_transform_presets

List available transform presets for a given backend.

from autotimm import list_transform_presets

# List all torchvision presets
presets = list_transform_presets(backend="torchvision")
print(presets)
# ['default', 'autoaugment', 'randaugment', 'trivialaugment', 'light']

# List albumentations presets
presets = list_transform_presets(backend="albumentations")
print(presets)
# ['default', 'strong', 'light']

# Get preset details
presets = list_transform_presets(backend="torchvision", verbose=True)
for name, description in presets:
    print(f"{name}: {description}")
# default: RandomResizedCrop, HorizontalFlip, ColorJitter
# autoaugment: AutoAugment (ImageNet policy)
# randaugment: RandAugment with configurable ops/magnitude
# trivialaugment: TrivialAugmentWide
# light: RandomResizedCrop, HorizontalFlip only

Integration with DataModules

All AutoTimm DataModules support transform_config and backbone parameters:

from autotimm import (
    ImageDataModule,
    DetectionDataModule,
    SegmentationDataModule,
    InstanceSegmentationDataModule,
    TransformConfig,
)

config = TransformConfig(preset="randaugment", image_size=384)

# Classification
data = ImageDataModule(
    data_dir="./data",
    transform_config=config,
    backbone="efficientnet_b4",
)

# Detection
data = DetectionDataModule(
    data_dir="./coco",
    transform_config=config,
    backbone="resnet50",
)

# Segmentation
data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",
    transform_config=config,
    backbone="resnet50",
)

# Instance Segmentation
data = InstanceSegmentationDataModule(
    data_dir="./coco",
    transform_config=config,
    backbone="resnet50",
)

Best Practices

1. Use use_timm_config=True (Default)

Always use the model's pretrained normalization for best results:

# Good - uses model's pretrained normalization
config = TransformConfig(use_timm_config=True)

# Not recommended - may hurt pretrained model performance
config = TransformConfig(
    mean=(0.5, 0.5, 0.5),
    std=(0.5, 0.5, 0.5),
    use_timm_config=False,
)

2. Share Config Between Model and DataModule

config = TransformConfig(preset="randaugment", image_size=384)
backbone = "efficientnet_b4"

# Same normalization for training and inference
datamodule = ImageDataModule(..., transform_config=config, backbone=backbone)
model = ImageClassifier(..., transform_config=config)

3. Use preprocess() for Inference

# Simple and correct - uses model's exact preprocessing
model = ImageClassifier(..., transform_config=TransformConfig())
tensor = model.preprocess(pil_image)
output = model(tensor)

See Also