Skip to content

Loss Functions

Loss functions for object detection tasks.

FocalLoss

Focal Loss for addressing class imbalance in one-stage object detectors.

Overview

Focal Loss reshapes the standard cross-entropy loss to down-weight easy examples and focus training on hard negatives. This is crucial for one-stage detectors where there's extreme imbalance between foreground and background classes.

Formula:

FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

where:
  p_t = p if y = 1, else 1 - p
  α_t = α if y = 1, else 1 - α

API Reference

autotimm.FocalLoss

Bases: Module

Focal Loss for addressing class imbalance in object detection.

Focal Loss reduces the relative loss for well-classified examples, focusing training on hard negatives.

FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)

Parameters:

Name Type Description Default
alpha float

Weighting factor for positive examples. Default 0.25.

0.25
gamma float

Focusing parameter. Higher values give more weight to hard examples. Default 2.0.

2.0
reduction str

Reduction method: 'none', 'mean', or 'sum'.

'mean'
Source code in src/autotimm/losses/detection.py
class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance in object detection.

    Focal Loss reduces the relative loss for well-classified examples,
    focusing training on hard negatives.

    FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)

    Parameters:
        alpha: Weighting factor for positive examples. Default 0.25.
        gamma: Focusing parameter. Higher values give more weight to hard
            examples. Default 2.0.
        reduction: Reduction method: 'none', 'mean', or 'sum'.
    """

    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        reduction: str = "mean",
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
    ) -> torch.Tensor:
        """Compute focal loss.

        Args:
            inputs: Predicted logits of shape [N, C] or [N, C, H, W].
            targets: Ground truth class indices of shape [N] or [N, H, W].
                Use -1 to ignore samples.

        Returns:
            Focal loss value.
        """
        # Handle spatial dimensions
        if inputs.dim() == 4:
            # [N, C, H, W] -> [N, H, W, C] -> [N*H*W, C]
            n, c, h, w = inputs.shape
            inputs = inputs.permute(0, 2, 3, 1).reshape(-1, c)
            targets = targets.reshape(-1)

        # Filter out ignored samples (targets == -1)
        valid_mask = targets >= 0
        if not valid_mask.any():
            return inputs.sum() * 0.0

        inputs = inputs[valid_mask]
        targets = targets[valid_mask]

        # Compute probabilities
        p = torch.sigmoid(inputs)
        num_classes = inputs.shape[-1]

        # Create one-hot encoded targets
        targets_one_hot = F.one_hot(targets, num_classes).float()

        # Compute focal weights
        pt = p * targets_one_hot + (1 - p) * (1 - targets_one_hot)
        focal_weight = (1 - pt) ** self.gamma

        # Compute binary cross entropy
        bce = F.binary_cross_entropy_with_logits(
            inputs, targets_one_hot, reduction="none"
        )

        # Apply focal weight and alpha
        alpha_t = self.alpha * targets_one_hot + (1 - self.alpha) * (
            1 - targets_one_hot
        )
        focal_loss = alpha_t * focal_weight * bce

        # Sum over classes, then reduce
        focal_loss = focal_loss.sum(dim=-1)

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        return focal_loss

__init__

__init__(alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean')
Source code in src/autotimm/losses/detection.py
def __init__(
    self,
    alpha: float = 0.25,
    gamma: float = 2.0,
    reduction: str = "mean",
):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction

forward

forward(inputs: Tensor, targets: Tensor) -> torch.Tensor

Compute focal loss.

Parameters:

Name Type Description Default
inputs Tensor

Predicted logits of shape [N, C] or [N, C, H, W].

required
targets Tensor

Ground truth class indices of shape [N] or [N, H, W]. Use -1 to ignore samples.

required

Returns:

Type Description
Tensor

Focal loss value.

Source code in src/autotimm/losses/detection.py
def forward(
    self,
    inputs: torch.Tensor,
    targets: torch.Tensor,
) -> torch.Tensor:
    """Compute focal loss.

    Args:
        inputs: Predicted logits of shape [N, C] or [N, C, H, W].
        targets: Ground truth class indices of shape [N] or [N, H, W].
            Use -1 to ignore samples.

    Returns:
        Focal loss value.
    """
    # Handle spatial dimensions
    if inputs.dim() == 4:
        # [N, C, H, W] -> [N, H, W, C] -> [N*H*W, C]
        n, c, h, w = inputs.shape
        inputs = inputs.permute(0, 2, 3, 1).reshape(-1, c)
        targets = targets.reshape(-1)

    # Filter out ignored samples (targets == -1)
    valid_mask = targets >= 0
    if not valid_mask.any():
        return inputs.sum() * 0.0

    inputs = inputs[valid_mask]
    targets = targets[valid_mask]

    # Compute probabilities
    p = torch.sigmoid(inputs)
    num_classes = inputs.shape[-1]

    # Create one-hot encoded targets
    targets_one_hot = F.one_hot(targets, num_classes).float()

    # Compute focal weights
    pt = p * targets_one_hot + (1 - p) * (1 - targets_one_hot)
    focal_weight = (1 - pt) ** self.gamma

    # Compute binary cross entropy
    bce = F.binary_cross_entropy_with_logits(
        inputs, targets_one_hot, reduction="none"
    )

    # Apply focal weight and alpha
    alpha_t = self.alpha * targets_one_hot + (1 - self.alpha) * (
        1 - targets_one_hot
    )
    focal_loss = alpha_t * focal_weight * bce

    # Sum over classes, then reduce
    focal_loss = focal_loss.sum(dim=-1)

    if self.reduction == "mean":
        return focal_loss.mean()
    elif self.reduction == "sum":
        return focal_loss.sum()
    return focal_loss

Usage Example

from autotimm import FocalLoss
import torch

loss_fn = FocalLoss(
    alpha=0.25,
    gamma=2.0,
)

# Predictions and targets
pred = torch.randn(32, 80, 100, 100)  # (B, C, H, W)
target = torch.randint(0, 80, (32, 100, 100))  # (B, H, W)

loss = loss_fn(pred, target)

Parameters

Parameter Type Default Description
alpha float 0.25 Weighting factor for class 1 (foreground)
gamma float 2.0 Focusing parameter for hard examples
reduction str "mean" Loss reduction: "mean", "sum", or "none"

How It Works

Alpha (α):

  • Controls the relative importance of positive vs negative examples
  • α = 0.25 means positive examples get 25% weight, negatives get 75%
  • Higher α → focus more on positives (foreground)
  • Standard value: 0.25

Gamma (γ):

  • Controls how much to down-weight easy examples
  • γ = 0: Equivalent to standard cross-entropy
  • γ = 2: Standard focal loss (down-weights easy examples significantly)
  • Higher γ → focuses more on hard examples
  • Standard value: 2.0

Effect:

  • Easy examples (high confidence, correct predictions): Very low loss
  • Hard examples (low confidence): High loss
  • Result: Model focuses on learning hard negatives

When to Use

Use Focal Loss when:

  • Extreme class imbalance (e.g., 1:1000 positive:negative ratio in detection)
  • One-stage detectors (FCOS, RetinaNet, YOLO)
  • Many easy negatives dominate the loss

Alternatives:

  • Balanced Cross Entropy: Simpler, use when imbalance is moderate
  • Weighted BCE: Manual class weights

Tuning Guidelines

Scenario Alpha Gamma Reason
Standard detection 0.25 2.0 Default, works for most cases
More positives 0.5 2.0 Less imbalance
Fewer positives 0.1 2.0 Extreme imbalance
Too many false positives 0.25 3.0 Focus even more on hard negatives
Missing detections 0.5 1.0 Easier on hard examples

GIoULoss

Generalized Intersection over Union (GIoU) Loss for bounding box regression.

Overview

GIoU Loss extends IoU to handle non-overlapping boxes and provides better gradient flow. Unlike L1 or L2 loss, GIoU is scale-invariant and directly optimizes the detection metric.

Formula:

IoU = |A ∩ B| / |A ∪ B|
GIoU = IoU - |C \ (A ∪ B)| / |C|

where:
  A, B are predicted and target boxes
  C is the smallest enclosing box

GIoU Loss = 1 - GIoU

API Reference

autotimm.GIoULoss

Bases: Module

Generalized IoU Loss for bounding box regression.

GIoU provides better gradients than standard IoU loss when boxes don't overlap, making it more suitable for training.

GIoU = IoU - (C - U) / C

where C is the smallest enclosing box and U is the union.

Parameters:

Name Type Description Default
reduction str

Reduction method: 'none', 'mean', or 'sum'.

'mean'
eps float

Small value for numerical stability.

1e-07
Source code in src/autotimm/losses/detection.py
class GIoULoss(nn.Module):
    """Generalized IoU Loss for bounding box regression.

    GIoU provides better gradients than standard IoU loss when boxes
    don't overlap, making it more suitable for training.

    GIoU = IoU - (C - U) / C

    where C is the smallest enclosing box and U is the union.

    Parameters:
        reduction: Reduction method: 'none', 'mean', or 'sum'.
        eps: Small value for numerical stability.
    """

    def __init__(self, reduction: str = "mean", eps: float = 1e-7):
        super().__init__()
        self.reduction = reduction
        self.eps = eps

    def forward(
        self,
        pred_boxes: torch.Tensor,
        target_boxes: torch.Tensor,
        weights: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Compute GIoU loss.

        Args:
            pred_boxes: Predicted boxes [N, 4] in (x1, y1, x2, y2) format.
            target_boxes: Target boxes [N, 4] in (x1, y1, x2, y2) format.
            weights: Optional per-box weights [N].

        Returns:
            GIoU loss value.
        """
        if pred_boxes.numel() == 0:
            return pred_boxes.sum() * 0.0

        giou = self._compute_giou(pred_boxes, target_boxes)
        loss = 1 - giou

        if weights is not None:
            loss = loss * weights

        if self.reduction == "mean":
            return (
                loss.mean()
                if weights is None
                else loss.sum() / weights.sum().clamp(min=self.eps)
            )
        elif self.reduction == "sum":
            return loss.sum()
        return loss

    def _compute_giou(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
        """Compute GIoU between two sets of boxes."""
        # Intersection
        inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
        inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
        inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
        inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])

        inter_w = (inter_x2 - inter_x1).clamp(min=0)
        inter_h = (inter_y2 - inter_y1).clamp(min=0)
        inter_area = inter_w * inter_h

        # Areas
        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

        # Union
        union_area = area1 + area2 - inter_area

        # IoU
        iou = inter_area / union_area.clamp(min=self.eps)

        # Enclosing box
        enclose_x1 = torch.min(boxes1[:, 0], boxes2[:, 0])
        enclose_y1 = torch.min(boxes1[:, 1], boxes2[:, 1])
        enclose_x2 = torch.max(boxes1[:, 2], boxes2[:, 2])
        enclose_y2 = torch.max(boxes1[:, 3], boxes2[:, 3])

        enclose_area = (enclose_x2 - enclose_x1) * (enclose_y2 - enclose_y1)

        # GIoU
        giou = iou - (enclose_area - union_area) / enclose_area.clamp(min=self.eps)

        return giou

