Skip to content

Heads

Neural network heads for classification and object detection tasks.

ClassificationHead

Simple classification head with optional dropout.

API Reference

autotimm.ClassificationHead

Bases: Module

Linear classification head with optional dropout.

Parameters:

Name Type Description Default
in_features int

Dimensionality of the backbone output.

required
num_classes int

Number of target classes.

required
dropout float

Dropout probability before the final linear layer.

0.0
Source code in src/autotimm/heads/_heads.py
class ClassificationHead(nn.Module):
    """Linear classification head with optional dropout.

    Parameters:
        in_features: Dimensionality of the backbone output.
        num_classes: Number of target classes.
        dropout: Dropout probability before the final linear layer.
    """

    def __init__(self, in_features: int, num_classes: int, dropout: float = 0.0):
        super().__init__()
        self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(self.drop(x))

__init__

__init__(in_features: int, num_classes: int, dropout: float = 0.0)
Source code in src/autotimm/heads/_heads.py
def __init__(self, in_features: int, num_classes: int, dropout: float = 0.0):
    super().__init__()
    self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    self.fc = nn.Linear(in_features, num_classes)

forward

forward(x: Tensor) -> torch.Tensor
Source code in src/autotimm/heads/_heads.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.fc(self.drop(x))

Usage Example

from autotimm import ClassificationHead
import torch

head = ClassificationHead(
    in_features=2048,
    num_classes=10,
    dropout=0.5,
)

features = torch.randn(32, 2048)
logits = head(features)  # (32, 10)

Parameters

Parameter Type Default Description
in_features int Required Input feature dimension
num_classes int Required Number of output classes
dropout float 0.0 Dropout rate before classifier

Architecture

ClassificationHead
├── dropout (if dropout > 0)
└── Linear(in_features, num_classes)

DetectionHead

FCOS-style detection head with classification, bbox regression, and centerness branches.

API Reference

autotimm.DetectionHead

Bases: Module

FCOS-style detection head with classification, regression, and centerness branches.

This is an anchor-free detection head that predicts: - Class scores for each spatial location - Bounding box regression (distances to left, top, right, bottom) - Centerness scores to downweight predictions far from object centers

Parameters:

Name Type Description Default
in_channels int

Number of input channels from FPN.

256
num_classes int

Number of object classes (excluding background).

80
num_convs int

Number of conv layers in each branch before prediction.

4
prior_prob float

Prior probability for focal loss initialization.

0.01
use_group_norm bool

Whether to use GroupNorm after conv layers.

True
num_groups int

Number of groups for GroupNorm.

