Skip to content

DetectionDataModule

Lightning data module for object detection datasets in COCO format.

Overview

DetectionDataModule is a PyTorch Lightning data module for object detection that supports:

  • COCO format datasets with automatic annotation loading
  • Torchvision and albumentations transform backends
  • Built-in augmentation presets optimized for detection
  • Efficient collation for variable-sized objects per image
  • Multi-worker data loading with prefetching

API Reference

autotimm.DetectionDataModule

Bases: LightningDataModule

Lightning data module for object detection.

Supports two modes:

  1. COCO mode (default) -- expects COCO-style directory structure::

    data_dir/ train2017/ # Training images val2017/ # Validation images annotations/ instances_train2017.json instances_val2017.json

  2. CSV mode -- provide train_csv pointing to a CSV file with columns image_path,x_min,y_min,x_max,y_max,label.

Parameters:

Name Type Description Default
data_dir str | Path

Root directory containing images and annotations.

'./coco'
train_images_dir str | Path | None

Path to training images. Defaults to data_dir/train2017.

None
val_images_dir str | Path | None

Path to validation images. Defaults to data_dir/val2017.

None
train_ann_file str | Path | None

Path to train annotations. Defaults to data_dir/annotations/instances_train2017.json.

None
val_ann_file str | Path | None

Path to val annotations. Defaults to data_dir/annotations/instances_val2017.json.

None
test_images_dir str | Path | None

Optional path to test images.

None
test_ann_file str | Path | None

Optional path to test annotations.

None
train_csv str | Path | None

Path to training CSV file (CSV mode).

None
val_csv str | Path | None

Path to validation CSV file (CSV mode).

None
test_csv str | Path | None

Path to test CSV file (CSV mode).

None
image_dir str | Path | None

Root directory for resolving image paths in CSV mode.

None
image_column str

CSV column name for image paths.

'image_path'
bbox_columns list[str] | None

CSV column names for bbox coordinates.

None
label_column str

CSV column name for class labels.

'label'
image_size int

Target image size (square).

640
batch_size int

Batch size for all dataloaders.

16
num_workers int

Number of data-loading workers. Defaults to os.cpu_count().

min(cpu_count() or 4, 4)
train_transforms Callable | None

Custom training transforms. Must include bbox_params.

None
eval_transforms Callable | None

Custom eval transforms. Must include bbox_params.

None
augmentation_preset str

Preset name ("default", "strong"). Ignored when train_transforms is provided.

'default'
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization from timm. Takes precedence over individual transform args.

None
backbone str | Module | None

Optional backbone name or module. Used with transform_config to resolve model-specific normalization (mean, std, input_size).

None
pin_memory bool

Pin memory for GPU transfer.

True
persistent_workers bool

Keep worker processes alive between epochs.

False
prefetch_factor int | None

Number of batches prefetched per worker.

None
min_bbox_area float

Minimum bbox area to include in training.

0.0
class_ids list[int] | None

Optional list of class IDs to filter.

