Skip to content

Segmentation API Reference

Complete API documentation for semantic and instance segmentation tasks.

Task Models

SemanticSegmentor

End-to-end semantic segmentation model with timm backbones, supporting multiple head architectures (DeepLabV3+, FCN) and loss functions (CE, Dice, Focal, Combined).

autotimm.SemanticSegmentor

Bases: PreprocessingMixin, LightningModule

End-to-end semantic segmentation model backed by a timm backbone.

Parameters:

Name Type Description Default
backbone str | FeatureBackboneConfig

A timm model name (str) or a :class:FeatureBackboneConfig.

required
num_classes int

Number of segmentation classes.

required
head_type str

Type of segmentation head ('deeplabv3plus' or 'fcn').

'deeplabv3plus'
loss_fn str | Module | None

Loss function to use. Can be: - A string from the loss registry (e.g., 'dice', 'focal_pixelwise', 'combined_segmentation') - An instance of nn.Module (custom loss) - None (uses loss_type parameter for backward compatibility)

None
loss_type str

[DEPRECATED] Loss function type ('ce', 'dice', 'focal', or 'combined'). Use loss_fn parameter instead for better flexibility.

'combined'
dice_weight float

Weight for Dice loss when using 'combined' loss (default: 1.0).

1.0
ce_weight float

Weight for cross-entropy loss when using 'combined' loss (default: 1.0).

1.0
ignore_index int

Index to ignore in loss computation (default: 255).

255
class_weights Tensor | None

Optional per-class weights for CE loss.

None
metrics MetricManager | list[MetricConfig] | None

A :class:MetricManager instance or list of :class:MetricConfig objects. Required - no default metrics are provided.

None
logging_config LoggingConfig | None

Optional :class:LoggingConfig for enhanced logging.

None
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided, enables the preprocess() method for inference-time preprocessing using model-specific normalization.

None
lr float

Learning rate.

0.0001
weight_decay float

Weight decay for optimizer.

0.0001
optimizer str | dict[str, Any]

Optimizer name or dict with 'class' and 'params' keys.

'adamw'
optimizer_kwargs dict[str, Any] | None

Additional kwargs for the optimizer.

None
scheduler str | dict[str, Any] | None

Scheduler name, dict with 'class' and 'params' keys, or None.

'cosine'
scheduler_kwargs dict[str, Any] | None

Extra kwargs forwarded to the LR scheduler.

None
freeze_backbone bool

If True, backbone parameters are frozen.

False
compile_model bool

If True (default), apply torch.compile() to the backbone and head for faster inference and training. Requires PyTorch 2.0+.

True
compile_kwargs dict[str, Any] | None

Optional dict of kwargs to pass to torch.compile(). Common options: mode ("default", "reduce-overhead", "max-autotune"), fullgraph (True/False), dynamic (True/False).

None
seed int | None

Random seed for reproducibility. If None, no seeding is performed. Default is 42 for reproducible results.

None
deterministic bool

If True (default), enables deterministic algorithms in PyTorch for full reproducibility (may impact performance). Set to False for faster training.

True
Example

model = SemanticSegmentor( ... backbone="resnet50", ... num_classes=19, ... head_type="deeplabv3plus", ... metrics=[ ... MetricConfig( ... name="iou", ... backend="torchmetrics", ... metric_class="JaccardIndex", ... params={"task": "multiclass", "num_classes": 19, "average": "macro"}, ... stages=["val", "test"], ... prog_bar=True, ... ), ... ], ... loss_type="combined", ... dice_weight=1.0, ... )

__init__

__init__(backbone: str | FeatureBackboneConfig, num_classes: int, head_type: str = 'deeplabv3plus', loss_fn: str | Module | None = None, loss_type: str = 'combined', dice_weight: float = 1.0, ce_weight: float = 1.0, ignore_index: int = 255, class_weights: Tensor | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.0001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, freeze_backbone: bool = False, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)

forward

forward(images: Tensor) -> torch.Tensor

Forward pass through backbone and head.

Parameters:

Name Type Description Default
images Tensor

Input images [B, 3, H, W]

required

Returns:

Type Description
Tensor

Segmentation logits [B, num_classes, H', W']

predict

predict(images: Tensor, return_logits: bool = False) -> torch.Tensor

Predict segmentation masks for input images.

Parameters:

Name Type Description Default
images Tensor

Input images [B, 3, H, W]

required
return_logits bool

If True, return logits instead of class predictions

False

Returns:

Type Description
Tensor

Predicted class indices [B, H, W] or logits [B, C, H, W]

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W]

required
batch_idx int

Batch index

required

Returns:

Type Description
Tensor

