Skip to content

ImageDataModule

Lightning data module for image classification datasets.

Overview

ImageDataModule is a PyTorch Lightning data module that supports:

  • Built-in torchvision datasets (CIFAR10, CIFAR100, MNIST, FashionMNIST)
  • Custom ImageFolder datasets
  • Torchvision and albumentations transform backends
  • Automatic validation splits
  • Balanced sampling for imbalanced datasets

API Reference

autotimm.ImageDataModule

Bases: LightningDataModule

Lightning data module for image classification.

Supports three modes:

  1. Folder mode -- point data_dir at a directory with train/, val/, and optionally test/ subdirectories, each containing one sub-folder per class (ImageFolder layout).
  2. Built-in dataset mode -- set dataset_name to a torchvision dataset ("CIFAR10", "CIFAR100", "FashionMNIST", "MNIST") and data_dir to the download root.
  3. CSV mode -- provide train_csv pointing to a CSV file with image_path,label columns. Optionally provide val_csv and test_csv for separate splits.

Parameters:

Name Type Description Default
data_dir str | Path

Root directory for image data or download root.

'./data'
dataset_name str | None

Optional name of a torchvision dataset class.

None
train_csv str | Path | None

Path to training CSV file (CSV mode).

None
val_csv str | Path | None

Path to validation CSV file (CSV mode).

None
test_csv str | Path | None

Path to test CSV file (CSV mode).

None
image_dir str | Path | None

Root directory for resolving image paths in CSV mode. Defaults to the parent directory of train_csv.

None
image_column str | None

Name of the CSV column containing image paths.

None
label_column str | None

Name of the CSV column containing class labels.

None
image_size int

Target image size (square).

224
batch_size int

Batch size for all dataloaders.

32
num_workers int

Number of data-loading workers. Defaults to os.cpu_count().

min(cpu_count() or 4, 4)
val_split float

Fraction of training data used for validation when no explicit val set exists.

0.1
train_transforms Callable | None

Custom training transforms; defaults used if None. Mutually exclusive with augmentation_preset.

None
eval_transforms Callable | None

Custom eval transforms; defaults used if None.

None
augmentation_preset str | None

Name of a built-in augmentation preset. For torchvision: "default", "autoaugment", "randaugment", "trivialaugment". For albumentations: "default", "strong". Ignored when train_transforms is provided.

None
transform_backend str

"torchvision" (PIL-based) or "albumentations" (OpenCV-based). Defaults to "torchvision". When "albumentations" is selected, folder-mode datasets load images with OpenCV and built-in datasets convert PIL images to numpy for the augmentation pipeline.

'torchvision'
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization from timm. Takes precedence over individual transform args.

None
backbone str | Module | None

Optional backbone name or module. Used with transform_config to resolve model-specific normalization (mean, std, input_size).

None
pin_memory bool

Pin memory for GPU transfer.

True
persistent_workers bool

Keep worker processes alive between epochs. Reduces overhead when num_workers > 0.

False
prefetch_factor int | None

Number of batches prefetched per worker.

None
balanced_sampling bool

Use WeightedRandomSampler to counter class imbalance in the training set.