32
Source code in src/autotimm/heads/_heads.py
class DetectionHead(nn.Module):
    """FCOS-style detection head with classification, regression, and centerness branches.

    This is an anchor-free detection head that predicts:
    - Class scores for each spatial location
    - Bounding box regression (distances to left, top, right, bottom)
    - Centerness scores to downweight predictions far from object centers

    Parameters:
        in_channels: Number of input channels from FPN.
        num_classes: Number of object classes (excluding background).
        num_convs: Number of conv layers in each branch before prediction.
        prior_prob: Prior probability for focal loss initialization.
        use_group_norm: Whether to use GroupNorm after conv layers.
        num_groups: Number of groups for GroupNorm.
    """

    def __init__(
        self,
        in_channels: int = 256,
        num_classes: int = 80,
        num_convs: int = 4,
        prior_prob: float = 0.01,
        use_group_norm: bool = True,
        num_groups: int = 32,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels

        # Shared conv layers for classification branch
        cls_convs = []
        for _ in range(num_convs):
            cls_convs.append(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
            )
            if use_group_norm:
                cls_convs.append(nn.GroupNorm(num_groups, in_channels))
            cls_convs.append(nn.ReLU(inplace=True))
        self.cls_convs = nn.Sequential(*cls_convs)

        # Shared conv layers for regression branch
        reg_convs = []
        for _ in range(num_convs):
            reg_convs.append(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
            )
            if use_group_norm:
                reg_convs.append(nn.GroupNorm(num_groups, in_channels))
            reg_convs.append(nn.ReLU(inplace=True))
        self.reg_convs = nn.Sequential(*reg_convs)

        # Prediction layers
        self.cls_pred = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1)
        self.reg_pred = nn.Conv2d(in_channels, 4, kernel_size=3, padding=1)
        self.centerness_pred = nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)

        # Per-level learnable scales for regression
        self.scales = nn.ModuleList([ScaleModule(1.0) for _ in range(5)])

        self._init_weights(prior_prob)

    def _init_weights(self, prior_prob: float):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        # Initialize classification bias for focal loss
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        nn.init.constant_(self.cls_pred.bias, bias_value)

    def forward(
        self, features: list[torch.Tensor]
    ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
        """Forward pass through detection head.

        Args:
            features: List of FPN features [P3, P4, P5, P6, P7].

        Returns:
            Tuple of (cls_outputs, reg_outputs, centerness_outputs), each a list
            of tensors with shapes:
            - cls: [B, num_classes, H, W]
            - reg: [B, 4, H, W] (left, top, right, bottom distances)
            - centerness: [B, 1, H, W]
        """
        cls_outputs = []
        reg_outputs = []
        centerness_outputs = []

        for i, feat in enumerate(features):
            cls_feat = self.cls_convs(feat)
            reg_feat = self.reg_convs(feat)

            # Classification prediction
            cls_out = self.cls_pred(cls_feat)

            # Regression prediction (with per-level scaling)
            scale_idx = min(i, len(self.scales) - 1)
            reg_out = self.scales[scale_idx](self.reg_pred(reg_feat))
            reg_out = F.relu(reg_out)  # Distances must be positive

            # Centerness prediction (from regression branch)
            centerness_out = self.centerness_pred(reg_feat)

            cls_outputs.append(cls_out)
            reg_outputs.append(reg_out)
            centerness_outputs.append(centerness_out)

        return cls_outputs, reg_outputs, centerness_outputs

    def forward_single(
        self, feat: torch.Tensor, scale_idx: int = 0
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass for a single feature level.

        Useful for inference or when processing levels independently.
        """
        cls_feat = self.cls_convs(feat)
        reg_feat = self.reg_convs(feat)

        cls_out = self.cls_pred(cls_feat)
        scale_idx = min(scale_idx, len(self.scales) - 1)
        reg_out = F.relu(self.scales[scale_idx](self.reg_pred(reg_feat)))
        centerness_out = self.centerness_pred(reg_feat)

        return cls_out, reg_out, centerness_out

__init__

__init__(in_channels: int = 256, num_classes: int = 80, num_convs: int = 4, prior_prob: float = 0.01, use_group_norm: bool = True, num_groups: int = 32)
Source code in src/autotimm/heads/_heads.py
def __init__(
    self,
    in_channels: int = 256,
    num_classes: int = 80,
    num_convs: int = 4,
    prior_prob: float = 0.01,
    use_group_norm: bool = True,
    num_groups: int = 32,
):
    super().__init__()
    self.num_classes = num_classes
    self.in_channels = in_channels

    # Shared conv layers for classification branch
    cls_convs = []
    for _ in range(num_convs):
        cls_convs.append(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        )
        if use_group_norm:
            cls_convs.append(nn.GroupNorm(num_groups, in_channels))
        cls_convs.append(nn.ReLU(inplace=True))
    self.cls_convs = nn.Sequential(*cls_convs)

    # Shared conv layers for regression branch
    reg_convs = []
    for _ in range(num_convs):
        reg_convs.append(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        )
        if use_group_norm:
            reg_convs.append(nn.GroupNorm(num_groups, in_channels))
        reg_convs.append(nn.ReLU(inplace=True))
    self.reg_convs = nn.Sequential(*reg_convs)

    # Prediction layers
    self.cls_pred = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1)
    self.reg_pred = nn.Conv2d(in_channels, 4, kernel_size=3, padding=1)
    self.centerness_pred = nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)

    # Per-level learnable scales for regression
    self.scales = nn.ModuleList([ScaleModule(1.0) for _ in range(5)])

    self._init_weights(prior_prob)

forward

forward(features: list[Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]

Forward pass through detection head.

Parameters:

Name Type Description Default
features list[Tensor]

List of FPN features [P3, P4, P5, P6, P7].

required

Returns:

Type Description
list[Tensor]

Tuple of (cls_outputs, reg_outputs, centerness_outputs), each a list

list[Tensor]

of tensors with shapes:

list[Tensor]
  • cls: [B, num_classes, H, W]
tuple[list[Tensor], list[Tensor], list[Tensor]]
  • reg: [B, 4, H, W] (left, top, right, bottom distances)
tuple[list[Tensor], list[Tensor], list[Tensor]]
  • centerness: [B, 1, H, W]
Source code in src/autotimm/heads/_heads.py
def forward(
    self, features: list[torch.Tensor]
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
    """Forward pass through detection head.

    Args:
        features: List of FPN features [P3, P4, P5, P6, P7].

    Returns:
        Tuple of (cls_outputs, reg_outputs, centerness_outputs), each a list
        of tensors with shapes:
        - cls: [B, num_classes, H, W]
        - reg: [B, 4, H, W] (left, top, right, bottom distances)
        - centerness: [B, 1, H, W]
    """
    cls_outputs = []
    reg_outputs = []
    centerness_outputs = []

    for i, feat in enumerate(features):
        cls_feat = self.cls_convs(feat)
        reg_feat = self.reg_convs(feat)

        # Classification prediction
        cls_out = self.cls_pred(cls_feat)

        # Regression prediction (with per-level scaling)
        scale_idx = min(i, len(self.scales) - 1)
        reg_out = self.scales[scale_idx](self.reg_pred(reg_feat))
        reg_out = F.relu(reg_out)  # Distances must be positive

        # Centerness prediction (from regression branch)
        centerness_out = self.centerness_pred(reg_feat)

        cls_outputs.append(cls_out)
        reg_outputs.append(reg_out)
        centerness_outputs.append(centerness_out)

    return cls_outputs, reg_outputs, centerness_outputs

Usage Example

from autotimm import DetectionHead
import torch

head = DetectionHead(
    in_channels=256,
    num_classes=80,
    num_convs=4,
)

features = [
    torch.randn(2, 256, 80, 80),   # P3
    torch.randn(2, 256, 40, 40),   # P4
    torch.randn(2, 256, 20, 20),   # P5
]

cls_scores, bbox_preds, centernesses = head(features)
# cls_scores: List of (B, num_classes, H, W) tensors
# bbox_preds: List of (B, 4, H, W) tensors (l, t, r, b)
# centernesses: List of (B, 1, H, W) tensors

Parameters

Parameter Type Default Description
in_channels int Required Number of input channels from FPN
num_classes int Required Number of object classes
num_convs int 4 Number of conv layers per branch
prior_prob float 0.01 Prior probability for focal loss initialization

Architecture

DetectionHead (shared across all FPN levels)
├── cls_subnet
│   ├── Conv3x3 + GroupNorm + ReLU (×num_convs)
│   └── Conv3x3 → (num_classes, H, W)
├── bbox_subnet
│   ├── Conv3x3 + GroupNorm + ReLU (×num_convs)
│   └── Conv3x3 → (4, H, W)  # l, t, r, b offsets
└── centerness_subnet
    ├── Conv3x3 + GroupNorm + ReLU (×num_convs)
    └── Conv3x3 → (1, H, W)

Output Format

Classification Scores:

  • Shape: (B, num_classes, H, W) for each FPN level
  • Values: Raw logits (apply sigmoid for probabilities)

Bounding Box Predictions:

  • Shape: (B, 4, H, W) for each FPN level
  • Format: (left, top, right, bottom) offsets from each location
  • Values: Raw predictions (apply exp() to get distances)

Centerness Scores:

  • Shape: (B, 1, H, W) for each FPN level
  • Values: Raw logits (apply sigmoid for probabilities)
  • Purpose: Suppresses low-quality detections far from object centers

FPN

Feature Pyramid Network for multi-scale feature extraction.

API Reference

autotimm.FPN

Bases: Module

Feature Pyramid Network for multi-scale object detection.

Takes feature maps from a backbone at multiple scales (C2-C5) and produces a feature pyramid (P3-P7) with top-down pathway and lateral connections.

Parameters:

Name Type Description Default
in_channels_list Sequence[int]

List of input channel counts from backbone features. Typically [256, 512, 1024, 2048] for ResNet-50.

required
out_channels int

Number of output channels for all pyramid levels.

256
num_extra_levels int

Number of extra levels to add beyond the backbone features. Default 2 adds P6 and P7 from P5.

2
use_depthwise bool

Whether to use depthwise separable convolutions.

False
Source code in src/autotimm/heads/_heads.py
class FPN(nn.Module):
    """Feature Pyramid Network for multi-scale object detection.

    Takes feature maps from a backbone at multiple scales (C2-C5) and produces
    a feature pyramid (P3-P7) with top-down pathway and lateral connections.

    Parameters:
        in_channels_list: List of input channel counts from backbone features.
            Typically [256, 512, 1024, 2048] for ResNet-50.
        out_channels: Number of output channels for all pyramid levels.
        num_extra_levels: Number of extra levels to add beyond the backbone
            features. Default 2 adds P6 and P7 from P5.
        use_depthwise: Whether to use depthwise separable convolutions.
    """

    def __init__(
        self,
        in_channels_list: Sequence[int],
        out_channels: int = 256,
        num_extra_levels: int = 2,
        use_depthwise: bool = False,
    ):
        super().__init__()
        self.in_channels_list = list(in_channels_list)
        self.out_channels = out_channels
        self.num_extra_levels = num_extra_levels

        # Lateral (1x1) connections
        self.lateral_convs = nn.ModuleList()
        for in_channels in self.in_channels_list:
            self.lateral_convs.append(
                nn.Conv2d(in_channels, out_channels, kernel_size=1)
            )

        # Top-down pathway (3x3) convolutions
        self.output_convs = nn.ModuleList()
        conv_fn = _depthwise_conv if use_depthwise else _standard_conv
        for _ in range(len(self.in_channels_list)):
            self.output_convs.append(conv_fn(out_channels, out_channels))

        # Extra levels (P6, P7) from P5
        self.extra_convs = nn.ModuleList()
        for i in range(num_extra_levels):
            in_ch = out_channels if i == 0 else out_channels
            self.extra_convs.append(
                nn.Conv2d(in_ch, out_channels, kernel_size=3, stride=2, padding=1)
            )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        """Forward pass through FPN.

        Args:
            features: List of feature maps from backbone [C2, C3, C4, C5] or similar.

        Returns:
            List of pyramid features [P3, P4, P5, P6, P7] with uniform channels.
        """
        assert len(features) == len(
            self.in_channels_list
        ), f"Expected {len(self.in_channels_list)} features, got {len(features)}"

        # Build top-down pathway
        laterals = [
            lateral_conv(f) for lateral_conv, f in zip(self.lateral_convs, features)
        ]

        # Top-down fusion (from highest level to lowest)
        for i in range(len(laterals) - 1, 0, -1):
            laterals[i - 1] = laterals[i - 1] + F.interpolate(
                laterals[i], size=laterals[i - 1].shape[-2:], mode="nearest"
            )

        # Apply 3x3 convs to remove aliasing
        outputs = [
            output_conv(lateral)
            for output_conv, lateral in zip(self.output_convs, laterals)
        ]

        # Add extra levels from the last output
        last_feat = outputs[-1]
        for extra_conv in self.extra_convs:
            last_feat = F.relu(extra_conv(last_feat))
            outputs.append(last_feat)

        return outputs

__init__

__init__(in_channels_list: Sequence[int], out_channels: int = 256, num_extra_levels: int = 2, use_depthwise: bool = False)
Source code in src/autotimm/heads/_heads.py
def __init__(
    self,
    in_channels_list: Sequence[int],
    out_channels: int = 256,
    num_extra_levels: int = 2,
    use_depthwise: bool = False,
):
    super().__init__()
    self.in_channels_list = list(in_channels_list)
    self.out_channels = out_channels
    self.num_extra_levels = num_extra_levels

    # Lateral (1x1) connections
    self.lateral_convs = nn.ModuleList()
    for in_channels in self.in_channels_list:
        self.lateral_convs.append(
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )

    # Top-down pathway (3x3) convolutions
    self.output_convs = nn.ModuleList()
    conv_fn = _depthwise_conv if use_depthwise else _standard_conv
    for _ in range(len(self.in_channels_list)):
        self.output_convs.append(conv_fn(out_channels, out_channels))

    # Extra levels (P6, P7) from P5
    self.extra_convs = nn.ModuleList()
    for i in range(num_extra_levels):
        in_ch = out_channels if i == 0 else out_channels
        self.extra_convs.append(
            nn.Conv2d(in_ch, out_channels, kernel_size=3, stride=2, padding=1)
        )

    self._init_weights()

forward

forward(features: list[Tensor]) -> list[torch.Tensor]

Forward pass through FPN.

Parameters:

Name Type Description Default
features list[Tensor]

List of feature maps from backbone [C2, C3, C4, C5] or similar.

required

Returns:

Type Description
list[Tensor]

List of pyramid features [P3, P4, P5, P6, P7] with uniform channels.

Source code in src/autotimm/heads/_heads.py
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
    """Forward pass through FPN.

    Args:
        features: List of feature maps from backbone [C2, C3, C4, C5] or similar.

    Returns:
        List of pyramid features [P3, P4, P5, P6, P7] with uniform channels.
    """
    assert len(features) == len(
        self.in_channels_list
    ), f"Expected {len(self.in_channels_list)} features, got {len(features)}"

    # Build top-down pathway
    laterals = [
        lateral_conv(f) for lateral_conv, f in zip(self.lateral_convs, features)
    ]

    # Top-down fusion (from highest level to lowest)
    for i in range(len(laterals) - 1, 0, -1):
        laterals[i - 1] = laterals[i - 1] + F.interpolate(
            laterals[i], size=laterals[i - 1].shape[-2:], mode="nearest"
        )

    # Apply 3x3 convs to remove aliasing
    outputs = [
        output_conv(lateral)
        for output_conv, lateral in zip(self.output_convs, laterals)
    ]

    # Add extra levels from the last output
    last_feat = outputs[-1]
    for extra_conv in self.extra_convs:
        last_feat = F.relu(extra_conv(last_feat))
        outputs.append(last_feat)

    return outputs

Usage Example

from autotimm import FPN, create_feature_backbone
import torch

# Create backbone
backbone = create_feature_backbone("resnet50")
in_channels = [512, 1024, 2048]  # C3, C4, C5

# Create FPN
fpn = FPN(
    in_channels=in_channels,
    out_channels=256,
)

# Extract features
images = torch.randn(2, 3, 640, 640)
features = backbone(images)  # [C3, C4, C5]
pyramid = fpn(features)      # [P3, P4, P5, P6, P7]

# Pyramid features all have 256 channels
for i, feat in enumerate(pyramid):
    print(f"P{i+3}: {feat.shape}")
# P3: torch.Size([2, 256, 80, 80])
# P4: torch.Size([2, 256, 40, 40])
# P5: torch.Size([2, 256, 20, 20])
# P6: torch.Size([2, 256, 10, 10])
# P7: torch.Size([2, 256, 5, 5])

Parameters

Parameter Type Default Description
in_channels list[int] Required Input channels for each backbone level (C3, C4, C5)
out_channels int 256 Output channels for all pyramid levels

Architecture

FPN
├── Lateral connections (1x1 conv)
│   ├── C5 → lateral_conv5
│   ├── C4 → lateral_conv4
│   └── C3 → lateral_conv3
├── Top-down pathway (upsample + add)
│   ├── P5 = lateral5
│   ├── P4 = lateral4 + upsample(P5)
│   └── P3 = lateral3 + upsample(P4)
├── Output convolutions (3x3 conv)
│   ├── P3 → fpn_conv3
│   ├── P4 → fpn_conv4
│   └── P5 → fpn_conv5
└── Additional levels (downsampling)
    ├── P6 = MaxPool(P5)
    └── P7 = MaxPool(P6)

Feature Pyramid Levels

Level Stride Resolution (640px) Typical Object Size
P3 8 80×80 Very small (8-64px)
P4 16 40×40 Small (64-128px)
P5 32 20×20 Medium (128-256px)
P6 64 10×10 Large (256-512px)
P7 128 5×5 Very large (>512px)

How FPN Works

  1. Bottom-up pathway: Backbone extracts features at multiple scales (C3, C4, C5)
  2. Lateral connections: 1×1 convolutions reduce channels to out_channels
  3. Top-down pathway: Higher-level features are upsampled and added to lateral connections
  4. Output smoothing: 3×3 convolutions smooth the merged features
  5. Additional levels: P6 and P7 are created via max pooling for detecting larger objects

Design Choices

Why out_channels=256?

  • Standard choice in FPN literature (from original paper)
  • Good balance between capacity and efficiency
  • Can use 128 for faster inference or 512 for higher capacity

Why 5 pyramid levels?

  • Covers object scales from very small to very large
  • P3-P7 handles objects from ~8px to >512px
  • More levels = better multi-scale detection but slower inference

Why top-down + lateral?

  • Top-down: High-level semantic information flows to lower levels
  • Lateral: Preserves precise localization from lower levels
  • Combination: Semantically strong features at all scales

Usage in Detection

from autotimm import ObjectDetector

# FPN is automatically created inside ObjectDetector
model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    fpn_channels=256,  # Controls FPN out_channels
)

# FPN construction:
# 1. Backbone extracts [C3, C4, C5] with channels [512, 1024, 2048]
# 2. FPN converts to [P3, P4, P5, P6, P7] with uniform 256 channels
# 3. DetectionHead processes all pyramid levels