None
Source code in src/autotimm/data/detection_datamodule.py
class DetectionDataModule(pl.LightningDataModule):
    """Lightning data module for object detection.

    Supports two modes:

    1. **COCO mode** (default) -- expects COCO-style directory structure::

        data_dir/
          train2017/           # Training images
          val2017/             # Validation images
          annotations/
            instances_train2017.json
            instances_val2017.json

    2. **CSV mode** -- provide ``train_csv`` pointing to a CSV file with
       columns ``image_path,x_min,y_min,x_max,y_max,label``.

    Parameters:
        data_dir: Root directory containing images and annotations.
        train_images_dir: Path to training images. Defaults to data_dir/train2017.
        val_images_dir: Path to validation images. Defaults to data_dir/val2017.
        train_ann_file: Path to train annotations. Defaults to
            data_dir/annotations/instances_train2017.json.
        val_ann_file: Path to val annotations. Defaults to
            data_dir/annotations/instances_val2017.json.
        test_images_dir: Optional path to test images.
        test_ann_file: Optional path to test annotations.
        train_csv: Path to training CSV file (CSV mode).
        val_csv: Path to validation CSV file (CSV mode).
        test_csv: Path to test CSV file (CSV mode).
        image_dir: Root directory for resolving image paths in CSV mode.
        image_column: CSV column name for image paths.
        bbox_columns: CSV column names for bbox coordinates.
        label_column: CSV column name for class labels.
        image_size: Target image size (square).
        batch_size: Batch size for all dataloaders.
        num_workers: Number of data-loading workers. Defaults to ``os.cpu_count()``.
        train_transforms: Custom training transforms. Must include bbox_params.
        eval_transforms: Custom eval transforms. Must include bbox_params.
        augmentation_preset: Preset name (``"default"``, ``"strong"``).
            Ignored when train_transforms is provided.
        transform_config: Optional :class:`TransformConfig` for unified transform
            configuration. When provided along with ``backbone``, uses model-specific
            normalization from timm. Takes precedence over individual transform args.
        backbone: Optional backbone name or module. Used with ``transform_config``
            to resolve model-specific normalization (mean, std, input_size).
        pin_memory: Pin memory for GPU transfer.
        persistent_workers: Keep worker processes alive between epochs.
        prefetch_factor: Number of batches prefetched per worker.
        min_bbox_area: Minimum bbox area to include in training.
        class_ids: Optional list of class IDs to filter.
    """

    def __init__(
        self,
        data_dir: str | Path = "./coco",
        train_images_dir: str | Path | None = None,
        val_images_dir: str | Path | None = None,
        train_ann_file: str | Path | None = None,
        val_ann_file: str | Path | None = None,
        test_images_dir: str | Path | None = None,
        test_ann_file: str | Path | None = None,
        train_csv: str | Path | None = None,
        val_csv: str | Path | None = None,
        test_csv: str | Path | None = None,
        image_dir: str | Path | None = None,
        image_column: str = "image_path",
        bbox_columns: list[str] | None = None,
        label_column: str = "label",
        image_size: int = 640,
        batch_size: int = 16,
        num_workers: int = min(os.cpu_count() or 4, 4),
        train_transforms: Callable | None = None,
        eval_transforms: Callable | None = None,
        augmentation_preset: str = "default",
        transform_config: TransformConfig | None = None,
        backbone: str | nn.Module | None = None,
        pin_memory: bool = True,
        persistent_workers: bool = False,
        prefetch_factor: int | None = None,
        min_bbox_area: float = 0.0,
        class_ids: list[int] | None = None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["backbone"])

        self.data_dir = Path(data_dir)
        self.train_csv = Path(train_csv) if train_csv else None
        self.val_csv = Path(val_csv) if val_csv else None
        self.test_csv = Path(test_csv) if test_csv else None
        self.image_dir = Path(image_dir) if image_dir else None
        self.image_column = image_column
        self.bbox_columns = bbox_columns
        self.label_column = label_column
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers and num_workers > 0
        self.prefetch_factor = prefetch_factor
        self.min_bbox_area = min_bbox_area
        self.class_ids = class_ids

        # Set default paths
        self.train_images_dir = (
            Path(train_images_dir) if train_images_dir else self.data_dir / "train2017"
        )
        self.val_images_dir = (
            Path(val_images_dir) if val_images_dir else self.data_dir / "val2017"
        )
        self.train_ann_file = (
            Path(train_ann_file)
            if train_ann_file
            else self.data_dir / "annotations" / "instances_train2017.json"
        )
        self.val_ann_file = (
            Path(val_ann_file)
            if val_ann_file
            else self.data_dir / "annotations" / "instances_val2017.json"
        )
        self.test_images_dir = Path(test_images_dir) if test_images_dir else None
        self.test_ann_file = Path(test_ann_file) if test_ann_file else None
        self.transform_config = transform_config
        self.backbone = backbone

        # Resolve transforms - TransformConfig takes precedence
        if transform_config is not None and backbone is not None:
            from autotimm.data.timm_transforms import get_transforms_from_backbone

            self.train_transforms = get_transforms_from_backbone(
                backbone=backbone,
                transform_config=transform_config,
                is_train=True,
                task="detection",
            )
            self.eval_transforms = get_transforms_from_backbone(
                backbone=backbone,
                transform_config=transform_config,
                is_train=False,
                task="detection",
            )
        elif train_transforms is not None:
            self.train_transforms = train_transforms
        else:
            self.train_transforms = get_detection_transforms(
                preset=augmentation_preset,
                image_size=image_size,
                is_train=True,
            )

        if eval_transforms is not None:
            self.eval_transforms = eval_transforms
        else:
            self.eval_transforms = detection_eval_transforms(image_size=image_size)

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.num_classes: int | None = None
        self.class_names: list[str] | None = None

    def setup(self, stage: str | None = None) -> None:
        if self.train_csv is not None:
            self._setup_csv(stage)
        else:
            self._setup_coco(stage)

    def _setup_csv(self, stage: str | None) -> None:
        img_dir = self.image_dir or self.train_csv.parent

        if stage in ("fit", None):
            self.train_dataset = CSVDetectionDataset(
                csv_path=self.train_csv,
                image_dir=img_dir,
                image_column=self.image_column,
                bbox_columns=self.bbox_columns,
                label_column=self.label_column,
                transform=self.train_transforms,
                min_bbox_area=self.min_bbox_area,
            )
            self.num_classes = self.train_dataset.num_classes
            self.class_names = self.train_dataset.class_names

            if self.val_csv is not None:
                self.val_dataset = CSVDetectionDataset(
                    csv_path=self.val_csv,
                    image_dir=img_dir,
                    image_column=self.image_column,
                    bbox_columns=self.bbox_columns,
                    label_column=self.label_column,
                    transform=self.eval_transforms,
                    min_bbox_area=0.0,
                )

        if stage in ("test", None) and self.test_csv is not None:
            self.test_dataset = CSVDetectionDataset(
                csv_path=self.test_csv,
                image_dir=img_dir,
                image_column=self.image_column,
                bbox_columns=self.bbox_columns,
                label_column=self.label_column,
                transform=self.eval_transforms,
                min_bbox_area=0.0,
            )
            if self.num_classes is None:
                self.num_classes = self.test_dataset.num_classes
                self.class_names = self.test_dataset.class_names

    def _setup_coco(self, stage: str | None) -> None:
        if stage in ("fit", None):
            self.train_dataset = COCODetectionDataset(
                images_dir=self.train_images_dir,
                annotations_file=self.train_ann_file,
                transform=self.train_transforms,
                min_bbox_area=self.min_bbox_area,
                class_ids=self.class_ids,
            )
            self.val_dataset = COCODetectionDataset(
                images_dir=self.val_images_dir,
                annotations_file=self.val_ann_file,
                transform=self.eval_transforms,
                min_bbox_area=0.0,  # Keep all boxes for evaluation
                class_ids=self.class_ids,
            )
            self.num_classes = self.train_dataset.num_classes
            self.class_names = self.train_dataset.class_names

        if stage in ("test", None) and self.test_images_dir and self.test_ann_file:
            self.test_dataset = COCODetectionDataset(
                images_dir=self.test_images_dir,
                annotations_file=self.test_ann_file,
                transform=self.eval_transforms,
                min_bbox_area=0.0,
                class_ids=self.class_ids,
            )
            if self.num_classes is None:
                self.num_classes = self.test_dataset.num_classes
                self.class_names = self.test_dataset.class_names

        if stage == "validate" and self.val_dataset is None:
            self.val_dataset = COCODetectionDataset(
                images_dir=self.val_images_dir,
                annotations_file=self.val_ann_file,
                transform=self.eval_transforms,
                min_bbox_area=0.0,
                class_ids=self.class_ids,
            )
            self.num_classes = self.val_dataset.num_classes
            self.class_names = self.val_dataset.class_names

    def _loader_kwargs(self) -> dict:
        kwargs: dict = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": self.pin_memory,
            "persistent_workers": self.persistent_workers,
            "collate_fn": detection_collate_fn,
        }
        if self.prefetch_factor is not None and self.num_workers > 0:
            kwargs["prefetch_factor"] = self.prefetch_factor
        return kwargs

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            shuffle=True,
            **self._loader_kwargs(),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            shuffle=False,
            **self._loader_kwargs(),
        )

    def test_dataloader(self) -> DataLoader:
        if self.test_dataset is None:
            raise RuntimeError(
                "No test dataset configured. Provide test_images_dir and test_ann_file."
            )
        return DataLoader(
            self.test_dataset,
            shuffle=False,
            **self._loader_kwargs(),
        )