False
Source code in src/autotimm/data/datamodule.py
class ImageDataModule(pl.LightningDataModule):
    """Lightning data module for image classification.

    Supports three modes:

    1. **Folder mode** -- point ``data_dir`` at a directory with ``train/``,
       ``val/``, and optionally ``test/`` subdirectories, each containing one
       sub-folder per class (ImageFolder layout).
    2. **Built-in dataset mode** -- set ``dataset_name`` to a torchvision
       dataset (``"CIFAR10"``, ``"CIFAR100"``, ``"FashionMNIST"``, ``"MNIST"``)
       and ``data_dir`` to the download root.
    3. **CSV mode** -- provide ``train_csv`` pointing to a CSV file with
       ``image_path,label`` columns. Optionally provide ``val_csv`` and
       ``test_csv`` for separate splits.

    Parameters:
        data_dir: Root directory for image data or download root.
        dataset_name: Optional name of a torchvision dataset class.
        train_csv: Path to training CSV file (CSV mode).
        val_csv: Path to validation CSV file (CSV mode).
        test_csv: Path to test CSV file (CSV mode).
        image_dir: Root directory for resolving image paths in CSV mode.
            Defaults to the parent directory of ``train_csv``.
        image_column: Name of the CSV column containing image paths.
        label_column: Name of the CSV column containing class labels.
        image_size: Target image size (square).
        batch_size: Batch size for all dataloaders.
        num_workers: Number of data-loading workers. Defaults to ``os.cpu_count()``.
        val_split: Fraction of training data used for validation when
            no explicit val set exists.
        train_transforms: Custom training transforms; defaults used if None.
            Mutually exclusive with ``augmentation_preset``.
        eval_transforms: Custom eval transforms; defaults used if None.
        augmentation_preset: Name of a built-in augmentation preset.
            For ``torchvision``: ``"default"``, ``"autoaugment"``,
            ``"randaugment"``, ``"trivialaugment"``.
            For ``albumentations``: ``"default"``, ``"strong"``.
            Ignored when ``train_transforms`` is provided.
        transform_backend: ``"torchvision"`` (PIL-based) or
            ``"albumentations"`` (OpenCV-based). Defaults to
            ``"torchvision"``. When ``"albumentations"`` is selected,
            folder-mode datasets load images with OpenCV and built-in
            datasets convert PIL images to numpy for the augmentation
            pipeline.
        transform_config: Optional :class:`TransformConfig` for unified transform
            configuration. When provided along with ``backbone``, uses model-specific
            normalization from timm. Takes precedence over individual transform args.
        backbone: Optional backbone name or module. Used with ``transform_config``
            to resolve model-specific normalization (mean, std, input_size).
        pin_memory: Pin memory for GPU transfer.
        persistent_workers: Keep worker processes alive between epochs.
            Reduces overhead when ``num_workers > 0``.
        prefetch_factor: Number of batches prefetched per worker.
        balanced_sampling: Use ``WeightedRandomSampler`` to counter
            class imbalance in the training set.
    """

    BUILTIN_DATASETS: dict[str, type] = {
        "CIFAR10": datasets.CIFAR10,
        "CIFAR100": datasets.CIFAR100,
        "FashionMNIST": datasets.FashionMNIST,
        "MNIST": datasets.MNIST,
    }

    def __init__(
        self,
        data_dir: str | Path = "./data",
        dataset_name: str | None = None,
        train_csv: str | Path | None = None,
        val_csv: str | Path | None = None,
        test_csv: str | Path | None = None,
        image_dir: str | Path | None = None,
        image_column: str | None = None,
        label_column: str | None = None,
        image_size: int = 224,
        batch_size: int = 32,
        num_workers: int = min(os.cpu_count() or 4, 4),
        val_split: float = 0.1,
        train_transforms: Callable | None = None,
        eval_transforms: Callable | None = None,
        augmentation_preset: str | None = None,
        transform_backend: str = "torchvision",
        transform_config: TransformConfig | None = None,
        backbone: str | nn.Module | None = None,
        pin_memory: bool = True,
        persistent_workers: bool = False,
        prefetch_factor: int | None = None,
        balanced_sampling: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["backbone"])
        self.data_dir = Path(data_dir)
        self.dataset_name = dataset_name
        self.train_csv = Path(train_csv) if train_csv else None
        self.val_csv = Path(val_csv) if val_csv else None
        self.test_csv = Path(test_csv) if test_csv else None
        self.image_dir = Path(image_dir) if image_dir else None
        self.image_column = image_column
        self.label_column = label_column
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers and num_workers > 0
        self.prefetch_factor = prefetch_factor
        self.balanced_sampling = balanced_sampling

        if transform_backend not in ("torchvision", "albumentations"):
            raise ValueError(
                f"Unknown transform_backend '{transform_backend}'. "
                f"Choose from: torchvision, albumentations."
            )
        self.transform_backend = transform_backend
        self.transform_config = transform_config
        self.backbone = backbone

        # Resolve transforms - TransformConfig takes precedence
        if transform_config is not None and backbone is not None:
            from autotimm.data.timm_transforms import get_transforms_from_backbone

            self.train_transforms = get_transforms_from_backbone(
                backbone=backbone,
                transform_config=transform_config,
                is_train=True,
                task="classification",
            )
            self.eval_transforms = get_transforms_from_backbone(
                backbone=backbone,
                transform_config=transform_config,
                is_train=False,
                task="classification",
            )
        elif train_transforms is not None:
            self.train_transforms = train_transforms
        elif augmentation_preset is not None:
            self.train_transforms = get_train_transforms(
                augmentation_preset,
                backend=transform_backend,
                image_size=image_size,
            )
        else:
            self.train_transforms = self._default_train_transforms()

        if eval_transforms is not None:
            self.eval_transforms = eval_transforms
        else:
            self.eval_transforms = self._default_eval_transforms()

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.num_classes: int | None = None
        self.class_names: list[str] | None = None
        self._train_targets: list[int] | None = None

    def _default_train_transforms(self) -> Callable:
        if self.transform_backend == "albumentations":
            return albu_default_train_transforms(self.image_size)
        return default_train_transforms(self.image_size)

    def _default_eval_transforms(self) -> Callable:
        if self.transform_backend == "albumentations":
            return albu_default_eval_transforms(self.image_size)
        return default_eval_transforms(self.image_size)

    def prepare_data(self) -> None:
        if self.dataset_name and self.dataset_name in self.BUILTIN_DATASETS:
            cls = self.BUILTIN_DATASETS[self.dataset_name]
            cls(str(self.data_dir), train=True, download=True)
            cls(str(self.data_dir), train=False, download=True)

    def setup(self, stage: str | None = None) -> None:
        if self.train_csv is not None:
            self._setup_csv(stage)
        elif self.dataset_name and self.dataset_name in self.BUILTIN_DATASETS:
            self._setup_builtin(stage)
        elif self.transform_backend == "albumentations":
            self._setup_folder_cv2(stage)
        else:
            self._setup_folder(stage)

    def _setup_builtin(self, stage: str | None) -> None:
        cls = self.BUILTIN_DATASETS[self.dataset_name]

        if self.transform_backend == "albumentations":
            wrapper_train = _AlbumentationsBuiltinWrapper(self.train_transforms)
            wrapper_eval = _AlbumentationsBuiltinWrapper(self.eval_transforms)
        else:
            wrapper_train = self.train_transforms
            wrapper_eval = self.eval_transforms

        if stage in ("fit", None):
            full_train = cls(str(self.data_dir), train=True, transform=wrapper_train)
            n_val = int(len(full_train) * self.val_split)
            n_train = len(full_train) - n_val
            self.train_dataset, self.val_dataset = random_split(
                full_train, [n_train, n_val]
            )
            self.num_classes = (
                len(full_train.classes) if hasattr(full_train, "classes") else 10
            )
            self.class_names = (
                list(full_train.classes) if hasattr(full_train, "classes") else None
            )
            self._train_targets = [
                full_train.targets[i] for i in self.train_dataset.indices
            ]
        if stage in ("test", None):
            self.test_dataset = cls(
                str(self.data_dir), train=False, transform=wrapper_eval
            )

    def _setup_folder(self, stage: str | None) -> None:
        train_dir = self.data_dir / "train"
        val_dir = self.data_dir / "val"
        test_dir = self.data_dir / "test"

        if stage in ("fit", None):
            self.train_dataset = datasets.ImageFolder(
                str(train_dir), transform=self.train_transforms
            )
            self.num_classes = len(self.train_dataset.classes)
            self.class_names = list(self.train_dataset.classes)
            self._train_targets = [s[1] for s in self.train_dataset.samples]

            if val_dir.exists():
                self.val_dataset = datasets.ImageFolder(
                    str(val_dir), transform=self.eval_transforms
                )
            else:
                n_val = int(len(self.train_dataset) * self.val_split)
                n_train = len(self.train_dataset) - n_val
                self.train_dataset, self.val_dataset = random_split(
                    self.train_dataset, [n_train, n_val]
                )
        if stage in ("test", None) and test_dir.exists():
            self.test_dataset = datasets.ImageFolder(
                str(test_dir), transform=self.eval_transforms
            )

    def _setup_folder_cv2(self, stage: str | None) -> None:
        from autotimm.data.dataset import ImageFolderCV2

        train_dir = self.data_dir / "train"
        val_dir = self.data_dir / "val"
        test_dir = self.data_dir / "test"

        if stage in ("fit", None):
            self.train_dataset = ImageFolderCV2(
                str(train_dir), transform=self.train_transforms
            )
            self.num_classes = len(self.train_dataset.classes)
            self.class_names = list(self.train_dataset.classes)
            self._train_targets = [s[1] for s in self.train_dataset.samples]

            if val_dir.exists():
                self.val_dataset = ImageFolderCV2(
                    str(val_dir), transform=self.eval_transforms
                )
            else:
                n_val = int(len(self.train_dataset) * self.val_split)
                n_train = len(self.train_dataset) - n_val
                self.train_dataset, self.val_dataset = random_split(
                    self.train_dataset, [n_train, n_val]
                )
        if stage in ("test", None) and test_dir.exists():
            self.test_dataset = ImageFolderCV2(
                str(test_dir), transform=self.eval_transforms
            )

    def _setup_csv(self, stage: str | None) -> None:
        from autotimm.data.dataset import CSVImageDataset

        use_albu = self.transform_backend == "albumentations"
        # Default image_dir to parent of train_csv
        img_dir = self.image_dir or self.train_csv.parent

        if stage in ("fit", None):
            self.train_dataset = CSVImageDataset(
                csv_path=self.train_csv,
                image_dir=img_dir,
                image_column=self.image_column,
                label_column=self.label_column,
                transform=self.train_transforms,
                use_albumentations=use_albu,
            )
            self.num_classes = self.train_dataset.num_classes
            self.class_names = list(self.train_dataset.classes)
            self._train_targets = [s[1] for s in self.train_dataset.samples]

            if self.val_csv is not None:
                self.val_dataset = CSVImageDataset(
                    csv_path=self.val_csv,
                    image_dir=img_dir,
                    image_column=self.image_column,
                    label_column=self.label_column,
                    transform=self.eval_transforms,
                    use_albumentations=use_albu,
                )
            else:
                n_val = int(len(self.train_dataset) * self.val_split)
                n_train = len(self.train_dataset) - n_val
                self.train_dataset, self.val_dataset = random_split(
                    self.train_dataset, [n_train, n_val]
                )

        if stage in ("test", None) and self.test_csv is not None:
            self.test_dataset = CSVImageDataset(
                csv_path=self.test_csv,
                image_dir=img_dir,
                image_column=self.image_column,
                label_column=self.label_column,
                transform=self.eval_transforms,
                use_albumentations=use_albu,
            )

    def _make_sampler(self) -> WeightedRandomSampler | None:
        if not self.balanced_sampling or self._train_targets is None:
            return None

        counts = Counter(self._train_targets)
        weight_per_class = {cls: 1.0 / cnt for cls, cnt in counts.items()}
        sample_weights = [weight_per_class[t] for t in self._train_targets]
        return WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True,
        )

    def _loader_kwargs(self) -> dict:
        kwargs: dict = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": self.pin_memory,
            "persistent_workers": self.persistent_workers,
        }
        if self.prefetch_factor is not None and self.num_workers > 0:
            kwargs["prefetch_factor"] = self.prefetch_factor
        return kwargs

    def train_dataloader(self) -> DataLoader:
        sampler = self._make_sampler()
        return DataLoader(
            self.train_dataset,
            shuffle=sampler is None,
            sampler=sampler,
            **self._loader_kwargs(),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            shuffle=False,
            **self._loader_kwargs(),
        )

    def test_dataloader(self) -> DataLoader:
        if self.test_dataset is None:
            raise RuntimeError(
                "No test split found. Provide a 'test/' directory or use a built-in dataset."
            )
        return DataLoader(
            self.test_dataset,
            shuffle=False,
            **self._loader_kwargs(),
        )