Loss value

validation_step

validation_step(batch: dict[str, Any], batch_idx: int) -> None

Validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W]

required
batch_idx int

Batch index

required

test_step

test_step(batch: dict[str, Any], batch_idx: int) -> None

Test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W]

required
batch_idx int

Batch index

required

predict_step

predict_step(batch: Any, batch_idx: int) -> torch.Tensor

Prediction step for Trainer.predict().

Parameters:

Name Type Description Default
batch Any

Input batch (dict or tensor)

required
batch_idx int

Batch index

required

Returns:

Type Description
Tensor

Predicted class indices [B, H, W]

configure_optimizers

configure_optimizers() -> dict

Configure optimizer and learning rate scheduler.

Supports torch.optim, timm optimizers, and custom optimizers/schedulers.

InstanceSegmentor

End-to-end instance segmentation model combining FCOS-style object detection with per-instance mask prediction. Integrates timm backbones with FPN and dual-head architecture for boxes and masks.

autotimm.InstanceSegmentor

Bases: PreprocessingMixin, LightningModule

End-to-end instance segmentation model.

Combines FCOS-style detection with per-instance mask prediction.

Architecture: timm backbone → FPN → Detection Head + Mask Head → NMS

Parameters:

Name Type Description Default
backbone str | FeatureBackboneConfig

A timm model name (str) or a :class:FeatureBackboneConfig.

required
num_classes int

Number of object classes (excluding background).

required
cls_loss_fn str | Module | None

Classification loss function. Can be: - A string from the loss registry (e.g., 'focal') - An instance of nn.Module (custom loss) - None (uses FocalLoss with focal_alpha and focal_gamma)

None
reg_loss_fn str | Module | None

Regression loss function. Can be: - A string from the loss registry (e.g., 'giou') - An instance of nn.Module (custom loss) - None (uses GIoULoss)

None
mask_loss_fn str | Module | None

Mask loss function. Can be: - A string from the loss registry (e.g., 'mask') - An instance of nn.Module (custom loss) - None (uses MaskLoss)

None
metrics MetricManager | list[MetricConfig] | None

A :class:MetricManager instance or list of :class:MetricConfig objects. Optional - if not provided, uses MeanAveragePrecision with mask support.

None
logging_config LoggingConfig | None

Optional :class:LoggingConfig for enhanced logging.

None
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided, enables the preprocess() method for inference-time preprocessing using model-specific normalization.

None
lr float

Learning rate.

0.0001
weight_decay float

Weight decay for optimizer.

0.0001
optimizer str | dict[str, Any]

Optimizer name or dict with 'class' and 'params' keys.

'adamw'
optimizer_kwargs dict[str, Any] | None

Additional kwargs for the optimizer.

None
scheduler str | dict[str, Any] | None

Scheduler name, dict config, or None for no scheduler.

'cosine'
scheduler_kwargs dict[str, Any] | None

Extra kwargs forwarded to the LR scheduler.

None
fpn_channels int

Number of channels in FPN layers.

256
head_num_convs int

Number of conv layers in detection head branches.

4
mask_size int

ROI mask resolution (default: 28).

28
mask_loss_weight float

Weight for mask loss.

1.0
focal_alpha float

Alpha parameter for focal loss.

0.25
focal_gamma float

Gamma parameter for focal loss.

2.0
cls_loss_weight float

Weight for classification loss.

1.0
reg_loss_weight float

Weight for regression loss.

1.0
centerness_loss_weight float

Weight for centerness loss.

1.0
score_thresh float

Score threshold for detections during inference.

0.05
nms_thresh float

IoU threshold for NMS.

0.5
max_detections_per_image int

Maximum detections to keep per image.

100
freeze_backbone bool

If True, backbone parameters are frozen.

False
roi_pool_size int

Size of ROI pooling output (default: 14).

14
mask_threshold float

Threshold for binarizing predicted masks (default: 0.5).

0.5
compile_model bool

If True (default), apply torch.compile() to the backbone, FPN, and heads for faster inference and training. Requires PyTorch 2.0+.

True
compile_kwargs dict[str, Any] | None

Optional dict of kwargs to pass to torch.compile(). Common options: mode ("default", "reduce-overhead", "max-autotune"), fullgraph (True/False), dynamic (True/False).

None
seed int | None

Random seed for reproducibility. If None, no seeding is performed. Default is 42 for reproducible results.

None
deterministic bool

If True (default), enables deterministic algorithms in PyTorch for full reproducibility (may impact performance). Set to False for faster training.

True
Example

