Segmentation API Reference¶
Complete API documentation for semantic and instance segmentation tasks.
Task Models¶
SemanticSegmentor¶
End-to-end semantic segmentation model with timm backbones, supporting multiple head architectures (DeepLabV3+, FCN) and loss functions (CE, Dice, Focal, Combined).
autotimm.SemanticSegmentor ¶
Bases: PreprocessingMixin, LightningModule
End-to-end semantic segmentation model backed by a timm backbone.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backbone
|
str | FeatureBackboneConfig
|
A timm model name (str) or a :class: |
required |
num_classes
|
int
|
Number of segmentation classes. |
required |
head_type
|
str
|
Type of segmentation head ('deeplabv3plus' or 'fcn'). |
'deeplabv3plus'
|
loss_fn
|
str | Module | None
|
Loss function to use. Can be: - A string from the loss registry (e.g., 'dice', 'focal_pixelwise', 'combined_segmentation') - An instance of nn.Module (custom loss) - None (uses loss_type parameter for backward compatibility) |
None
|
loss_type
|
str
|
[DEPRECATED] Loss function type ('ce', 'dice', 'focal', or 'combined'). Use loss_fn parameter instead for better flexibility. |
'combined'
|
dice_weight
|
float
|
Weight for Dice loss when using 'combined' loss (default: 1.0). |
1.0
|
ce_weight
|
float
|
Weight for cross-entropy loss when using 'combined' loss (default: 1.0). |
1.0
|
ignore_index
|
int
|
Index to ignore in loss computation (default: 255). |
255
|
class_weights
|
Tensor | None
|
Optional per-class weights for CE loss. |
None
|
metrics
|
MetricManager | list[MetricConfig] | None
|
A :class: |
None
|
logging_config
|
LoggingConfig | None
|
Optional :class: |
None
|
transform_config
|
TransformConfig | None
|
Optional :class: |
None
|
lr
|
float
|
Learning rate. |
0.0001
|
weight_decay
|
float
|
Weight decay for optimizer. |
0.0001
|
optimizer
|
str | dict[str, Any]
|
Optimizer name or dict with 'class' and 'params' keys. |
'adamw'
|
optimizer_kwargs
|
dict[str, Any] | None
|
Additional kwargs for the optimizer. |
None
|
scheduler
|
str | dict[str, Any] | None
|
Scheduler name, dict with 'class' and 'params' keys, or None. |
'cosine'
|
scheduler_kwargs
|
dict[str, Any] | None
|
Extra kwargs forwarded to the LR scheduler. |
None
|
freeze_backbone
|
bool
|
If True, backbone parameters are frozen. |
False
|
compile_model
|
bool
|
If |
True
|
compile_kwargs
|
dict[str, Any] | None
|
Optional dict of kwargs to pass to |
None
|
seed
|
int | None
|
Random seed for reproducibility. If |
None
|
deterministic
|
bool
|
If |
True
|
Example
model = SemanticSegmentor( ... backbone="resnet50", ... num_classes=19, ... head_type="deeplabv3plus", ... metrics=[ ... MetricConfig( ... name="iou", ... backend="torchmetrics", ... metric_class="JaccardIndex", ... params={"task": "multiclass", "num_classes": 19, "average": "macro"}, ... stages=["val", "test"], ... prog_bar=True, ... ), ... ], ... loss_type="combined", ... dice_weight=1.0, ... )
__init__ ¶
__init__(backbone: str | FeatureBackboneConfig, num_classes: int, head_type: str = 'deeplabv3plus', loss_fn: str | Module | None = None, loss_type: str = 'combined', dice_weight: float = 1.0, ce_weight: float = 1.0, ignore_index: int = 255, class_weights: Tensor | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.0001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, freeze_backbone: bool = False, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)
forward ¶
Forward pass through backbone and head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
images
|
Tensor
|
Input images [B, 3, H, W] |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Segmentation logits [B, num_classes, H', W'] |
predict ¶
Predict segmentation masks for input images.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
images
|
Tensor
|
Input images [B, 3, H, W] |
required |
return_logits
|
bool
|
If True, return logits instead of class predictions |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Predicted class indices [B, H, W] or logits [B, C, H, W] |
training_step ¶
Training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W] |
required |
batch_idx
|
int
|
Batch index |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Loss value |
validation_step ¶
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W] |
required |
batch_idx
|
int
|
Batch index |
required |
test_step ¶
Test step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image' [B, 3, H, W] and 'mask' [B, H, W] |
required |
batch_idx
|
int
|
Batch index |
required |
predict_step ¶
Prediction step for Trainer.predict().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Any
|
Input batch (dict or tensor) |
required |
batch_idx
|
int
|
Batch index |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Predicted class indices [B, H, W] |
configure_optimizers ¶
Configure optimizer and learning rate scheduler.
Supports torch.optim, timm optimizers, and custom optimizers/schedulers.
InstanceSegmentor¶
End-to-end instance segmentation model combining FCOS-style object detection with per-instance mask prediction. Integrates timm backbones with FPN and dual-head architecture for boxes and masks.
autotimm.InstanceSegmentor ¶
Bases: PreprocessingMixin, LightningModule
End-to-end instance segmentation model.
Combines FCOS-style detection with per-instance mask prediction.
Architecture: timm backbone → FPN → Detection Head + Mask Head → NMS
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backbone
|
str | FeatureBackboneConfig
|
A timm model name (str) or a :class: |
required |
num_classes
|
int
|
Number of object classes (excluding background). |
required |
cls_loss_fn
|
str | Module | None
|
Classification loss function. Can be: - A string from the loss registry (e.g., 'focal') - An instance of nn.Module (custom loss) - None (uses FocalLoss with focal_alpha and focal_gamma) |
None
|
reg_loss_fn
|
str | Module | None
|
Regression loss function. Can be: - A string from the loss registry (e.g., 'giou') - An instance of nn.Module (custom loss) - None (uses GIoULoss) |
None
|
mask_loss_fn
|
str | Module | None
|
Mask loss function. Can be: - A string from the loss registry (e.g., 'mask') - An instance of nn.Module (custom loss) - None (uses MaskLoss) |
None
|
metrics
|
MetricManager | list[MetricConfig] | None
|
A :class: |
None
|
logging_config
|
LoggingConfig | None
|
Optional :class: |
None
|
transform_config
|
TransformConfig | None
|
Optional :class: |
None
|
lr
|
float
|
Learning rate. |
0.0001
|
weight_decay
|
float
|
Weight decay for optimizer. |
0.0001
|
optimizer
|
str | dict[str, Any]
|
Optimizer name or dict with 'class' and 'params' keys. |
'adamw'
|
optimizer_kwargs
|
dict[str, Any] | None
|
Additional kwargs for the optimizer. |
None
|
scheduler
|
str | dict[str, Any] | None
|
Scheduler name, dict config, or None for no scheduler. |
'cosine'
|
scheduler_kwargs
|
dict[str, Any] | None
|
Extra kwargs forwarded to the LR scheduler. |
None
|
fpn_channels
|
int
|
Number of channels in FPN layers. |
256
|
head_num_convs
|
int
|
Number of conv layers in detection head branches. |
4
|
mask_size
|
int
|
ROI mask resolution (default: 28). |
28
|
mask_loss_weight
|
float
|
Weight for mask loss. |
1.0
|
focal_alpha
|
float
|
Alpha parameter for focal loss. |
0.25
|
focal_gamma
|
float
|
Gamma parameter for focal loss. |
2.0
|
cls_loss_weight
|
float
|
Weight for classification loss. |
1.0
|
reg_loss_weight
|
float
|
Weight for regression loss. |
1.0
|
centerness_loss_weight
|
float
|
Weight for centerness loss. |
1.0
|
score_thresh
|
float
|
Score threshold for detections during inference. |
0.05
|
nms_thresh
|
float
|
IoU threshold for NMS. |
0.5
|
max_detections_per_image
|
int
|
Maximum detections to keep per image. |
100
|
freeze_backbone
|
bool
|
If True, backbone parameters are frozen. |
False
|
roi_pool_size
|
int
|
Size of ROI pooling output (default: 14). |
14
|
mask_threshold
|
float
|
Threshold for binarizing predicted masks (default: 0.5). |
0.5
|
compile_model
|
bool
|
If |
True
|
compile_kwargs
|
dict[str, Any] | None
|
Optional dict of kwargs to pass to |
None
|
seed
|
int | None
|
Random seed for reproducibility. If |
None
|
deterministic
|
bool
|
If |
True
|
Example
model = InstanceSegmentor( ... backbone="resnet50", ... num_classes=80, ... metrics=[ ... MetricConfig( ... name="mask_mAP", ... backend="torchmetrics", ... metric_class="MeanAveragePrecision", ... params={"box_format": "xyxy", "iou_type": "segm"}, ... stages=["val", "test"], ... ), ... ], ... lr=1e-4, ... mask_loss_weight=1.0, ... )
__init__ ¶
__init__(backbone: str | FeatureBackboneConfig, num_classes: int, cls_loss_fn: str | Module | None = None, reg_loss_fn: str | Module | None = None, mask_loss_fn: str | Module | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.0001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, fpn_channels: int = 256, head_num_convs: int = 4, mask_size: int = 28, mask_loss_weight: float = 1.0, focal_alpha: float = 0.25, focal_gamma: float = 2.0, cls_loss_weight: float = 1.0, reg_loss_weight: float = 1.0, centerness_loss_weight: float = 1.0, score_thresh: float = 0.05, nms_thresh: float = 0.5, max_detections_per_image: int = 100, freeze_backbone: bool = False, roi_pool_size: int = 14, mask_threshold: float = 0.5, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)
forward ¶
Forward pass through detector (without mask head).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
images
|
Tensor
|
Input images [B, C, H, W] |
required |
Returns:
| Type | Description |
|---|---|
tuple[list[Tensor], list[Tensor], list[Tensor]]
|
Tuple of (cls_outputs, reg_outputs, centerness_outputs) per FPN level |
predict ¶
Predict instance segmentation for input images.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
images
|
Tensor
|
Input images [B, 3, H, W] |
required |
Returns:
| Type | Description |
|---|---|
list[dict[str, Tensor]]
|
List of dicts with 'boxes', 'labels', 'scores', 'masks' for each image |
training_step ¶
Training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image', 'boxes', 'labels', 'masks' |
required |
batch_idx
|
int
|
Batch index |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Total loss value |
validation_step ¶
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image', 'boxes', 'labels', 'masks' |
required |
batch_idx
|
int
|
Batch index |
required |
test_step ¶
Test step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Dict with 'image', 'boxes', 'labels', 'masks' |
required |
batch_idx
|
int
|
Batch index |
required |
predict_step ¶
Prediction step for Trainer.predict().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Any
|
Input batch |
required |
batch_idx
|
int
|
Batch index |
required |
Returns:
| Type | Description |
|---|---|
list[dict[str, Tensor]]
|
List of predictions |
configure_optimizers ¶
Configure optimizer and learning rate scheduler.
Data Modules¶
SegmentationDataModule¶
PyTorch Lightning DataModule for semantic segmentation. Supports multiple dataset formats including PNG masks, Cityscapes, COCO, and Pascal VOC. Provides flexible augmentation presets and custom transform support.
autotimm.SegmentationDataModule ¶
Bases: LightningDataModule
PyTorch Lightning DataModule for semantic segmentation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_dir
|
str | Path
|
Root directory of the dataset |
required |
format
|
str
|
Dataset format ('png', 'coco', 'cityscapes', 'voc', 'csv') |
'png'
|
image_size
|
int
|
Target image size (square) |
512
|
batch_size
|
int
|
Batch size for dataloaders |
8
|
num_workers
|
int
|
Number of dataloader workers. Defaults to |
min(cpu_count() or 4, 4)
|
augmentation_preset
|
str
|
Augmentation preset ('default', 'strong', 'light') |
'default'
|
custom_train_transforms
|
Any
|
Optional custom training transforms |
None
|
custom_val_transforms
|
Any
|
Optional custom validation transforms |
None
|
class_mapping
|
dict[int, int] | None
|
Optional mapping from dataset class IDs to contiguous IDs |
None
|
ignore_index
|
int
|
Index to use for ignored pixels (default: 255) |
255
|
transform_config
|
TransformConfig | None
|
Optional TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization. |
None
|
backbone
|
str | Module | None
|
Optional backbone name or module for model-specific normalization. |
None
|
train_csv
|
str | Path | None
|
Path to training CSV file (used when format='csv'). |
None
|
val_csv
|
str | Path | None
|
Path to validation CSV file (used when format='csv'). |
None
|
test_csv
|
str | Path | None
|
Path to test CSV file (used when format='csv'). |
None
|
__init__ ¶
__init__(data_dir: str | Path, format: str = 'png', image_size: int = 512, batch_size: int = 8, num_workers: int = min(os.cpu_count() or 4, 4), augmentation_preset: str = 'default', custom_train_transforms: Any = None, custom_val_transforms: Any = None, class_mapping: dict[int, int] | None = None, ignore_index: int = 255, transform_config: TransformConfig | None = None, backbone: str | Module | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None)
setup ¶
Setup datasets for each stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
str | None
|
'fit', 'validate', 'test', or 'predict' |
None
|
InstanceSegmentationDataModule¶
PyTorch Lightning DataModule for instance segmentation using COCO format. Handles both detection boxes and instance masks with built-in augmentation support.
autotimm.InstanceSegmentationDataModule ¶
Bases: LightningDataModule
PyTorch Lightning DataModule for instance segmentation.
Supports two modes:
- COCO mode (default) -- expects COCO-format annotations.
- CSV mode -- provide
train_csvpointing to a CSV file with columnsimage_path,x_min,y_min,x_max,y_max,label,mask_path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_dir
|
str | Path
|
Root directory of COCO dataset |
'.'
|
image_size
|
int
|
Target image size (square) |
640
|
batch_size
|
int
|
Batch size for dataloaders |
4
|
num_workers
|
int
|
Number of dataloader workers. Defaults to |
min(cpu_count() or 4, 4)
|
augmentation_preset
|
str
|
Augmentation strength ('default', 'strong', 'light') |
'default'
|
custom_train_transforms
|
Any
|
Optional custom training transforms |
None
|
custom_val_transforms
|
Any
|
Optional custom validation transforms |
None
|
min_keypoints
|
int
|
Minimum number of keypoints for an instance to be valid |
0
|
min_area
|
float
|
Minimum area for an instance to be valid |
0.0
|
transform_config
|
TransformConfig | None
|
Optional TransformConfig for unified transform configuration. When provided along with backbone, uses model-specific normalization. |
None
|
backbone
|
str | Module | None
|
Optional backbone name or module for model-specific normalization. |
None
|
train_csv
|
str | Path | None
|
Path to training CSV file (CSV mode). |
None
|
val_csv
|
str | Path | None
|
Path to validation CSV file (CSV mode). |
None
|
test_csv
|
str | Path | None
|
Path to test CSV file (CSV mode). |
None
|
image_dir
|
str | Path | None
|
Root directory for resolving image/mask paths in CSV mode. |
None
|
image_column
|
str
|
CSV column name for image paths. |
'image_path'
|
bbox_columns
|
list[str] | None
|
CSV column names for bbox coordinates. |
None
|
label_column
|
str
|
CSV column name for class labels. |
'label'
|
mask_column
|
str
|
CSV column name for mask file paths. |
'mask_path'
|
__init__ ¶
__init__(data_dir: str | Path = '.', image_size: int = 640, batch_size: int = 4, num_workers: int = min(os.cpu_count() or 4, 4), augmentation_preset: str = 'default', custom_train_transforms: Any = None, custom_val_transforms: Any = None, min_keypoints: int = 0, min_area: float = 0.0, transform_config: TransformConfig | None = None, backbone: str | Module | None = None, train_csv: str | Path | None = None, val_csv: str | Path | None = None, test_csv: str | Path | None = None, image_dir: str | Path | None = None, image_column: str = 'image_path', bbox_columns: list[str] | None = None, label_column: str = 'label', mask_column: str = 'mask_path')
setup ¶
Setup datasets for each stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
str | None
|
'fit', 'validate', 'test', or 'predict' |
None
|
Segmentation Heads¶
DeepLabV3PlusHead¶
DeepLabV3+ segmentation head with Atrous Spatial Pyramid Pooling (ASPP) and decoder with skip connections. Provides multi-scale context aggregation for accurate semantic segmentation.
autotimm.heads.DeepLabV3PlusHead ¶
Bases: Module
DeepLabV3+ segmentation head with ASPP and decoder.
Combines high-level semantic features (C5) with low-level spatial features (C2) for accurate segmentation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels_list
|
Sequence[int]
|
List of input channel counts from backbone [C2, C3, C4, C5]. |
required |
num_classes
|
int
|
Number of segmentation classes. |
required |
aspp_out_channels
|
int
|
Number of ASPP output channels. |
256
|
decoder_channels
|
int
|
Number of decoder output channels. |
256
|
dilation_rates
|
Sequence[int]
|
Dilation rates for ASPP. |
(6, 12, 18)
|
__init__ ¶
__init__(in_channels_list: Sequence[int], num_classes: int, aspp_out_channels: int = 256, decoder_channels: int = 256, dilation_rates: Sequence[int] = (6, 12, 18))
forward ¶
Forward pass through DeepLabV3+ head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
list[Tensor]
|
List of backbone features [C2, C3, C4, C5] |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Segmentation logits [B, num_classes, H/4, W/4] |
FCNHead¶
Fully Convolutional Network (FCN) head for semantic segmentation. Simple upsampling-based architecture suitable for baseline models.
autotimm.heads.FCNHead ¶
Bases: Module
Simple Fully Convolutional Network head for segmentation.
A simpler baseline compared to DeepLabV3+.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels from backbone. |
required |
num_classes
|
int
|
Number of segmentation classes. |
required |
intermediate_channels
|
int
|
Number of intermediate channels. |
512
|
forward ¶
Forward pass through FCN head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
list[Tensor]
|
List of backbone features (uses last one) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Segmentation logits [B, num_classes, H, W] |
MaskHead¶
ROI-based mask prediction head for instance segmentation. Predicts per-instance binary masks from ROI-aligned features.
autotimm.heads.MaskHead ¶
Bases: Module
Mask prediction head for instance segmentation.
Takes ROI-aligned features and predicts per-instance binary masks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels from ROI align. |
256
|
num_classes
|
int
|
Number of object classes. |
80
|
hidden_channels
|
int
|
Number of hidden layer channels. |
256
|
num_convs
|
int
|
Number of convolutional layers before deconv. |
4
|
mask_size
|
int
|
Output mask resolution. |
28
|
__init__ ¶
__init__(in_channels: int = 256, num_classes: int = 80, hidden_channels: int = 256, num_convs: int = 4, mask_size: int = 28)
forward ¶
Forward pass through mask head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
roi_features
|
Tensor
|
ROI-aligned features [N, C, H, W] |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mask logits [N, num_classes, mask_size, mask_size] |
ASPP¶
Atrous Spatial Pyramid Pooling module for multi-scale feature extraction. Core component of DeepLabV3+ architecture.
autotimm.heads.ASPP ¶
Bases: Module
Atrous Spatial Pyramid Pooling module for DeepLabV3+.
Applies parallel atrous convolutions with different dilation rates to capture multi-scale context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels from backbone. |
required |
out_channels
|
int
|
Number of output channels. |
256
|
dilation_rates
|
Sequence[int]
|
List of dilation rates for parallel branches. |
(6, 12, 18)
|
Segmentation Losses¶
DiceLoss¶
Dice loss for multi-class semantic segmentation. Optimizes directly for IoU-like metric. Formula: 1 - (2 * |X ∩ Y|) / (|X| + |Y|). Effective for handling class imbalance.
autotimm.losses.DiceLoss ¶
Bases: Module
Dice loss for multi-class segmentation.
Formula: 1 - (2 * |X ∩ Y|) / (|X| + |Y|)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of segmentation classes |
required |
smooth
|
float
|
Smoothing constant to avoid division by zero (default: 1.0) |
1.0
|
ignore_index
|
int
|
Index to ignore in loss computation (default: 255) |
255
|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') (default: 'mean') |
'mean'
|
CombinedSegmentationLoss¶
Combines Cross-Entropy and Dice losses with configurable weights. Leverages pixel-wise classification (CE) and region overlap optimization (Dice) for robust segmentation.
autotimm.losses.CombinedSegmentationLoss ¶
Bases: Module
Combined cross-entropy and Dice loss for semantic segmentation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of segmentation classes |
required |
ce_weight
|
float
|
Weight for cross-entropy loss (default: 1.0) |
1.0
|
dice_weight
|
float
|
Weight for Dice loss (default: 1.0) |
1.0
|
ignore_index
|
int
|
Index to ignore in loss computation (default: 255) |
255
|
class_weights
|
Tensor | None
|
Optional per-class weights for CE loss (default: None) |
None
|
__init__ ¶
__init__(num_classes: int, ce_weight: float = 1.0, dice_weight: float = 1.0, ignore_index: int = 255, class_weights: Tensor | None = None)
forward ¶
Compute combined loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Predicted logits [B, C, H, W] |
required |
targets
|
Tensor
|
Ground truth masks [B, H, W] with class indices |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Combined loss value |
FocalLossPixelwise¶
Focal loss for dense pixel-wise prediction. Down-weights easy examples to focus on hard pixels. Particularly effective for handling severe class imbalance in segmentation tasks.
autotimm.losses.FocalLossPixelwise ¶
Bases: Module
Focal loss for dense pixel-wise prediction.
Handles class imbalance by down-weighting easy examples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Weighting factor in [0, 1] to balance positive/negative examples |
0.25
|
gamma
|
float
|
Focusing parameter for modulating loss (default: 2.0) |
2.0
|
ignore_index
|
int
|
Index to ignore in loss computation (default: 255) |
255
|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') (default: 'mean') |
'mean'
|
TverskyLoss¶
Generalized Dice loss with configurable false positive/negative trade-off via alpha and beta parameters. Useful for highly imbalanced segmentation tasks.
autotimm.losses.TverskyLoss ¶
Bases: Module
Tversky loss for segmentation.
Generalization of Dice loss with separate control over false positives and false negatives.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of segmentation classes |
required |
alpha
|
float
|
Weight for false positives (default: 0.5) |
0.5
|
beta
|
float
|
Weight for false negatives (default: 0.5) |
0.5
|
smooth
|
float
|
Smoothing constant (default: 1.0) |
1.0
|
ignore_index
|
int
|
Index to ignore in loss computation (default: 255) |
255
|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') (default: 'mean') |
'mean'
|
__init__ ¶
__init__(num_classes: int, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1.0, ignore_index: int = 255, reduction: str = 'mean')
forward ¶
Compute Tversky loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Predicted logits [B, C, H, W] |
required |
targets
|
Tensor
|
Ground truth masks [B, H, W] with class indices |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Tversky loss value |
MaskLoss¶
Binary cross-entropy loss for instance segmentation masks. Used in conjunction with detection losses for end-to-end instance segmentation training.
autotimm.losses.MaskLoss ¶
Bases: Module
Binary cross-entropy loss for instance segmentation masks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') (default: 'mean') |
'mean'
|
pos_weight
|
float | None
|
Weight for positive examples (default: None) |
None
|
forward ¶
Compute binary cross-entropy loss for masks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_masks
|
Tensor
|
Predicted masks [N, H, W] or [N, 1, H, W] (logits) |
required |
target_masks
|
Tensor
|
Ground truth binary masks [N, H, W] or [N, 1, H, W] |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mask loss value |
Import Styles¶
AutoTimm supports multiple import styles for convenience:
Direct Imports¶
from autotimm import (
SemanticSegmentor,
InstanceSegmentor,
SegmentationDataModule,
DiceLoss,
DeepLabV3PlusHead,
)
Submodule Aliases¶
# Use singular form aliases
from autotimm.task import SemanticSegmentor, InstanceSegmentor
from autotimm.loss import DiceLoss, CombinedSegmentationLoss
from autotimm.head import DeepLabV3PlusHead, MaskHead
from autotimm.metric import MetricConfig
Original Imports¶
# Original plural form still works
from autotimm.tasks import SemanticSegmentor
from autotimm.losses import DiceLoss
from autotimm.heads import DeepLabV3PlusHead
from autotimm.core.metrics import MetricConfig
Package Alias¶
import autotimm as at # recommended alias
model = at.SemanticSegmentor(backbone="resnet50", num_classes=19)
loss = at.DiceLoss(num_classes=19)
Namespace Access¶
import autotimm as at # recommended alias
# Access via submodule aliases
model = at.task.SemanticSegmentor(...)
loss = at.losses.DiceLoss(...)
head = at.heads.DeepLabV3PlusHead(...)
Parameters Reference¶
SemanticSegmentor Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
backbone |
str | FeatureBackboneConfig | Required | Timm model name or config |
num_classes |
int | Required | Number of segmentation classes |
head_type |
str | "deeplabv3plus" | Head architecture ("deeplabv3plus" or "fcn") |
loss_type |
str | "combined" | Loss function type |
dice_weight |
float | 1.0 | Weight for Dice loss in combined mode |
ce_weight |
float | 1.0 | Weight for CE loss in combined mode |
ignore_index |
int | 255 | Index to ignore in loss computation |
metrics |
list[MetricConfig] | None | Metric configurations |
lr |
float | 1e-4 | Learning rate |
weight_decay |
float | 1e-4 | Weight decay |
optimizer |
str | dict | "adamw" | Optimizer name or config |
optimizer_kwargs |
dict | None | None | Extra optimizer kwargs |
scheduler |
str | dict | "cosine" | Scheduler name or config |
scheduler_kwargs |
dict | None | None | Extra scheduler kwargs |
freeze_backbone |
bool | False | Whether to freeze backbone parameters |
transform_config |
TransformConfig | None | None | Transform config for preprocessing |
logging_config |
LoggingConfig | None | None | Enhanced logging configuration |
loss_fn |
str | nn.Module | None | None | Custom loss function (string from registry, nn.Module, or None for default) |
class_weights |
torch.Tensor | None | None | Class weights for loss computation |
compile_model |
bool | True | Apply torch.compile() for faster training/inference |
compile_kwargs |
dict | None | None | Kwargs for torch.compile() |
seed |
int | None | None | Random seed for reproducibility (None to disable) |
deterministic |
bool | True | Enable deterministic algorithms |
InstanceSegmentor Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
backbone |
str | FeatureBackboneConfig | Required | Timm model name or config |
num_classes |
int | Required | Number of object classes (excluding background) |
fpn_channels |
int | 256 | Number of FPN channels |
head_num_convs |
int | 4 | Number of conv layers in detection head |
mask_size |
int | 28 | ROI mask resolution |
roi_pool_size |
int | 14 | ROI pooling output size |
mask_loss_weight |
float | 1.0 | Weight for mask loss |
focal_alpha |
float | 0.25 | Alpha parameter for focal loss |
focal_gamma |
float | 2.0 | Gamma parameter for focal loss |
cls_loss_weight |
float | 1.0 | Weight for classification loss |
reg_loss_weight |
float | 1.0 | Weight for regression loss |
centerness_loss_weight |
float | 1.0 | Weight for centerness loss |
score_thresh |
float | 0.05 | Score threshold for detections |
nms_thresh |
float | 0.5 | IoU threshold for NMS |
max_detections_per_image |
int | 100 | Maximum detections per image |
mask_threshold |
float | 0.5 | Threshold for binarizing predicted masks |
metrics |
list[MetricConfig] | None | Metric configurations |
logging_config |
LoggingConfig | None | Enhanced logging configuration |
lr |
float | 1e-4 | Learning rate |
weight_decay |
float | 1e-4 | Weight decay |
optimizer |
str | dict | "adamw" | Optimizer name or config |
optimizer_kwargs |
dict | None | None | Extra optimizer kwargs |
scheduler |
str | dict | "cosine" | Scheduler name or config |
scheduler_kwargs |
dict | None | None | Extra scheduler kwargs |
freeze_backbone |
bool | False | Whether to freeze backbone parameters |
transform_config |
TransformConfig | None | None | Transform config for preprocessing |
cls_loss_fn |
str | nn.Module | None | None | Classification loss function |
reg_loss_fn |
str | nn.Module | None | None | Regression loss function |
mask_loss_fn |
str | nn.Module | None | None | Mask loss function |
compile_model |
bool | True | Apply torch.compile() for faster training/inference |
compile_kwargs |
dict | None | None | Kwargs for torch.compile() |
seed |
int | None | None | Random seed for reproducibility (None to disable) |
deterministic |
bool | True | Enable deterministic algorithms |
SegmentationDataModule Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
data_dir |
str | Path | Required | Root directory of dataset |
format |
str | "png" | Dataset format ("png", "coco", "cityscapes", "voc") |
image_size |
int | 512 | Target image size (square) |
batch_size |
int | 8 | Batch size |
num_workers |
int | 4 | Number of dataloader workers |
augmentation_preset |
str | "default" | Augmentation preset ("default", "strong", "light") |
custom_train_transforms |
Any | None | Custom training transforms (overrides preset) |
custom_val_transforms |
Any | None | Custom validation transforms |
class_mapping |
dict | None | Mapping from dataset class IDs to contiguous IDs |
ignore_index |
int | 255 | Index for ignored pixels (e.g., boundaries) |
transform_config |
TransformConfig | None | None | Unified transform configuration |
backbone |
str | nn.Module | None | None | Backbone for model-specific normalization |
InstanceSegmentationDataModule Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
data_dir |
str | Path | Required | Root directory of COCO dataset |
image_size |
int | 640 | Target image size (square) |
batch_size |
int | 4 | Batch size |
num_workers |
int | 4 | Number of dataloader workers |
augmentation_preset |
str | "default" | Augmentation strength ("default", "strong", "light") |
custom_train_transforms |
Any | None | Custom training transforms (overrides preset) |
custom_val_transforms |
Any | None | Custom validation transforms |
min_keypoints |
int | 0 | Minimum keypoints for valid instance |
min_area |
float | 0.0 | Minimum area for valid instance |
transform_config |
TransformConfig | None | None | Unified transform configuration |
backbone |
str | nn.Module | None | None | Backbone for model-specific normalization |
Examples¶
Basic Semantic Segmentation¶
from autotimm import SemanticSegmentor, SegmentationDataModule, MetricConfig
# Setup data
data = SegmentationDataModule(
data_dir="./data",
format="png",
image_size=512,
batch_size=8,
)
# Create model
model = SemanticSegmentor(
backbone="resnet50",
num_classes=21,
head_type="deeplabv3plus",
loss_type="combined",
ce_weight=1.0,
dice_weight=1.0,
metrics=[
MetricConfig(
name="iou",
backend="torchmetrics",
metric_class="JaccardIndex",
params={"task": "multiclass", "num_classes": 21, "average": "macro"},
stages=["val"],
prog_bar=True,
)
],
)
Instance Segmentation¶
from autotimm import InstanceSegmentor, InstanceSegmentationDataModule, MetricConfig
# Setup data
data = InstanceSegmentationDataModule(
data_dir="./coco",
image_size=640,
batch_size=4,
)
# Create model
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
mask_loss_weight=1.0,
mask_size=28,
roi_pool_size=14,
metrics=[
MetricConfig(
name="mask_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "segm"},
stages=["val"],
prog_bar=True,
)
],
)
Using Import Aliases¶
from autotimm.task import SemanticSegmentor, InstanceSegmentor
from autotimm.loss import DiceLoss, CombinedSegmentationLoss, MaskLoss
from autotimm.head import DeepLabV3PlusHead, MaskHead
from autotimm.metric import MetricConfig
# Create model using aliases
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
head_type="deeplabv3plus",
loss_type="combined",
)
# Directly instantiate losses if needed
dice_loss = DiceLoss(num_classes=19, ignore_index=255)
combined_loss = CombinedSegmentationLoss(
num_classes=19,
ce_weight=1.0,
dice_weight=1.0,
)
mask_loss = MaskLoss()
Custom Loss Configuration¶
from autotimm.loss import TverskyLoss
# Create custom Tversky loss for handling class imbalance
loss_fn = TverskyLoss(
num_classes=19,
alpha=0.3, # Lower alpha emphasizes recall
beta=0.7, # Higher beta penalizes false negatives more
ignore_index=255,
)
# Note: Use loss_type parameter in model for built-in losses
# For custom losses, you would need to subclass and override the loss
Advanced Configuration¶
from autotimm import SemanticSegmentor, LoggingConfig
# Model with enhanced logging
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
head_type="deeplabv3plus",
loss_type="combined",
ce_weight=1.0,
dice_weight=1.0,
logging_config=LoggingConfig(
log_learning_rate=True,
log_gradient_norm=True,
log_weight_norm=True,
),
freeze_backbone=False, # Set True to freeze backbone
)
With TransformConfig (Preprocessing)¶
Enable inference-time preprocessing with model-specific normalization:
from autotimm import SemanticSegmentor, TransformConfig
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
head_type="deeplabv3plus",
metrics=[...],
transform_config=TransformConfig(image_size=512), # Enable preprocess()
)
# Preprocess raw images for inference
from PIL import Image
image = Image.open("test.jpg")
tensor = model.preprocess(image) # Returns preprocessed tensor
# Run inference
model.eval()
with torch.inference_mode():
predictions = model(tensor)
Shared Config for DataModule and Model¶
from autotimm import SemanticSegmentor, SegmentationDataModule, TransformConfig
# Shared config ensures same preprocessing
config = TransformConfig(preset="default", image_size=512)
backbone_name = "resnet50"
# DataModule uses model's normalization
data = SegmentationDataModule(
data_dir="./cityscapes",
format="cityscapes",
transform_config=config,
backbone=backbone_name,
)
# Model uses same config for inference preprocessing
model = SemanticSegmentor(
backbone=backbone_name,
num_classes=19,
metrics=[...],
transform_config=config,
)
Instance Segmentation with Preprocessing¶
from autotimm import InstanceSegmentor, TransformConfig
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
metrics=[...],
transform_config=TransformConfig(image_size=640),
)
# Get model's normalization config
config = model.get_data_config()
print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")