__init__

__init__(data_dir: str | Path = './data', dataset_name: str | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None, image_dir: str | Path | None = None, image_column: str | None = None, label_column: str | None = None, image_size: int = 224, batch_size: int = 32, num_workers: int = min(os.cpu_count() or 4, 4), val_split: float = 0.1, train_transforms: Callable | None = None, eval_transforms: Callable | None = None, augmentation_preset: str | None = None, transform_backend: str = 'torchvision', transform_config: TransformConfig | None = None, backbone: str | Module | None = None, pin_memory: bool = True, persistent_workers: bool = False, prefetch_factor: int | None = None, balanced_sampling: bool = False)
Source code in src/autotimm/data/datamodule.py
def __init__(
    self,
    data_dir: str | Path = "./data",
    dataset_name: str | None = None,
    train_csv: str | Path | None = None,
    val_csv: str | Path | None = None,
    test_csv: str | Path | None = None,
    image_dir: str | Path | None = None,
    image_column: str | None = None,
    label_column: str | None = None,
    image_size: int = 224,
    batch_size: int = 32,
    num_workers: int = min(os.cpu_count() or 4, 4),
    val_split: float = 0.1,
    train_transforms: Callable | None = None,
    eval_transforms: Callable | None = None,
    augmentation_preset: str | None = None,
    transform_backend: str = "torchvision",
    transform_config: TransformConfig | None = None,
    backbone: str | nn.Module | None = None,
    pin_memory: bool = True,
    persistent_workers: bool = False,
    prefetch_factor: int | None = None,
    balanced_sampling: bool = False,
):
    super().__init__()
    self.save_hyperparameters(ignore=["backbone"])
    self.data_dir = Path(data_dir)
    self.dataset_name = dataset_name
    self.train_csv = Path(train_csv) if train_csv else None
    self.val_csv = Path(val_csv) if val_csv else None
    self.test_csv = Path(test_csv) if test_csv else None
    self.image_dir = Path(image_dir) if image_dir else None
    self.image_column = image_column
    self.label_column = label_column
    self.image_size = image_size
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.val_split = val_split
    self.pin_memory = pin_memory
    self.persistent_workers = persistent_workers and num_workers > 0
    self.prefetch_factor = prefetch_factor
    self.balanced_sampling = balanced_sampling

    if transform_backend not in ("torchvision", "albumentations"):
        raise ValueError(
            f"Unknown transform_backend '{transform_backend}'. "
            f"Choose from: torchvision, albumentations."
        )
    self.transform_backend = transform_backend
    self.transform_config = transform_config
    self.backbone = backbone

    # Resolve transforms - TransformConfig takes precedence
    if transform_config is not None and backbone is not None:
        from autotimm.data.timm_transforms import get_transforms_from_backbone

        self.train_transforms = get_transforms_from_backbone(
            backbone=backbone,
            transform_config=transform_config,
            is_train=True,
            task="classification",
        )
        self.eval_transforms = get_transforms_from_backbone(
            backbone=backbone,
            transform_config=transform_config,
            is_train=False,
            task="classification",
        )
    elif train_transforms is not None:
        self.train_transforms = train_transforms
    elif augmentation_preset is not None:
        self.train_transforms = get_train_transforms(
            augmentation_preset,
            backend=transform_backend,
            image_size=image_size,
        )
    else:
        self.train_transforms = self._default_train_transforms()

    if eval_transforms is not None:
        self.eval_transforms = eval_transforms
    else:
        self.eval_transforms = self._default_eval_transforms()

    self.train_dataset = None
    self.val_dataset = None
    self.test_dataset = None
    self.num_classes: int | None = None
    self.class_names: list[str] | None = None
    self._train_targets: list[int] | None = None