model = InstanceSegmentor( ... backbone="resnet50", ... num_classes=80, ... metrics=[ ... MetricConfig( ... name="mask_mAP", ... backend="torchmetrics", ... metric_class="MeanAveragePrecision", ... params={"box_format": "xyxy", "iou_type": "segm"}, ... stages=["val", "test"], ... ), ... ], ... lr=1e-4, ... mask_loss_weight=1.0, ... )

__init__

__init__(backbone: str | FeatureBackboneConfig, num_classes: int, cls_loss_fn: str | Module | None = None, reg_loss_fn: str | Module | None = None, mask_loss_fn: str | Module | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.0001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, fpn_channels: int = 256, head_num_convs: int = 4, mask_size: int = 28, mask_loss_weight: float = 1.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, cls_loss_weight: float = 1.0, reg_loss_weight: float = 1.0, centerness_loss_weight: float = 1.0, score_thresh: float = 0.05, nms_thresh: float = 0.5, max_detections_per_image: int = 100, freeze_backbone: bool = False, roi_pool_size: int = 14, mask_threshold: float = 0.5, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)

forward

forward(images: Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]

Forward pass through detector (without mask head).

Parameters:

Name Type Description Default
images Tensor

Input images [B, C, H, W]

required

Returns:

Type Description
tuple[list[Tensor], list[Tensor], list[Tensor]]

Tuple of (cls_outputs, reg_outputs, centerness_outputs) per FPN level

predict

predict(images: Tensor) -> list[dict[str, torch.Tensor]]

Predict instance segmentation for input images.

Parameters:

Name Type Description Default
images Tensor

Input images [B, 3, H, W]

required

Returns:

Type Description
list[dict[str, Tensor]]

List of dicts with 'boxes', 'labels', 'scores', 'masks' for each image

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image', 'boxes', 'labels', 'masks'

required
batch_idx int

Batch index

required

Returns:

Type Description
Tensor

Total loss value

validation_step

validation_step(batch: dict[str, Any], batch_idx: int) -> None

Validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image', 'boxes', 'labels', 'masks'

required
batch_idx int

Batch index

required

test_step

test_step(batch: dict[str, Any], batch_idx: int) -> None

Test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dict with 'image', 'boxes', 'labels', 'masks'

required
batch_idx int

Batch index

required

predict_step

predict_step(batch: Any, batch_idx: int) -> list[dict[str, torch.Tensor]]

Prediction step for Trainer.predict().

Parameters:

Name Type Description Default
batch Any

Input batch

required
batch_idx int

Batch index

required

Returns:

Type Description
list[dict[str, Tensor]]

List of predictions

configure_optimizers

configure_optimizers() -> dict

Configure optimizer and learning rate scheduler.

Data Modules

SegmentationDataModule

PyTorch Lightning DataModule for semantic segmentation. Supports multiple dataset formats including PNG masks, Cityscapes, COCO, and Pascal VOC. Provides flexible augmentation presets and custom transform support.

autotimm.SegmentationDataModule

Bases: LightningDataModule

PyTorch Lightning DataModule for semantic segmentation.

Parameters:

Name Type Description Default
data_dir str | Path

Root directory of the dataset

required
format str

Dataset format ('png', 'coco', 'cityscapes', 'voc', 'csv')

'png'
image_size int

Target image size (square)

512
batch_size int

Batch size for dataloaders

8
num_workers int

Number of dataloader workers. Defaults to os.cpu_count().

min(cpu_count() or 4, 4)
augmentation_preset str

Augmentation preset ('default', 'strong', 'light')

'default'
custom_train_transforms Any

Optional custom training transforms

None
custom_val_transforms Any

Optional custom validation transforms

None
class_mapping dict[int, int] | None

Optional mapping from dataset class IDs to contiguous IDs

None
ignore_index int

Index to use for ignored pixels (default: 255)

255
transform_config TransformConfig | None

Optional TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization.

None
backbone str | Module | None

Optional backbone name or module for model-specific normalization.

None
train_csv str | Path | None

Path to training CSV file (used when format='csv').

None
val_csv str | Path | None

Path to validation CSV file (used when format='csv').

None
test_csv str | Path | None

Path to test CSV file (used when format='csv').

None

__init__

__init__(data_dir: str | Path, format: str = 'png', image_size: int = 512, batch_size: int = 8, num_workers: int = min(os.cpu_count() or 4, 4), augmentation_preset: str = 'default', custom_train_transforms: Any = None, custom_val_transforms: Any = None, class_mapping: dict[int, int] | None = None, ignore_index: int = 255, transform_config: TransformConfig | None = None, backbone: str | Module | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None)

setup

setup(stage: str | None = None)

Setup datasets for each stage.

Parameters:

Name Type Description Default
stage str | None