__init__

__init__(reduction: str = 'mean', eps: float = 1e-07)
Source code in src/autotimm/losses/detection.py
def __init__(self, reduction: str = "mean", eps: float = 1e-7):
    super().__init__()
    self.reduction = reduction
    self.eps = eps

forward

forward(pred_boxes: Tensor, target_boxes: Tensor, weights: Tensor | None = None) -> torch.Tensor

Compute GIoU loss.

Parameters:

Name Type Description Default
pred_boxes Tensor

Predicted boxes [N, 4] in (x1, y1, x2, y2) format.

required
target_boxes Tensor

Target boxes [N, 4] in (x1, y1, x2, y2) format.

required
weights Tensor | None

Optional per-box weights [N].

None

Returns:

Type Description
Tensor

GIoU loss value.

Source code in src/autotimm/losses/detection.py
def forward(
    self,
    pred_boxes: torch.Tensor,
    target_boxes: torch.Tensor,
    weights: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute GIoU loss.

    Args:
        pred_boxes: Predicted boxes [N, 4] in (x1, y1, x2, y2) format.
        target_boxes: Target boxes [N, 4] in (x1, y1, x2, y2) format.
        weights: Optional per-box weights [N].

    Returns:
        GIoU loss value.
    """
    if pred_boxes.numel() == 0:
        return pred_boxes.sum() * 0.0

    giou = self._compute_giou(pred_boxes, target_boxes)
    loss = 1 - giou

    if weights is not None:
        loss = loss * weights

    if self.reduction == "mean":
        return (
            loss.mean()
            if weights is None
            else loss.sum() / weights.sum().clamp(min=self.eps)
        )
    elif self.reduction == "sum":
        return loss.sum()
    return loss

Usage Example

from autotimm import GIoULoss
import torch

loss_fn = GIoULoss()

# Boxes in (x1, y1, x2, y2) format
pred_boxes = torch.tensor([[10, 10, 50, 50], [20, 20, 60, 60]], dtype=torch.float32)
target_boxes = torch.tensor([[12, 12, 48, 48], [25, 25, 65, 65]], dtype=torch.float32)

loss = loss_fn(pred_boxes, target_boxes)

Parameters

Parameter Type Default Description
reduction str "mean" Loss reduction: "mean", "sum", or "none"
eps float 1e-7 Small value to avoid division by zero

How It Works

Standard IoU:

  • Range: [0, 1]
  • 1 = perfect overlap
  • 0 = no overlap
  • Problem: Zero gradient when boxes don't overlap

GIoU Improvement:

  • Range: [-1, 1]
  • Considers the area of the smallest enclosing box
  • Provides gradient signal even for non-overlapping boxes
  • Penalizes both misalignment and poor aspect ratios

Advantages over L1/L2: 1. Scale-invariant: Works equally well for small and large boxes 2. Direct optimization: Optimizes IoU directly, not coordinates 3. Better gradients: Non-zero gradients for non-overlapping boxes 4. Aspect ratio aware: Penalizes incorrect aspect ratios

GIoU Values

GIoU IoU Interpretation
1.0 1.0 Perfect match
0.5 0.5 50% overlap, no wasted space
0.0 0.0 No overlap, but touching
-0.5 0.0 No overlap, some distance
-1.0 0.0 No overlap, maximum distance

When to Use

Use GIoU Loss when:

  • Training object detectors
  • Need scale-invariant regression
  • Boxes can be non-overlapping during training
  • Want to optimize IoU metric directly

Alternatives:

  • IoU Loss: Simpler, but undefined for non-overlapping boxes
  • DIoU Loss: Faster convergence, considers center distance
  • CIoU Loss: Best for accurate localization, considers aspect ratio explicitly
  • L1 Loss: Simple, but not scale-invariant

Input Format

Boxes must be in (x1, y1, x2, y2) format where: - x1, y1: Top-left corner coordinates - x2, y2: Bottom-right corner coordinates - x2 > x1 and y2 > y1

# Correct format
boxes = torch.tensor([[10, 20, 50, 60]])  # (x1=10, y1=20, x2=50, y2=60)

# Convert from (x, y, w, h) if needed
def xywh_to_xyxy(boxes):
    x, y, w, h = boxes.unbind(-1)
    return torch.stack([x, y, x + w, y + h], dim=-1)

CenternessLoss

Binary cross-entropy loss for centerness prediction in FCOS.

Overview

Centerness predicts how "centered" a location is within its assigned object. It's used to down-weight low-quality bounding boxes that are far from object centers, improving detection quality without NMS.

Formula:

centerness = sqrt((min(l, r) / max(l, r)) * (min(t, b) / max(t, b)))

where l, t, r, b are distances from a location to the left, top, right, bottom
of its target bounding box.

API Reference

autotimm.CenternessLoss

Bases: Module

Binary cross-entropy loss for FCOS centerness prediction.

Parameters:

Name Type Description Default
reduction str

Reduction method: 'none', 'mean', or 'sum'.

'mean'
Source code in src/autotimm/losses/detection.py
class CenternessLoss(nn.Module):
    """Binary cross-entropy loss for FCOS centerness prediction.

    Parameters:
        reduction: Reduction method: 'none', 'mean', or 'sum'.
    """

    def __init__(self, reduction: str = "mean"):
        super().__init__()
        self.reduction = reduction

    def forward(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        weights: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Compute centerness loss.

        Args:
            pred: Predicted centerness logits [N] or [N, 1].
            target: Target centerness values [N] in [0, 1].
            weights: Optional per-sample weights [N].

        Returns:
            BCE loss value.
        """
        if pred.numel() == 0:
            return pred.sum() * 0.0

        pred = pred.view(-1)
        target = target.view(-1)

        loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")

        if weights is not None:
            weights = weights.view(-1)
            loss = loss * weights

        if self.reduction == "mean":
            if weights is not None:
                return loss.sum() / weights.sum().clamp(min=1e-7)
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss

__init__

__init__(reduction: str = 'mean')
Source code in src/autotimm/losses/detection.py
def __init__(self, reduction: str = "mean"):
    super().__init__()
    self.reduction = reduction

forward

forward(pred: Tensor, target: Tensor, weights: Tensor | None = None) -> torch.Tensor

Compute centerness loss.

Parameters:

Name Type Description Default
pred Tensor

Predicted centerness logits [N] or [N, 1].

required
target Tensor

Target centerness values [N] in [0, 1].

required
weights Tensor | None

Optional per-sample weights [N].

None

Returns:

Type Description
Tensor

BCE loss value.

Source code in src/autotimm/losses/detection.py
def forward(
    self,
    pred: torch.Tensor,
    target: torch.Tensor,
    weights: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute centerness loss.

    Args:
        pred: Predicted centerness logits [N] or [N, 1].
        target: Target centerness values [N] in [0, 1].
        weights: Optional per-sample weights [N].

    Returns:
        BCE loss value.
    """
    if pred.numel() == 0:
        return pred.sum() * 0.0

    pred = pred.view(-1)
    target = target.view(-1)

    loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")

    if weights is not None:
        weights = weights.view(-1)
        loss = loss * weights

    if self.reduction == "mean":
        if weights is not None:
            return loss.sum() / weights.sum().clamp(min=1e-7)
        return loss.mean()
    elif self.reduction == "sum":
        return loss.sum()
    return loss

Usage Example

from autotimm import CenternessLoss
import torch

loss_fn = CenternessLoss()

# Centerness predictions (logits)
pred = torch.randn(32, 1, 100, 100)  # (B, 1, H, W)
target = torch.rand(32, 1, 100, 100)  # (B, 1, H, W), values in [0, 1]

loss = loss_fn(pred, target)

Parameters

Parameter Type Default Description
reduction str "mean" Loss reduction: "mean", "sum", or "none"

How It Works

Centerness Computation: 1. For each location, compute distances to box edges: l, t, r, b 2. Compute geometric mean of horizontal and vertical ratios 3. Value is 1.0 at the center, decreases towards edges 4. Multiply with classification score during inference

Purpose:

  • Suppress low-quality detections far from object centers
  • Alternative to NMS (though FCOS still uses NMS)
  • Improve localization quality by favoring central predictions

Example:

Object bounding box: [10, 10, 50, 50]
Location at (30, 30): Center → centerness ≈ 1.0
Location at (10, 10): Corner → centerness ≈ 0.0
Location at (30, 10): Edge → centerness ≈ 0.5

When to Use

Use Centerness Loss when:

  • Implementing FCOS or similar anchor-free detectors
  • Want to suppress low-quality detections
  • Training point-based detectors

Not needed for:

  • Anchor-based detectors (Faster R-CNN, etc.)
  • Methods with different quality measures (e.g., IoU prediction)

FCOSLoss

Combined loss function for FCOS object detection.

Overview

FCOSLoss combines Focal Loss (classification), GIoU Loss (bbox regression), and Centerness Loss into a single differentiable loss function for end-to-end training.

Formula:

Total Loss = λ_cls * Focal Loss
           + λ_reg * GIoU Loss
           + λ_centerness * Centerness Loss

API Reference

autotimm.FCOSLoss

Bases: Module

Combined FCOS loss: Focal Loss + GIoU Loss + Centerness Loss.

This wrapper computes and combines the three loss components used in FCOS-style object detection.

Parameters:

Name Type Description Default
num_classes int

Number of object classes.

required
focal_alpha float

Alpha parameter for focal loss.

0.25
focal_gamma float

Gamma parameter for focal loss.

2.0
cls_weight float

Weight for classification loss.

1.0
reg_weight float

Weight for regression loss.

1.0
centerness_weight float

Weight for centerness loss.

1.0
Source code in src/autotimm/losses/detection.py
class FCOSLoss(nn.Module):
    """Combined FCOS loss: Focal Loss + GIoU Loss + Centerness Loss.

    This wrapper computes and combines the three loss components used
    in FCOS-style object detection.

    Parameters:
        num_classes: Number of object classes.
        focal_alpha: Alpha parameter for focal loss.
        focal_gamma: Gamma parameter for focal loss.
        cls_weight: Weight for classification loss.
        reg_weight: Weight for regression loss.
        centerness_weight: Weight for centerness loss.
    """

    def __init__(
        self,
        num_classes: int,
        focal_alpha: float = 0.25,
        focal_gamma: float = 2.0,
        cls_weight: float = 1.0,
        reg_weight: float = 1.0,
        centerness_weight: float = 1.0,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.cls_weight = cls_weight
        self.reg_weight = reg_weight
        self.centerness_weight = centerness_weight

        self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
        self.giou_loss = GIoULoss()
        self.centerness_loss = CenternessLoss()

    def forward(
        self,
        cls_preds: list[torch.Tensor],
        reg_preds: list[torch.Tensor],
        centerness_preds: list[torch.Tensor],
        cls_targets: list[torch.Tensor],
        reg_targets: list[torch.Tensor],
        centerness_targets: list[torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        """Compute combined FCOS loss.

        Args:
            cls_preds: List of classification predictions per level [B, C, H, W].
            reg_preds: List of regression predictions per level [B, 4, H, W].
            centerness_preds: List of centerness predictions per level [B, 1, H, W].
            cls_targets: List of classification targets per level [B, H, W].
            reg_targets: List of regression targets per level [B, 4, H, W].
            centerness_targets: List of centerness targets per level [B, H, W].

        Returns:
            Dict with 'cls_loss', 'reg_loss', 'centerness_loss', and 'total_loss'.
        """
        total_cls_loss = 0.0
        total_reg_loss = 0.0
        total_centerness_loss = 0.0
        num_pos = 0

        for level_idx in range(len(cls_preds)):
            cls_pred = cls_preds[level_idx]
            reg_pred = reg_preds[level_idx]
            centerness_pred = centerness_preds[level_idx]
            cls_target = cls_targets[level_idx]
            reg_target = reg_targets[level_idx]
            centerness_target = centerness_targets[level_idx]

            # Classification loss (all locations)
            total_cls_loss = total_cls_loss + self.focal_loss(cls_pred, cls_target)

            # Find positive samples (cls_target >= 0 and not background)
            # Background is typically the last class or handled via -1
            pos_mask = cls_target >= 0

            if pos_mask.any():
                # Regression loss (positive samples only)
                pos_reg_pred = reg_pred.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]
                pos_reg_target = reg_target.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]

                # Convert LTRB distances to boxes for GIoU
                # This requires knowing the grid locations - simplified here
                # In practice, you'd compute boxes from distances + grid centers
                total_reg_loss = total_reg_loss + self._compute_reg_loss(
                    pos_reg_pred, pos_reg_target
                )

                # Centerness loss (positive samples only)
                pos_centerness_pred = centerness_pred.permute(0, 2, 3, 1)[pos_mask]
                pos_centerness_target = centerness_target[pos_mask]
                total_centerness_loss = total_centerness_loss + self.centerness_loss(
                    pos_centerness_pred, pos_centerness_target
                )

                num_pos += pos_mask.sum()

        # Normalize by number of positive samples
        num_pos = max(num_pos, 1)

        cls_loss = self.cls_weight * total_cls_loss
        reg_loss = self.reg_weight * total_reg_loss / num_pos
        centerness_loss = self.centerness_weight * total_centerness_loss / num_pos
        total_loss = cls_loss + reg_loss + centerness_loss

        return {
            "cls_loss": cls_loss,
            "reg_loss": reg_loss,
            "centerness_loss": centerness_loss,
            "total_loss": total_loss,
        }

    def _compute_reg_loss(
        self, pred: torch.Tensor, target: torch.Tensor
    ) -> torch.Tensor:
        """Compute regression loss using IoU-based loss.

        For FCOS, predictions are (left, top, right, bottom) distances.
        We can compute IoU loss directly from these.
        """
        # Compute areas from LTRB distances
        pred_area = (pred[:, 0] + pred[:, 2]) * (pred[:, 1] + pred[:, 3])
        target_area = (target[:, 0] + target[:, 2]) * (target[:, 1] + target[:, 3])

        # Intersection
        inter_w = torch.min(pred[:, 0], target[:, 0]) + torch.min(
            pred[:, 2], target[:, 2]
        )
        inter_h = torch.min(pred[:, 1], target[:, 1]) + torch.min(
            pred[:, 3], target[:, 3]
        )
        inter_area = inter_w * inter_h

        # Union
        union_area = pred_area + target_area - inter_area

        # IoU loss
        iou = inter_area / union_area.clamp(min=1e-7)
        loss = -torch.log(iou.clamp(min=1e-7))

        return loss.sum()

__init__

__init__(num_classes: int, focal_alpha: float = 0.25, focal_gamma: float = 2.0, cls_weight: float = 1.0, reg_weight: float = 1.0, centerness_weight: float = 1.0)
Source code in src/autotimm/losses/detection.py
def __init__(
    self,
    num_classes: int,
    focal_alpha: float = 0.25,
    focal_gamma: float = 2.0,
    cls_weight: float = 1.0,
    reg_weight: float = 1.0,
    centerness_weight: float = 1.0,
):
    super().__init__()
    self.num_classes = num_classes
    self.cls_weight = cls_weight
    self.reg_weight = reg_weight
    self.centerness_weight = centerness_weight

    self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
    self.giou_loss = GIoULoss()
    self.centerness_loss = CenternessLoss()

forward

forward(cls_preds: list[Tensor], reg_preds: list[Tensor], centerness_preds: list[Tensor], cls_targets: list[Tensor], reg_targets: list[Tensor], centerness_targets: list[Tensor]) -> dict[str, torch.Tensor]

Compute combined FCOS loss.

Parameters:

Name Type Description Default
cls_preds list[Tensor]

List of classification predictions per level [B, C, H, W].

required
reg_preds list[Tensor]

List of regression predictions per level [B, 4, H, W].

required
centerness_preds list[Tensor]

List of centerness predictions per level [B, 1, H, W].

required
cls_targets list[Tensor]

List of classification targets per level [B, H, W].

required
reg_targets list[Tensor]

List of regression targets per level [B, 4, H, W].

required
centerness_targets list[Tensor]

List of centerness targets per level [B, H, W].

required

Returns:

Type Description
dict[str, Tensor]

Dict with 'cls_loss', 'reg_loss', 'centerness_loss', and 'total_loss'.

Source code in src/autotimm/losses/detection.py
def forward(
    self,
    cls_preds: list[torch.Tensor],
    reg_preds: list[torch.Tensor],
    centerness_preds: list[torch.Tensor],
    cls_targets: list[torch.Tensor],
    reg_targets: list[torch.Tensor],
    centerness_targets: list[torch.Tensor],
) -> dict[str, torch.Tensor]:
    """Compute combined FCOS loss.

    Args:
        cls_preds: List of classification predictions per level [B, C, H, W].
        reg_preds: List of regression predictions per level [B, 4, H, W].
        centerness_preds: List of centerness predictions per level [B, 1, H, W].
        cls_targets: List of classification targets per level [B, H, W].
        reg_targets: List of regression targets per level [B, 4, H, W].
        centerness_targets: List of centerness targets per level [B, H, W].

    Returns:
        Dict with 'cls_loss', 'reg_loss', 'centerness_loss', and 'total_loss'.
    """
    total_cls_loss = 0.0
    total_reg_loss = 0.0
    total_centerness_loss = 0.0
    num_pos = 0

    for level_idx in range(len(cls_preds)):
        cls_pred = cls_preds[level_idx]
        reg_pred = reg_preds[level_idx]
        centerness_pred = centerness_preds[level_idx]
        cls_target = cls_targets[level_idx]
        reg_target = reg_targets[level_idx]
        centerness_target = centerness_targets[level_idx]

        # Classification loss (all locations)
        total_cls_loss = total_cls_loss + self.focal_loss(cls_pred, cls_target)

        # Find positive samples (cls_target >= 0 and not background)
        # Background is typically the last class or handled via -1
        pos_mask = cls_target >= 0

        if pos_mask.any():
            # Regression loss (positive samples only)
            pos_reg_pred = reg_pred.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]
            pos_reg_target = reg_target.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]

            # Convert LTRB distances to boxes for GIoU
            # This requires knowing the grid locations - simplified here
            # In practice, you'd compute boxes from distances + grid centers
            total_reg_loss = total_reg_loss + self._compute_reg_loss(
                pos_reg_pred, pos_reg_target
            )

            # Centerness loss (positive samples only)
            pos_centerness_pred = centerness_pred.permute(0, 2, 3, 1)[pos_mask]
            pos_centerness_target = centerness_target[pos_mask]
            total_centerness_loss = total_centerness_loss + self.centerness_loss(
                pos_centerness_pred, pos_centerness_target
            )

            num_pos += pos_mask.sum()

    # Normalize by number of positive samples
    num_pos = max(num_pos, 1)

    cls_loss = self.cls_weight * total_cls_loss
    reg_loss = self.reg_weight * total_reg_loss / num_pos
    centerness_loss = self.centerness_weight * total_centerness_loss / num_pos
    total_loss = cls_loss + reg_loss + centerness_loss

    return {
        "cls_loss": cls_loss,
        "reg_loss": reg_loss,
        "centerness_loss": centerness_loss,
        "total_loss": total_loss,
    }

Usage Example

from autotimm import FCOSLoss
import torch

loss_fn = FCOSLoss(
    num_classes=80,
    focal_alpha=0.25,
    focal_gamma=2.0,
    cls_loss_weight=1.0,
    reg_loss_weight=1.0,
    centerness_loss_weight=1.0,
)

# Predictions from detection head
cls_scores = [torch.randn(2, 80, 100, 100)]  # Classification logits
bbox_preds = [torch.randn(2, 4, 100, 100)]   # Bbox predictions (l,t,r,b)
centernesses = [torch.randn(2, 1, 100, 100)] # Centerness logits

# Ground truth
targets = [
    {
        "boxes": torch.tensor([[10, 10, 50, 50], [60, 60, 100, 100]]),
        "labels": torch.tensor([1, 5]),
    },
    {
        "boxes": torch.tensor([[20, 20, 40, 40]]),
        "labels": torch.tensor([3]),
    },
]

# Compute loss
loss_dict = loss_fn(cls_scores, bbox_preds, centernesses, targets)
total_loss = loss_dict["loss"]

Parameters

Parameter Type Default Description
num_classes int Required Number of object classes
focal_alpha float 0.25 Focal loss alpha
focal_gamma float 2.0 Focal loss gamma
cls_loss_weight float 1.0 Classification loss weight (λ_cls)
reg_loss_weight float 1.0 Regression loss weight (λ_reg)
centerness_loss_weight float 1.0 Centerness loss weight (λ_centerness)
strides tuple[int, ...] (8, 16, 32, 64, 128) FPN strides
regress_ranges tuple[tuple[float, float], ...] \| None None Regression ranges for each level

Returns

Dictionary with: - "loss": Total weighted loss - "cls_loss": Classification loss (before weighting) - "reg_loss": Regression loss (before weighting) - "centerness_loss": Centerness loss (before weighting)

How It Works

  1. Target Assignment:
  2. Assigns ground truth boxes to FPN levels based on size
  3. Computes target classification labels
  4. Computes target bbox offsets (l, t, r, b)
  5. Computes target centerness values

  6. Loss Computation:

  7. Classification: Focal loss on all locations (positive + negative)
  8. Regression: GIoU loss on positive locations only
  9. Centerness: BCE loss on positive locations only

  10. Normalization:

  11. Losses are normalized by the number of positive samples
  12. Prevents loss explosion when there are few objects

Loss Weight Tuning

Scenario cls_weight reg_weight centerness_weight
Standard (recommended) 1.0 1.0 1.0
Poor localization 1.0 2.0 1.0
Missing detections 2.0 1.0 1.0
False positives 1.0 1.0 2.0
Small objects 1.0 2.0 0.5

Usage in Training

from autotimm import ObjectDetector

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    # Loss configuration
    focal_alpha=0.25,
    focal_gamma=2.0,
    cls_loss_weight=1.0,
    reg_loss_weight=1.0,
    centerness_loss_weight=1.0,
)

# FCOSLoss is automatically created and used during training

Target Assignment Strategy

Regression Range Assignment:

FPN Level Stride Default Range Object Size
P3 8 (-1, 64) 0-64px
P4 16 (64, 128) 64-128px
P5 32 (128, 256) 128-256px
P6 64 (256, 512) 256-512px
P7 128 (512, ∞) >512px

Objects are assigned to the FPN level whose regression range best matches the object's max dimension.

Custom Regression Ranges

For datasets with specific object size distributions:

loss_fn = FCOSLoss(
    num_classes=80,
    strides=(8, 16, 32, 64, 128),
    regress_ranges=(
        (-1, 32),          # P3: very small objects
        (32, 64),          # P4
        (64, 128),         # P5
        (128, 256),        # P6
        (256, float("inf")),  # P7
    ),
)