prepare_data

prepare_data() -> None
Source code in src/autotimm/data/datamodule.py
def prepare_data(self) -> None:
    if self.dataset_name and self.dataset_name in self.BUILTIN_DATASETS:
        cls = self.BUILTIN_DATASETS[self.dataset_name]
        cls(str(self.data_dir), train=True, download=True)
        cls(str(self.data_dir), train=False, download=True)

setup

setup(stage: str | None = None) -> None
Source code in src/autotimm/data/datamodule.py
def setup(self, stage: str | None = None) -> None:
    if self.train_csv is not None:
        self._setup_csv(stage)
    elif self.dataset_name and self.dataset_name in self.BUILTIN_DATASETS:
        self._setup_builtin(stage)
    elif self.transform_backend == "albumentations":
        self._setup_folder_cv2(stage)
    else:
        self._setup_folder(stage)

train_dataloader

train_dataloader() -> DataLoader
Source code in src/autotimm/data/datamodule.py
def train_dataloader(self) -> DataLoader:
    sampler = self._make_sampler()
    return DataLoader(
        self.train_dataset,
        shuffle=sampler is None,
        sampler=sampler,
        **self._loader_kwargs(),
    )

val_dataloader

val_dataloader() -> DataLoader
Source code in src/autotimm/data/datamodule.py
def val_dataloader(self) -> DataLoader:
    return DataLoader(
        self.val_dataset,
        shuffle=False,
        **self._loader_kwargs(),
    )