'fit', 'validate', 'test', or 'predict'

None

train_dataloader

train_dataloader() -> DataLoader

Get training dataloader.

val_dataloader

val_dataloader() -> DataLoader | None

Get validation dataloader.

test_dataloader

test_dataloader() -> DataLoader

Get test dataloader.

InstanceSegmentationDataModule

PyTorch Lightning DataModule for instance segmentation using COCO format. Handles both detection boxes and instance masks with built-in augmentation support.

autotimm.InstanceSegmentationDataModule

Bases: LightningDataModule

PyTorch Lightning DataModule for instance segmentation.

Supports two modes:

  1. COCO mode (default) -- expects COCO-format annotations.
  2. CSV mode -- provide train_csv pointing to a CSV file with columns image_path,x_min,y_min,x_max,y_max,label,mask_path.

Parameters:

Name Type Description Default
data_dir str | Path

Root directory of COCO dataset

'.'
image_size int

Target image size (square)

640
batch_size int

Batch size for dataloaders

4
num_workers int

Number of dataloader workers. Defaults to os.cpu_count().

min(cpu_count() or 4, 4)
augmentation_preset str

Augmentation strength ('default', 'strong', 'light')

'default'
custom_train_transforms Any

Optional custom training transforms

None
custom_val_transforms Any

Optional custom validation transforms

None
min_keypoints int

Minimum number of keypoints for an instance to be valid

0
min_area float

Minimum area for an instance to be valid

0.0
transform_config TransformConfig | None

Optional TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization.

None
backbone str | Module | None

Optional backbone name or module for model-specific normalization.

None
train_csv str | Path | None

Path to training CSV file (CSV mode).

None
val_csv str | Path | None

Path to validation CSV file (CSV mode).

None
test_csv str | Path | None

Path to test CSV file (CSV mode).

None
image_dir str | Path | None

Root directory for resolving image/mask paths in CSV mode.

None
image_column str

CSV column name for image paths.

'image_path'
bbox_columns list[str] | None

CSV column names for bbox coordinates.

None
label_column str

CSV column name for class labels.

'label'
mask_column str

CSV column name for mask file paths.

'mask_path'

__init__

__init__(data_dir: str | Path = '.', image_size: int = 640, batch_size: int = 4, num_workers: int = min(os.cpu_count() or 4, 4), augmentation_preset: str = 'default', custom_train_transforms: Any = None, custom_val_transforms: Any = None, min_keypoints: int = 0, min_area: float = 0.0, transform_config: TransformConfig | None = None, backbone: str | Module | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None, image_dir: str | Path | None = None, image_column: str = 'image_path', bbox_columns: list[str] | None = None, label_column: str = 'label', mask_column: str = 'mask_path')

setup

setup(stage: str | None = None)

Setup datasets for each stage.

Parameters:

Name Type Description Default
stage str | None

'fit', 'validate', 'test', or 'predict'

None

train_dataloader

train_dataloader() -> DataLoader

Get training dataloader.

val_dataloader

val_dataloader() -> DataLoader | None

Get validation dataloader.

test_dataloader

test_dataloader() -> DataLoader

Get test dataloader.

Segmentation Heads

DeepLabV3PlusHead

DeepLabV3+ segmentation head with Atrous Spatial Pyramid Pooling (ASPP) and decoder with skip connections. Provides multi-scale context aggregation for accurate semantic segmentation.

autotimm.heads.DeepLabV3PlusHead

Bases: Module

DeepLabV3+ segmentation head with ASPP and decoder.

Combines high-level semantic features (C5) with low-level spatial features (C2) for accurate segmentation.

Parameters:

Name Type Description Default
in_channels_list Sequence[int]

List of input channel counts from backbone [C2, C3, C4, C5].

required
num_classes int

Number of segmentation classes.

required
aspp_out_channels int

Number of ASPP output channels.

256
decoder_channels int

Number of decoder output channels.

256
dilation_rates Sequence[int]

Dilation rates for ASPP.

(6, 12, 18)

__init__

__init__(in_channels_list: Sequence[int], num_classes: int, aspp_out_channels: int = 256, decoder_channels: int = 256, dilation_rates: Sequence[int] = (6, 12, 18))

forward

forward(features: list[Tensor]) -> torch.Tensor

Forward pass through DeepLabV3+ head.

Parameters:

Name Type Description Default
features list[Tensor]

List of backbone features [C2, C3, C4, C5]

required

Returns:

Type Description
Tensor

Segmentation logits [B, num_classes, H/4, W/4]

FCNHead

Fully Convolutional Network (FCN) head for semantic segmentation. Simple upsampling-based architecture suitable for baseline models.