__init__

__init__(data_dir: str | Path = './coco', train_images_dir: str | Path | None = None, val_images_dir: str | Path | None = None, train_ann_file: str | Path | None = None, val_ann_file: str | Path | None = None, test_images_dir: str | Path | None = None, test_ann_file: str | Path | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None, image_dir: str | Path | None = None, image_column: str = 'image_path', bbox_columns: list[str] | None = None, label_column: str = 'label', image_size: int = 640, batch_size: int = 16, num_workers: int = min(os.cpu_count() or 4, 4), train_transforms: Callable | None = None, eval_transforms: Callable | None = None, augmentation_preset: str = 'default', transform_config: TransformConfig | None = None, backbone: str | Module | None = None, pin_memory: bool = True, persistent_workers: bool = False, prefetch_factor: int | None = None, min_bbox_area: float = 0.0, class_ids: list[int] | None = None)
Source code in src/autotimm/data/detection_datamodule.py
def __init__(
    self,
    data_dir: str | Path = "./coco",
    train_images_dir: str | Path | None = None,
    val_images_dir: str | Path | None = None,
    train_ann_file: str | Path | None = None,
    val_ann_file: str | Path | None = None,
    test_images_dir: str | Path | None = None,
    test_ann_file: str | Path | None = None,
    train_csv: str | Path | None = None,
    val_csv: str | Path | None = None,
    test_csv: str | Path | None = None,
    image_dir: str | Path | None = None,
    image_column: str = "image_path",
    bbox_columns: list[str] | None = None,
    label_column: str = "label",
    image_size: int = 640,
    batch_size: int = 16,
    num_workers: int = min(os.cpu_count() or 4, 4),
    train_transforms: Callable | None = None,
    eval_transforms: Callable | None = None,
    augmentation_preset: str = "default",
    transform_config: TransformConfig | None = None,
    backbone: str | nn.Module | None = None,
    pin_memory: bool = True,
    persistent_workers: bool = False,
    prefetch_factor: int | None = None,
    min_bbox_area: float = 0.0,
    class_ids: list[int] | None = None,
):
    super().__init__()
    self.save_hyperparameters(ignore=["backbone"])

    self.data_dir = Path(data_dir)
    self.train_csv = Path(train_csv) if train_csv else None
    self.val_csv = Path(val_csv) if val_csv else None
    self.test_csv = Path(test_csv) if test_csv else None
    self.image_dir = Path(image_dir) if image_dir else None
    self.image_column = image_column
    self.bbox_columns = bbox_columns
    self.label_column = label_column
    self.image_size = image_size
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.pin_memory = pin_memory
    self.persistent_workers = persistent_workers and num_workers > 0
    self.prefetch_factor = prefetch_factor
    self.min_bbox_area = min_bbox_area
    self.class_ids = class_ids

    # Set default paths
    self.train_images_dir = (
        Path(train_images_dir) if train_images_dir else self.data_dir / "train2017"
    )
    self.val_images_dir = (
        Path(val_images_dir) if val_images_dir else self.data_dir / "val2017"
    )
    self.train_ann_file = (
        Path(train_ann_file)
        if train_ann_file
        else self.data_dir / "annotations" / "instances_train2017.json"
    )
    self.val_ann_file = (
        Path(val_ann_file)
        if val_ann_file
        else self.data_dir / "annotations" / "instances_val2017.json"
    )
    self.test_images_dir = Path(test_images_dir) if test_images_dir else None
    self.test_ann_file = Path(test_ann_file) if test_ann_file else None
    self.transform_config = transform_config
    self.backbone = backbone

    # Resolve transforms - TransformConfig takes precedence
    if transform_config is not None and backbone is not None:
        from autotimm.data.timm_transforms import get_transforms_from_backbone

        self.train_transforms = get_transforms_from_backbone(
            backbone=backbone,
            transform_config=transform_config,
            is_train=True,
            task="detection",
        )
        self.eval_transforms = get_transforms_from_backbone(
            backbone=backbone,
            transform_config=transform_config,
            is_train=False,
            task="detection",
        )
    elif train_transforms is not None:
        self.train_transforms = train_transforms
    else:
        self.train_transforms = get_detection_transforms(
            preset=augmentation_preset,
            image_size=image_size,
            is_train=True,
        )

    if eval_transforms is not None:
        self.eval_transforms = eval_transforms
    else:
        self.eval_transforms = detection_eval_transforms(image_size=image_size)

    self.train_dataset = None
    self.val_dataset = None
    self.test_dataset = None
    self.num_classes: int | None = None
    self.class_names: list[str] | None = None

