Loss Function Comparison¶
AutoTimm provides specialized loss functions for object detection and segmentation tasks. This guide compares available losses and helps you choose the right one for your use case.
Loss Function Registry¶
AutoTimm includes a centralized loss registry that makes it easy to discover, access, and use built-in loss functions across all tasks.
Quick Start¶
from autotimm.losses import list_available_losses, get_loss_registry
# List all available losses
print(list_available_losses())
# Output: ['bce', 'bce_with_logits', 'centerness', 'combined_segmentation',
# 'cross_entropy', 'dice', 'fcos', 'focal', 'focal_pixelwise',
# 'giou', 'mask', 'mse', 'nll', 'tversky']
# List losses by task
print(list_available_losses(task="segmentation"))
# Output: ['combined_segmentation', 'dice', 'focal_pixelwise', 'mask', 'tversky']
# Get a loss from the registry
registry = get_loss_registry()
dice_loss = registry.get_loss("dice", num_classes=10)
Using Losses in Models¶
You can now pass loss functions to models in three ways:
1. By Name (from Registry)
from autotimm import SemanticSegmentor
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
loss_fn="dice", # Use loss by name
)
2. Custom Loss Instance
import torch.nn as nn
class MyCustomLoss(nn.Module):
def forward(self, input, target):
# Your custom loss logic
return loss_value
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
loss_fn=MyCustomLoss(), # Pass instance
)
3. From Registry with Parameters
from autotimm.losses import get_loss_registry
registry = get_loss_registry()
dice_loss = registry.get_loss("dice", num_classes=19, smooth=2.0)
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
loss_fn=dice_loss,
)
Available Losses by Task¶
Classification:
- cross_entropy - Standard cross-entropy loss
- bce / bce_with_logits - Binary cross-entropy for multi-label
- nll - Negative log likelihood
- mse - Mean squared error
Detection:
- focal - Focal loss for classification
- giou - Generalized IoU for box regression
- centerness - Centerness loss (FCOS)
- fcos - Combined FCOS loss
Segmentation:
- dice - Dice loss for class overlap
- focal_pixelwise - Pixel-wise focal loss
- tversky - Tversky loss (generalized Dice)
- mask - Binary mask loss
- combined_segmentation - Combined CE + Dice
Registering Custom Losses¶
from autotimm.losses import register_custom_loss
import torch.nn as nn
class MyWeightedLoss(nn.Module):
def __init__(self, weights):
super().__init__()
self.weights = weights
def forward(self, input, target):
return (input - target).abs().mean() * self.weights
# Register globally
register_custom_loss("weighted_loss", MyWeightedLoss, alias="wl")
# Now you can use it by name
model = ImageClassifier(
backbone="resnet18",
num_classes=10,
loss_fn="weighted_loss", # or "wl"
)
Detection Losses¶
AutoTimm implements FCOS-style detection losses optimized for anchor-free object detection.
FocalLoss¶
Focal Loss addresses class imbalance by down-weighting well-classified examples and focusing on hard negatives.
from autotimm import FocalLoss
loss_fn = FocalLoss(
alpha=0.25, # Weight for positive examples
gamma=2.0, # Focusing parameter (higher = more focus on hard examples)
reduction="mean", # "mean", "sum", or "none"
)
# Usage
loss = loss_fn(predictions, targets)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
alpha |
float | 0.25 | Weighting factor for positive examples. Higher values give more weight to positives. |
gamma |
float | 2.0 | Focusing parameter. Higher values increase focus on hard examples. |
reduction |
str | "mean" | Reduction method: "none", "mean", or "sum" |
When to Use:
- Class imbalanced detection datasets
- When background overwhelms foreground examples
- Standard choice for anchor-free detectors
Tuning Guidelines:
gamma=2.0: Standard for most casesgamma=1.5: Less aggressive focusing (for balanced datasets)gamma=3.0: More aggressive (for severe imbalance)alpha=0.25: Standard for detectionalpha=0.5: Equal weighting
GIoULoss¶
Generalized IoU Loss for bounding box regression. Provides better gradients than standard IoU loss when boxes don't overlap.
from autotimm import GIoULoss
loss_fn = GIoULoss(
reduction="mean",
eps=1e-7,
)
# Usage: boxes in (x1, y1, x2, y2) format
loss = loss_fn(pred_boxes, target_boxes)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
reduction |
str | "mean" | Reduction method: "none", "mean", or "sum" |
eps |
float | 1e-7 | Small value for numerical stability |
When to Use:
- Bounding box regression in detection tasks
- When boxes may not overlap (early training)
- Better convergence than L1/L2 losses for boxes
Comparison with Other Box Losses:
| Loss | Handles Non-Overlap | Scale Invariant | Typical Use |
|---|---|---|---|
| L1/L2 Loss | Yes | No | Fast, simple |
| IoU Loss | No | Yes | Overlapping boxes |
| GIoU Loss | Yes | Yes | General detection |
| DIoU Loss | Yes | Yes | Center-aware |
| CIoU Loss | Yes | Yes | Aspect ratio-aware |
CenternessLoss¶
Binary cross-entropy loss for FCOS centerness prediction. Centerness helps suppress low-quality predictions far from object centers.
from autotimm import CenternessLoss
loss_fn = CenternessLoss(reduction="mean")
# Usage: centerness values in [0, 1]
loss = loss_fn(pred_centerness, target_centerness)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
reduction |
str | "mean" | Reduction method: "none", "mean", or "sum" |
When to Use:
- FCOS-style anchor-free detection
- Quality prediction for bounding boxes
- Combined with classification and regression losses
FCOSLoss¶
Combined FCOS loss that integrates classification, regression, and centerness losses.
from autotimm import FCOSLoss
loss_fn = FCOSLoss(
num_classes=80,
focal_alpha=0.25,
focal_gamma=2.0,
cls_weight=1.0,
reg_weight=1.0,
centerness_weight=1.0,
)
# Usage
losses = loss_fn(
cls_preds, reg_preds, centerness_preds,
cls_targets, reg_targets, centerness_targets,
)
# Returns: {"cls_loss", "reg_loss", "centerness_loss", "total_loss"}
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
num_classes |
int | Required | Number of object classes (excluding background) |
focal_alpha |
float | 0.25 | Alpha for focal loss |
focal_gamma |
float | 2.0 | Gamma for focal loss |
cls_weight |
float | 1.0 | Weight for classification loss |
reg_weight |
float | 1.0 | Weight for regression loss |
centerness_weight |
float | 1.0 | Weight for centerness loss |
When to Use:
- Complete FCOS training pipeline
- Anchor-free object detection
- Multi-scale detection
Tuning Guidelines:
- Default weights (1:1:1) work well for most cases
- Increase
reg_weightif localization is poor - Increase
cls_weightif classification accuracy is low
Detection Loss Comparison¶
| Loss | Purpose | Pros | Cons | Best For |
|---|---|---|---|---|
| FocalLoss | Classification | Handles imbalance, proven performance | Sensitive to alpha/gamma | Class-imbalanced datasets |
| GIoULoss | Box regression | Scale invariant, handles non-overlap | Slightly slower than L1 | General detection |
| CenternessLoss | Quality prediction | Improves NMS, filters poor boxes | Requires centerness targets | FCOS-style detectors |
| FCOSLoss | Combined | All-in-one, balanced training | Fixed architecture | Complete FCOS training |
Segmentation Losses¶
AutoTimm provides losses for both semantic and instance segmentation tasks.
DiceLoss¶
Dice Loss measures overlap between predicted and ground truth masks. Effective for imbalanced segmentation.
from autotimm import DiceLoss
loss_fn = DiceLoss(
num_classes=21,
smooth=1.0,
ignore_index=255,
reduction="mean",
)
# Usage
loss = loss_fn(logits, targets) # [B, C, H, W], [B, H, W]
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
num_classes |
int | Required | Number of segmentation classes |
smooth |
float | 1.0 | Smoothing constant to avoid division by zero |
ignore_index |
int | 255 | Index to ignore in loss computation |
reduction |
str | "mean" | Reduction method: "none", "mean", or "sum" |
When to Use:
- Class-imbalanced segmentation
- Small object segmentation
- Medical image segmentation
Tuning Guidelines:
smooth=1.0: Standard, prevents NaN for empty maskssmooth=0.01: Less smoothing, sharper gradientsignore_index=255: Standard for Pascal VOC/Cityscapes
FocalLossPixelwise¶
Pixel-wise focal loss for dense prediction. Handles class imbalance at the pixel level.
from autotimm import FocalLossPixelwise
loss_fn = FocalLossPixelwise(
alpha=0.25,
gamma=2.0,
ignore_index=255,
reduction="mean",
)
# Usage
loss = loss_fn(logits, targets) # [B, C, H, W], [B, H, W]
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
alpha |
float | 0.25 | Weighting factor for positive/rare classes |
gamma |
float | 2.0 | Focusing parameter |
ignore_index |
int | 255 | Index to ignore in loss computation |
reduction |
str | "mean" | Reduction method |
When to Use:
- Severely imbalanced pixel distributions
- When some classes are much rarer than others
- Scene parsing with many small objects
TverskyLoss¶
Generalization of Dice loss with separate control over false positives and false negatives.
from autotimm import TverskyLoss
loss_fn = TverskyLoss(
num_classes=21,
alpha=0.5, # Weight for false positives
beta=0.5, # Weight for false negatives
smooth=1.0,
ignore_index=255,
reduction="mean",
)
# Usage
loss = loss_fn(logits, targets)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
num_classes |
int | Required | Number of segmentation classes |
alpha |
float | 0.5 | Weight for false positives |
beta |
float | 0.5 | Weight for false negatives |
smooth |
float | 1.0 | Smoothing constant |
ignore_index |
int | 255 | Index to ignore |
reduction |
str | "mean" | Reduction method |
When to Use:
- Control precision vs recall trade-off
- Medical imaging (minimize false negatives)
- Small object detection (minimize false positives)
Tuning Guidelines:
| Setting | alpha | beta | Effect |
|---|---|---|---|
| Dice Loss (equal) | 0.5 | 0.5 | Balanced |
| Precision focus | 0.7 | 0.3 | Penalize false positives more |
| Recall focus | 0.3 | 0.7 | Penalize false negatives more |
MaskLoss¶
Binary cross-entropy loss for instance segmentation masks.
from autotimm import MaskLoss
loss_fn = MaskLoss(
reduction="mean",
pos_weight=1.0, # Weight for positive pixels
)
# Usage: per-instance binary masks
loss = loss_fn(pred_masks, target_masks) # [N, H, W], [N, H, W]
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
reduction |
str | "mean" | Reduction method |
pos_weight |
float | None | Weight for positive (foreground) pixels |
When to Use:
- Instance segmentation tasks
- Per-object binary mask prediction
- Mask R-CNN style architectures
CombinedSegmentationLoss¶
Combines cross-entropy and Dice loss for robust semantic segmentation training.
from autotimm import CombinedSegmentationLoss
import torch
loss_fn = CombinedSegmentationLoss(
num_classes=21,
ce_weight=1.0,
dice_weight=1.0,
ignore_index=255,
class_weights=None, # Optional per-class weights
)
# With class weights for imbalanced data
class_weights = torch.tensor([1.0, 2.0, 2.0, ...]) # One per class
loss_fn = CombinedSegmentationLoss(
num_classes=21,
ce_weight=1.0,
dice_weight=1.0,
class_weights=class_weights,
)
# Usage
loss = loss_fn(logits, targets)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
num_classes |
int | Required | Number of segmentation classes |
ce_weight |
float | 1.0 | Weight for cross-entropy loss |
dice_weight |
float | 1.0 | Weight for Dice loss |
ignore_index |
int | 255 | Index to ignore |
class_weights |
Tensor | None | Per-class weights for CE loss |
When to Use:
- Default choice for semantic segmentation
- Combines pixel-wise accuracy (CE) with region overlap (Dice)
- Robust to class imbalance
Tuning Guidelines:
- Start with equal weights (1:1)
- Increase
dice_weightfor better IoU scores - Increase
ce_weightfor better pixel accuracy - Use
class_weightsfor severely imbalanced datasets
Segmentation Loss Comparison¶
| Loss | Handles Imbalance | Best For | Typical Use Case |
|---|---|---|---|
| CrossEntropy | No (without weights) | Balanced datasets | General segmentation |
| DiceLoss | Yes | Region overlap, small objects | Medical imaging |
| FocalLossPixelwise | Yes | Severe pixel imbalance | Scene parsing |
| TverskyLoss | Yes | Precision/recall control | Domain-specific needs |
| MaskLoss | Optional (pos_weight) | Binary masks | Instance segmentation |
| CombinedSegmentationLoss | Yes | General semantic segmentation | Default choice |
Task-Specific Recommendations¶
Image Classification¶
Use CrossEntropyLoss (built into task classes):
from autotimm import ImageClassifier
# CrossEntropy is used automatically
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
metrics=metrics,
)
Object Detection¶
from autotimm import ObjectDetector
# FCOSLoss is used automatically with these defaults
model = ObjectDetector(
backbone="resnet50",
num_classes=80,
metrics=metrics,
# Loss weights can be adjusted via model parameters if needed
)
Semantic Segmentation¶
from autotimm import SemanticSegmentor, CombinedSegmentationLoss
# Default: Combined CE + Dice
model = SemanticSegmentor(
backbone="resnet50",
num_classes=21,
metrics=metrics,
)
# Custom loss configuration
loss_fn = CombinedSegmentationLoss(
num_classes=21,
ce_weight=0.5,
dice_weight=1.5, # Emphasize Dice for better IoU
)
Instance Segmentation¶
from autotimm import InstanceSegmentor
# Uses FCOSLoss for detection + MaskLoss for masks
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
metrics=metrics,
)
Complete Example: Custom Loss Configuration¶
from autotimm import (
AutoTrainer,
CombinedSegmentationLoss,
LoggerConfig,
MetricConfig,
SegmentationDataModule,
SemanticSegmentor,
)
import torch
def main():
# Data
data = SegmentationDataModule(
data_dir="./cityscapes",
image_size=512,
batch_size=8,
)
# Metrics
metrics = [
MetricConfig(
name="iou",
backend="torchmetrics",
metric_class="JaccardIndex",
params={"task": "multiclass", "num_classes": 19},
stages=["val", "test"],
prog_bar=True,
),
]
# Class weights for Cityscapes (example: road and building are more common)
class_weights = torch.ones(19)
class_weights[0] = 0.5 # road (common)
class_weights[1] = 0.5 # sidewalk (common)
class_weights[11] = 2.0 # person (less common)
class_weights[12] = 2.0 # rider (rare)
# Custom loss
loss_fn = CombinedSegmentationLoss(
num_classes=19,
ce_weight=1.0,
dice_weight=1.5,
class_weights=class_weights,
)
# Model
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
metrics=metrics,
lr=1e-4,
)
# Trainer
trainer = AutoTrainer(
max_epochs=50,
logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
checkpoint_monitor="val/iou",
checkpoint_mode="max",
gradient_clip_val=1.0,
)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()
See Also¶
- Training Guide - Complete training documentation
- Metric Selection - Choosing metrics for your task
- API Reference: Losses - Full API documentation