autotimm.heads.FCNHead

Bases: Module

Simple Fully Convolutional Network head for segmentation.

A simpler baseline compared to DeepLabV3+.

Parameters:

Name Type Description Default
in_channels int

Number of input channels from backbone.

required
num_classes int

Number of segmentation classes.

required
intermediate_channels int

Number of intermediate channels.

512

__init__

__init__(in_channels: int, num_classes: int, intermediate_channels: int = 512)

forward

forward(features: list[Tensor]) -> torch.Tensor

Forward pass through FCN head.

Parameters:

Name Type Description Default
features list[Tensor]

List of backbone features (uses last one)

required

Returns:

Type Description
Tensor

Segmentation logits [B, num_classes, H, W]

MaskHead

ROI-based mask prediction head for instance segmentation. Predicts per-instance binary masks from ROI-aligned features.

autotimm.heads.MaskHead

Bases: Module

Mask prediction head for instance segmentation.

Takes ROI-aligned features and predicts per-instance binary masks.

Parameters:

Name Type Description Default
in_channels int

Number of input channels from ROI align.

256
num_classes int

Number of object classes.

80
hidden_channels int

Number of hidden layer channels.

256
num_convs int

Number of convolutional layers before deconv.

4
mask_size int

Output mask resolution.

28

__init__

__init__(in_channels: int = 256, num_classes: int = 80, hidden_channels: int = 256, num_convs: int = 4, mask_size: int = 28)

forward

forward(roi_features: Tensor) -> torch.Tensor

Forward pass through mask head.

Parameters:

Name Type Description Default
roi_features Tensor

ROI-aligned features [N, C, H, W]

required

Returns:

Type Description
Tensor

Mask logits [N, num_classes, mask_size, mask_size]

ASPP

Atrous Spatial Pyramid Pooling module for multi-scale feature extraction. Core component of DeepLabV3+ architecture.

autotimm.heads.ASPP

Bases: Module

Atrous Spatial Pyramid Pooling module for DeepLabV3+.

Applies parallel atrous convolutions with different dilation rates to capture multi-scale context.

Parameters:

Name Type Description Default
in_channels int

Number of input channels from backbone.

required
out_channels int

Number of output channels.

256
dilation_rates Sequence[int]

List of dilation rates for parallel branches.

(6, 12, 18)

__init__

__init__(in_channels: int, out_channels: int = 256, dilation_rates: Sequence[int] = (6, 12, 18))

forward

forward(x: Tensor) -> torch.Tensor

Forward pass through ASPP.

Parameters:

Name Type Description Default
x Tensor

Input feature map [B, C, H, W]

required

Returns:

Type Description
Tensor

Output feature map [B, out_channels, H, W]

Segmentation Losses

DiceLoss

Dice loss for multi-class semantic segmentation. Optimizes directly for IoU-like metric. Formula: 1 - (2 * |X ∩ Y|) / (|X| + |Y|). Effective for handling class imbalance.

autotimm.losses.DiceLoss

Bases: Module

Dice loss for multi-class segmentation.

Formula: 1 - (2 * |X ∩ Y|) / (|X| + |Y|)

Parameters:

Name Type Description Default
num_classes int

Number of segmentation classes

required
smooth float

Smoothing constant to avoid division by zero (default: 1.0)

1.0
ignore_index int

Index to ignore in loss computation (default: 255)

255
reduction str

Reduction method ('mean', 'sum', 'none') (default: 'mean')

'mean'

__init__

__init__(num_classes: int, smooth: float = 1.0, ignore_index: int = 255, reduction: str = 'mean')

forward

forward(logits: Tensor, targets: Tensor) -> torch.Tensor

Compute Dice loss.

Parameters:

Name Type Description Default
logits Tensor

Predicted logits [B, C, H, W]

required
targets Tensor

Ground truth masks [B, H, W] with class indices

required

Returns:

Type Description
Tensor

Dice loss value

CombinedSegmentationLoss

Combines Cross-Entropy and Dice losses with configurable weights. Leverages pixel-wise classification (CE) and region overlap optimization (Dice) for robust segmentation.

autotimm.losses.CombinedSegmentationLoss

Bases: Module

Combined cross-entropy and Dice loss for semantic segmentation.

Parameters:

Name Type Description Default
num_classes int

Number of segmentation classes

required
ce_weight float

Weight for cross-entropy loss (default: 1.0)

1.0
dice_weight float

Weight for Dice loss (default: 1.0)

1.0
ignore_index int

Index to ignore in loss computation (default: 255)

255
class_weights Tensor | None

Optional per-class weights for CE loss (default: None)

None

__init__