test_dataloader

test_dataloader() -> DataLoader
Source code in src/autotimm/data/datamodule.py
def test_dataloader(self) -> DataLoader:
    if self.test_dataset is None:
        raise RuntimeError(
            "No test split found. Provide a 'test/' directory or use a built-in dataset."
        )
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        **self._loader_kwargs(),
    )

Usage Examples

Built-in Dataset

from autotimm import ImageDataModule

data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    image_size=224,
    batch_size=64,
)

Custom Folder Dataset

data = ImageDataModule(
    data_dir="./my_dataset",
    image_size=384,
    batch_size=32,
)
data.setup("fit")
print(f"Classes: {data.num_classes}")
print(f"Class names: {data.class_names}")

With Albumentations

data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    transform_backend="albumentations",
    augmentation_preset="strong",
)

With Augmentation Preset

data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    augmentation_preset="randaugment",
)

With Custom Transforms

from torchvision import transforms

custom_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

data = ImageDataModule(
    data_dir="./dataset",
    train_transforms=custom_train,
)

With Balanced Sampling

data = ImageDataModule(
    data_dir="./imbalanced_dataset",
    balanced_sampling=True,
)

Performance Optimization

data = ImageDataModule(
    data_dir="./dataset",
    batch_size=64,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

With TransformConfig (Model-Specific Normalization)

Use TransformConfig with a backbone to get model-specific normalization:

from autotimm import ImageDataModule, TransformConfig

# Create shared config
config = TransformConfig(
    preset="randaugment",
    image_size=384,
    use_timm_config=True,  # Use model's pretrained mean/std
)

data = ImageDataModule(
    data_dir="./dataset",
    transform_config=config,
    backbone="efficientnet_b4",  # Required for model-specific normalization
)

Shared Config Between Model and Data

from autotimm import ImageClassifier, ImageDataModule, TransformConfig, MetricConfig

# Shared config ensures same preprocessing
config = TransformConfig(preset="randaugment", image_size=384)
backbone_name = "efficientnet_b4"

# DataModule uses model's normalization
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    transform_config=config,
    backbone=backbone_name,
)
data.setup("fit")