setup

setup(stage: str | None = None) -> None
Source code in src/autotimm/data/detection_datamodule.py
def setup(self, stage: str | None = None) -> None:
    if self.train_csv is not None:
        self._setup_csv(stage)
    else:
        self._setup_coco(stage)

train_dataloader

train_dataloader() -> DataLoader
Source code in src/autotimm/data/detection_datamodule.py
def train_dataloader(self) -> DataLoader:
    return DataLoader(
        self.train_dataset,
        shuffle=True,
        **self._loader_kwargs(),
    )

val_dataloader

val_dataloader() -> DataLoader
Source code in src/autotimm/data/detection_datamodule.py
def val_dataloader(self) -> DataLoader:
    return DataLoader(
        self.val_dataset,
        shuffle=False,
        **self._loader_kwargs(),
    )

test_dataloader

test_dataloader() -> DataLoader
Source code in src/autotimm/data/detection_datamodule.py
def test_dataloader(self) -> DataLoader:
    if self.test_dataset is None:
        raise RuntimeError(
            "No test dataset configured. Provide test_images_dir and test_ann_file."
        )
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        **self._loader_kwargs(),
    )

Usage Examples

Basic COCO Dataset

from autotimm import DetectionDataModule

data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=16,
)

With Augmentation Preset

data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=16,
    augmentation_preset="strong",  # Enhanced augmentation
)