__init__(num_classes: int, ce_weight: float = 1.0, dice_weight: float = 1.0, ignore_index: int = 255, class_weights: Tensor | None = None)

forward

forward(logits: Tensor, targets: Tensor) -> torch.Tensor

Compute combined loss.

Parameters:

Name Type Description Default
logits Tensor

Predicted logits [B, C, H, W]

required
targets Tensor

Ground truth masks [B, H, W] with class indices

required

Returns:

Type Description
Tensor

Combined loss value

FocalLossPixelwise

Focal loss for dense pixel-wise prediction. Down-weights easy examples to focus on hard pixels. Particularly effective for handling severe class imbalance in segmentation tasks.

autotimm.losses.FocalLossPixelwise

Bases: Module

Focal loss for dense pixel-wise prediction.

Handles class imbalance by down-weighting easy examples.

Parameters:

Name Type Description Default
alpha float

Weighting factor in [0, 1] to balance positive/negative examples

0.25
gamma float

Focusing parameter for modulating loss (default: 2.0)

2.0
ignore_index int

Index to ignore in loss computation (default: 255)

255
reduction str

Reduction method ('mean', 'sum', 'none') (default: 'mean')

'mean'

__init__

__init__(alpha: float = 0.25, gamma: float = 2.0, ignore_index: int = 255, reduction: str = 'mean')

forward

forward(logits: Tensor, targets: Tensor) -> torch.Tensor

Compute focal loss.

Parameters:

Name Type Description Default
logits Tensor

Predicted logits [B, C, H, W]

required
targets Tensor

Ground truth masks [B, H, W] with class indices

required

Returns:

Type Description
Tensor

Focal loss value

TverskyLoss

Generalized Dice loss with configurable false positive/negative trade-off via alpha and beta parameters. Useful for highly imbalanced segmentation tasks.

autotimm.losses.TverskyLoss

Bases: Module

Tversky loss for segmentation.

Generalization of Dice loss with separate control over false positives and false negatives.

Parameters:

Name Type Description Default
num_classes int

Number of segmentation classes

required
alpha float

Weight for false positives (default: 0.5)

0.5
beta float

Weight for false negatives (default: 0.5)

0.5
smooth float

Smoothing constant (default: 1.0)

1.0
ignore_index int

Index to ignore in loss computation (default: 255)

255
reduction str

Reduction method ('mean', 'sum', 'none') (default: 'mean')

'mean'

__init__

__init__(num_classes: int, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1.0, ignore_index: int = 255, reduction: str = 'mean')

forward

forward(logits: Tensor, targets: Tensor) -> torch.Tensor

Compute Tversky loss.

Parameters:

Name Type Description Default
logits Tensor

Predicted logits [B, C, H, W]

required
targets Tensor

Ground truth masks [B, H, W] with class indices

required

Returns:

Type Description
Tensor

Tversky loss value

MaskLoss

Binary cross-entropy loss for instance segmentation masks. Used in conjunction with detection losses for end-to-end instance segmentation training.

autotimm.losses.MaskLoss

Bases: Module

Binary cross-entropy loss for instance segmentation masks.

Parameters:

Name Type Description Default
reduction str

Reduction method ('mean', 'sum', 'none') (default: 'mean')

'mean'
pos_weight float | None

Weight for positive examples (default: None)

None

__init__

__init__(reduction: str = 'mean', pos_weight: float | None = None)

forward

forward(pred_masks: Tensor, target_masks: Tensor) -> torch.Tensor

Compute binary cross-entropy loss for masks.

Parameters:

Name Type Description Default
pred_masks Tensor

Predicted masks [N, H, W] or [N, 1, H, W] (logits)

required
target_masks Tensor

Ground truth binary masks [N, H, W] or [N, 1, H, W]

required

Returns:

Type Description
Tensor

Mask loss value

Import Styles

AutoTimm supports multiple import styles for convenience:

Direct Imports

from autotimm import (
    SemanticSegmentor,
    InstanceSegmentor,
    SegmentationDataModule,
    DiceLoss,
    DeepLabV3PlusHead,
)

Submodule Aliases

# Use singular form aliases
from autotimm.task import SemanticSegmentor, InstanceSegmentor
from autotimm.loss import DiceLoss, CombinedSegmentationLoss
from autotimm.head import DeepLabV3PlusHead, MaskHead
from autotimm.metric import MetricConfig

Original Imports

# Original plural form still works
from autotimm.tasks import SemanticSegmentor
from autotimm.losses import DiceLoss
from autotimm.heads import DeepLabV3PlusHead
from autotimm.core.metrics import MetricConfig

Package Alias