# Model uses same config for inference preprocessing
metrics = [MetricConfig(
    name="accuracy",
    backend="torchmetrics",
    metric_class="Accuracy",
    params={"task": "multiclass"},
    stages=["val"],
)]

model = ImageClassifier(
    backbone=backbone_name,
    num_classes=data.num_classes,
    metrics=metrics,
    transform_config=config,
)

Parameters

Parameter Type Default Description
data_dir str \| Path "./data" Root directory
dataset_name str \| None None Built-in dataset name
image_size int 224 Target image size
batch_size int 32 Batch size
num_workers int 4 Data loading workers
val_split float 0.1 Validation split fraction
train_transforms Callable \| None None Custom train transforms
eval_transforms Callable \| None None Custom eval transforms
augmentation_preset str \| None None Preset name
transform_backend str "torchvision" "torchvision" or "albumentations"
transform_config TransformConfig \| None None Unified transform configuration
backbone str \| nn.Module \| None None Backbone for model-specific normalization
pin_memory bool True Pin memory for GPU
persistent_workers bool False Keep workers alive
prefetch_factor int \| None None Prefetch batches
balanced_sampling bool False Weighted sampling

Attributes

Attribute Type Description
num_classes int \| None Number of classes (after setup)
class_names list[str] \| None Class names (after setup)
train_dataset Dataset \| None Training dataset (after setup)
val_dataset Dataset \| None Validation dataset (after setup)
test_dataset Dataset \| None Test dataset (after setup)

Built-in Datasets

Name Classes Image Size
CIFAR10 10 32x32
CIFAR100 100 32x32
MNIST 10 28x28
FashionMNIST 10 28x28

Augmentation Presets

Torchvision

Preset Description
default RandomResizedCrop, HorizontalFlip, ColorJitter
autoaugment AutoAugment (ImageNet policy)
randaugment RandAugment (2 ops, magnitude 9)
trivialaugment TrivialAugmentWide

Albumentations

Preset Description
default RandomResizedCrop, HorizontalFlip, ColorJitter
strong Affine, blur/noise, ColorJitter, CoarseDropout

Folder Structure

dataset/
├── train/
│   ├── class_a/
│   │   ├── img1.jpg
│   │   └── img2.jpg
│   └── class_b/
│       └── img3.jpg
├── val/           # Optional (uses val_split if missing)
│   ├── class_a/
│   │   └── img4.jpg
│   └── class_b/
│       └── img5.jpg
└── test/          # Optional
    ├── class_a/
    │   └── img6.jpg
    └── class_b/
        └── img7.jpg

CSV Data Loading

For CSV-based data loading, ImageDataModule supports train_csv, val_csv, and test_csv parameters for single-label classification.

For multi-label classification from CSV files, see MultiLabelImageDataModule.

For direct CSV dataset usage (without DataModules), see the CSV Data Loading API documentation.

CSV Classification Example

from autotimm import ImageDataModule

data = ImageDataModule(
    train_csv="train.csv",
    val_csv="val.csv",
    image_dir="./images",
    image_size=224,
    batch_size=32,
)

CSV Format:

image_path,label
img001.jpg,cat
img002.jpg,dog

See CSV Data API for detailed CSV format specification.

See Also