With Albumentations

data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=16,
    transform_backend="albumentations",
    augmentation_preset="strong",
)

With Custom Transforms

from torchvision import transforms as T

custom_train = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=16,
    train_transforms=custom_train,
)

Performance Optimization

data = DetectionDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=16,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

With TransformConfig (Model-Specific Normalization)

Use TransformConfig with a backbone to get model-specific normalization:

from autotimm import DetectionDataModule, TransformConfig

# Create shared config
config = TransformConfig(
    preset="default",
    image_size=640,
    use_timm_config=True,  # Use model's pretrained mean/std
)

data = DetectionDataModule(
    data_dir="./coco",
    transform_config=config,
    backbone="resnet50",  # Required for model-specific normalization
)

Shared Config Between Model and Data

from autotimm import ObjectDetector, DetectionDataModule, TransformConfig, MetricConfig

# Shared config ensures same preprocessing
config = TransformConfig(preset="default", image_size=640)
backbone_name = "resnet50"

# DataModule uses model's normalization
data = DetectionDataModule(
    data_dir="./coco",
    transform_config=config,
    backbone=backbone_name,
)
data.setup("fit")

# Model uses same config for inference preprocessing
metrics = [MetricConfig(
    name="mAP",
    backend="torchmetrics",
    metric_class="MeanAveragePrecision",
    params={"box_format": "xyxy"},
    stages=["val"],
)]

model = ObjectDetector(
    backbone=backbone_name,
    num_classes=data.num_classes,
    metrics=metrics,
    transform_config=config,
)

Parameters

Parameter Type Default Description
data_dir str \| Path "./coco" Root directory
image_size int 640 Target image size
batch_size int 16 Batch size
num_workers int 4 Data loading workers
train_transforms Callable \| None None Custom train transforms
val_transforms Callable \| None None Custom validation transforms
augmentation_preset str \| None "default" Preset name
transform_backend str "torchvision" "torchvision" or "albumentations"
pin_memory bool True Pin memory for GPU
persistent_workers bool False Keep workers alive
prefetch_factor int \| None None Prefetch batches
transform_config TransformConfig \| None None Unified transform configuration
backbone str \| nn.Module \| None None Backbone for model-specific normalization

Attributes

Attribute Type Description
num_classes int \| None Number of object classes (after setup)
train_dataset Dataset \| None Training dataset (after setup)
val_dataset Dataset \| None Validation dataset (after setup)
test_dataset Dataset \| None Test dataset (after setup)

COCO Format

The data directory should follow this structure:

coco/
├── annotations/
│   ├── instances_train2017.json
│   ├── instances_val2017.json
│   └── instances_test2017.json  # Optional
├── train2017/
│   ├── 000000000001.jpg
│   ├── 000000000002.jpg
│   └── ...
├── val2017/
│   ├── 000000000001.jpg
│   └── ...
└── test2017/          # Optional
    └── ...

Annotation Format

COCO annotations should have this structure:

{
  "images": [
    {
      "id": 1,
      "file_name": "000000000001.jpg",
      "height": 480,
      "width": 640
    }
  ],
  "annotations": [
    {
      "id": 1,
      "image_id": 1,
      "category_id": 1,
      "bbox": [x, y, width, height],
      "area": 12345,
      "iscrowd": 0
    }
  ],
  "categories": [
    {
      "id": 1,
      "name": "person",
      "supercategory": "person"
    }
  ]
}

Augmentation Presets

Torchvision

Preset Description
default RandomHorizontalFlip, ColorJitter, ToTensor
strong Default + RandomPhotometricDistort

Albumentations

Preset Description
default HorizontalFlip, ColorJitter
strong HorizontalFlip, RandomBrightnessContrast, HueSaturationValue, Blur, Noise

Data Output

Each batch contains:

batch = {
    "image": Tensor,      # Shape: (B, 3, H, W)
    "boxes": List[Tensor],   # List of (N, 4) tensors in [x1, y1, x2, y2] format
    "labels": List[Tensor],  # List of (N,) tensors with class indices
}

See Also