import autotimm as at  # recommended alias

model = at.SemanticSegmentor(backbone="resnet50", num_classes=19)
loss = at.DiceLoss(num_classes=19)

Namespace Access

import autotimm as at  # recommended alias

# Access via submodule aliases
model = at.task.SemanticSegmentor(...)
loss = at.losses.DiceLoss(...)
head = at.heads.DeepLabV3PlusHead(...)

Parameters Reference

SemanticSegmentor Parameters

Parameter Type Default Description
backbone str | FeatureBackboneConfig Required Timm model name or config
num_classes int Required Number of segmentation classes
head_type str "deeplabv3plus" Head architecture ("deeplabv3plus" or "fcn")
loss_type str "combined" Loss function type
dice_weight float 1.0 Weight for Dice loss in combined mode
ce_weight float 1.0 Weight for CE loss in combined mode
ignore_index int 255 Index to ignore in loss computation
metrics list[MetricConfig] None Metric configurations
lr float 1e-4 Learning rate
weight_decay float 1e-4 Weight decay
optimizer str | dict "adamw" Optimizer name or config
optimizer_kwargs dict | None None Extra optimizer kwargs
scheduler str | dict "cosine" Scheduler name or config
scheduler_kwargs dict | None None Extra scheduler kwargs
freeze_backbone bool False Whether to freeze backbone parameters
transform_config TransformConfig | None None Transform config for preprocessing
logging_config LoggingConfig | None None Enhanced logging configuration
loss_fn str | nn.Module | None None Custom loss function (string from registry, nn.Module, or None for default)
class_weights torch.Tensor | None None Class weights for loss computation
compile_model bool True Apply torch.compile() for faster training/inference
compile_kwargs dict | None None Kwargs for torch.compile()
seed int | None None Random seed for reproducibility (None to disable)
deterministic bool True Enable deterministic algorithms

InstanceSegmentor Parameters

Parameter Type Default Description
backbone str | FeatureBackboneConfig Required Timm model name or config
num_classes int Required Number of object classes (excluding background)
fpn_channels int 256 Number of FPN channels
head_num_convs int 4 Number of conv layers in detection head
mask_size int 28 ROI mask resolution
roi_pool_size int 14 ROI pooling output size
mask_loss_weight float 1.0 Weight for mask loss
focal_alpha float 0.25 Alpha parameter for focal loss
focal_gamma float 2.0 Gamma parameter for focal loss
cls_loss_weight float 1.0 Weight for classification loss
reg_loss_weight float 1.0 Weight for regression loss
centerness_loss_weight float 1.0 Weight for centerness loss
score_thresh float 0.05 Score threshold for detections
nms_thresh float 0.5 IoU threshold for NMS
max_detections_per_image int 100 Maximum detections per image
mask_threshold float 0.5 Threshold for binarizing predicted masks
metrics list[MetricConfig] None Metric configurations
logging_config LoggingConfig None Enhanced logging configuration
lr float 1e-4 Learning rate
weight_decay float 1e-4 Weight decay
optimizer str | dict "adamw" Optimizer name or config
optimizer_kwargs dict | None None Extra optimizer kwargs
scheduler str | dict "cosine" Scheduler name or config
scheduler_kwargs dict | None None Extra scheduler kwargs
freeze_backbone bool False Whether to freeze backbone parameters
transform_config TransformConfig | None None Transform config for preprocessing
cls_loss_fn str | nn.Module | None None Classification loss function
reg_loss_fn str | nn.Module | None None Regression loss function
mask_loss_fn str | nn.Module | None None Mask loss function
compile_model bool True Apply torch.compile() for faster training/inference
compile_kwargs dict | None None Kwargs for torch.compile()
seed int | None None Random seed for reproducibility (None to disable)
deterministic bool True Enable deterministic algorithms

SegmentationDataModule Parameters

Parameter Type Default Description
data_dir str | Path Required Root directory of dataset
format str "png" Dataset format ("png", "coco", "cityscapes", "voc")
image_size int 512 Target image size (square)
batch_size int 8 Batch size
num_workers int 4 Number of dataloader workers
augmentation_preset str "default" Augmentation preset ("default", "strong", "light")
custom_train_transforms Any None Custom training transforms (overrides preset)
custom_val_transforms Any None Custom validation transforms
class_mapping dict None Mapping from dataset class IDs to contiguous IDs
ignore_index int 255 Index for ignored pixels (e.g., boundaries)
transform_config TransformConfig | None None Unified transform configuration
backbone str | nn.Module | None None Backbone for model-specific normalization

InstanceSegmentationDataModule Parameters

Parameter Type Default Description
data_dir str | Path Required Root directory of COCO dataset
image_size int 640 Target image size (square)
batch_size int 4 Batch size
num_workers int 4 Number of dataloader workers
augmentation_preset str "default" Augmentation strength ("default", "strong", "light")
custom_train_transforms Any None Custom training transforms (overrides preset)
custom_val_transforms Any None Custom validation transforms
min_keypoints int 0 Minimum keypoints for valid instance
min_area float 0.0 Minimum area for valid instance
transform_config TransformConfig | None None Unified transform configuration
backbone str | nn.Module | None None Backbone for model-specific normalization

Examples

Basic Semantic Segmentation

from autotimm import SemanticSegmentor, SegmentationDataModule, MetricConfig

# Setup data
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    image_size=512,
    batch_size=8,
)

# Create model
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=21,
    head_type="deeplabv3plus",
    loss_type="combined",
    ce_weight=1.0,
    dice_weight=1.0,
    metrics=[
        MetricConfig(
            name="iou",
            backend="torchmetrics",
            metric_class="JaccardIndex",
            params={"task": "multiclass", "num_classes": 21, "average": "macro"},
            stages=["val"],
            prog_bar=True,
        )
    ],
)

Instance Segmentation

from autotimm import InstanceSegmentor, InstanceSegmentationDataModule, MetricConfig

# Setup data
data = InstanceSegmentationDataModule(
    data_dir="./coco",
    image_size=640,
    batch_size=4,
)

# Create model
model = InstanceSegmentor(
    backbone="resnet50",
    num_classes=80,
    mask_loss_weight=1.0,
    mask_size=28,
    roi_pool_size=14,
    metrics=[
        MetricConfig(
            name="mask_mAP",
            backend="torchmetrics",
            metric_class="MeanAveragePrecision",
            params={"box_format": "xyxy", "iou_type": "segm"},
            stages=["val"],
            prog_bar=True,
        )
    ],
)

Using Import Aliases

from autotimm.task import SemanticSegmentor, InstanceSegmentor
from autotimm.loss import DiceLoss, CombinedSegmentationLoss, MaskLoss
from autotimm.head import DeepLabV3PlusHead, MaskHead
from autotimm.metric import MetricConfig

# Create model using aliases
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    loss_type="combined",
)

# Directly instantiate losses if needed
dice_loss = DiceLoss(num_classes=19, ignore_index=255)
combined_loss = CombinedSegmentationLoss(
    num_classes=19,
    ce_weight=1.0,
    dice_weight=1.0,
)
mask_loss = MaskLoss()

Custom Loss Configuration

from autotimm.loss import TverskyLoss

# Create custom Tversky loss for handling class imbalance
loss_fn = TverskyLoss(
    num_classes=19,
    alpha=0.3,  # Lower alpha emphasizes recall
    beta=0.7,   # Higher beta penalizes false negatives more
    ignore_index=255,
)

# Note: Use loss_type parameter in model for built-in losses
# For custom losses, you would need to subclass and override the loss

Advanced Configuration

from autotimm import SemanticSegmentor, LoggingConfig

# Model with enhanced logging
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    loss_type="combined",
    ce_weight=1.0,
    dice_weight=1.0,
    logging_config=LoggingConfig(
        log_learning_rate=True,
        log_gradient_norm=True,
        log_weight_norm=True,
    ),
    freeze_backbone=False,  # Set True to freeze backbone
)

With TransformConfig (Preprocessing)

Enable inference-time preprocessing with model-specific normalization:

from autotimm import SemanticSegmentor, TransformConfig

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    metrics=[...],
    transform_config=TransformConfig(image_size=512),  # Enable preprocess()
)

# Preprocess raw images for inference
from PIL import Image
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Returns preprocessed tensor

# Run inference
model.eval()
with torch.inference_mode():
    predictions = model(tensor)

Shared Config for DataModule and Model

from autotimm import SemanticSegmentor, SegmentationDataModule, TransformConfig

# Shared config ensures same preprocessing
config = TransformConfig(preset="default", image_size=512)
backbone_name = "resnet50"

# DataModule uses model's normalization
data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",
    transform_config=config,
    backbone=backbone_name,
)

# Model uses same config for inference preprocessing
model = SemanticSegmentor(
    backbone=backbone_name,
    num_classes=19,
    metrics=[...],
    transform_config=config,
)

Instance Segmentation with Preprocessing

from autotimm import InstanceSegmentor, TransformConfig

model = InstanceSegmentor(
    backbone="resnet50",
    num_classes=80,
    metrics=[...],
    transform_config=TransformConfig(image_size=640),
)

# Get model's normalization config
config = model.get_data_config()
print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")

See Also