Heads¶
Neural network heads for classification and object detection tasks.
ClassificationHead¶
Simple classification head with optional dropout.
API Reference¶
autotimm.ClassificationHead ¶
Bases: Module
Linear classification head with optional dropout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Dimensionality of the backbone output. |
required |
num_classes
|
int
|
Number of target classes. |
required |
dropout
|
float
|
Dropout probability before the final linear layer. |
0.0
|
Source code in src/autotimm/heads/_heads.py
__init__ ¶
Usage Example¶
from autotimm import ClassificationHead
import torch
head = ClassificationHead(
in_features=2048,
num_classes=10,
dropout=0.5,
)
features = torch.randn(32, 2048)
logits = head(features) # (32, 10)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
in_features |
int |
Required | Input feature dimension |
num_classes |
int |
Required | Number of output classes |
dropout |
float |
0.0 |
Dropout rate before classifier |
Architecture¶
DetectionHead¶
FCOS-style detection head with classification, bbox regression, and centerness branches.
API Reference¶
autotimm.DetectionHead ¶
Bases: Module
FCOS-style detection head with classification, regression, and centerness branches.
This is an anchor-free detection head that predicts: - Class scores for each spatial location - Bounding box regression (distances to left, top, right, bottom) - Centerness scores to downweight predictions far from object centers
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels from FPN. |
256
|
num_classes
|
int
|
Number of object classes (excluding background). |
80
|
num_convs
|
int
|
Number of conv layers in each branch before prediction. |
4
|
prior_prob
|
float
|
Prior probability for focal loss initialization. |
0.01
|
use_group_norm
|
bool
|
Whether to use GroupNorm after conv layers. |
True
|
num_groups
|
int
|
Number of groups for GroupNorm. |
32
|
Source code in src/autotimm/heads/_heads.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | |
__init__ ¶
__init__(in_channels: int = 256, num_classes: int = 80, num_convs: int = 4, prior_prob: float = 0.01, use_group_norm: bool = True, num_groups: int = 32)
Source code in src/autotimm/heads/_heads.py
forward ¶
forward(features: list[Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]
Forward pass through detection head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
list[Tensor]
|
List of FPN features [P3, P4, P5, P6, P7]. |
required |
Returns:
| Type | Description |
|---|---|
list[Tensor]
|
Tuple of (cls_outputs, reg_outputs, centerness_outputs), each a list |
list[Tensor]
|
of tensors with shapes: |
list[Tensor]
|
|
tuple[list[Tensor], list[Tensor], list[Tensor]]
|
|
tuple[list[Tensor], list[Tensor], list[Tensor]]
|
|
Source code in src/autotimm/heads/_heads.py
Usage Example¶
from autotimm import DetectionHead
import torch
head = DetectionHead(
in_channels=256,
num_classes=80,
num_convs=4,
)
features = [
torch.randn(2, 256, 80, 80), # P3
torch.randn(2, 256, 40, 40), # P4
torch.randn(2, 256, 20, 20), # P5
]
cls_scores, bbox_preds, centernesses = head(features)
# cls_scores: List of (B, num_classes, H, W) tensors
# bbox_preds: List of (B, 4, H, W) tensors (l, t, r, b)
# centernesses: List of (B, 1, H, W) tensors
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
in_channels |
int |
Required | Number of input channels from FPN |
num_classes |
int |
Required | Number of object classes |
num_convs |
int |
4 |
Number of conv layers per branch |
prior_prob |
float |
0.01 |
Prior probability for focal loss initialization |
Architecture¶
DetectionHead (shared across all FPN levels)
├── cls_subnet
│ ├── Conv3x3 + GroupNorm + ReLU (×num_convs)
│ └── Conv3x3 → (num_classes, H, W)
├── bbox_subnet
│ ├── Conv3x3 + GroupNorm + ReLU (×num_convs)
│ └── Conv3x3 → (4, H, W) # l, t, r, b offsets
└── centerness_subnet
├── Conv3x3 + GroupNorm + ReLU (×num_convs)
└── Conv3x3 → (1, H, W)
Output Format¶
Classification Scores:
- Shape:
(B, num_classes, H, W)for each FPN level - Values: Raw logits (apply sigmoid for probabilities)
Bounding Box Predictions:
- Shape:
(B, 4, H, W)for each FPN level - Format:
(left, top, right, bottom)offsets from each location - Values: Raw predictions (apply exp() to get distances)
Centerness Scores:
- Shape:
(B, 1, H, W)for each FPN level - Values: Raw logits (apply sigmoid for probabilities)
- Purpose: Suppresses low-quality detections far from object centers
FPN¶
Feature Pyramid Network for multi-scale feature extraction.
API Reference¶
autotimm.FPN ¶
Bases: Module
Feature Pyramid Network for multi-scale object detection.
Takes feature maps from a backbone at multiple scales (C2-C5) and produces a feature pyramid (P3-P7) with top-down pathway and lateral connections.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels_list
|
Sequence[int]
|
List of input channel counts from backbone features. Typically [256, 512, 1024, 2048] for ResNet-50. |
required |
out_channels
|
int
|
Number of output channels for all pyramid levels. |
256
|
num_extra_levels
|
int
|
Number of extra levels to add beyond the backbone features. Default 2 adds P6 and P7 from P5. |
2
|
use_depthwise
|
bool
|
Whether to use depthwise separable convolutions. |
False
|
Source code in src/autotimm/heads/_heads.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
__init__ ¶
__init__(in_channels_list: Sequence[int], out_channels: int = 256, num_extra_levels: int = 2, use_depthwise: bool = False)
Source code in src/autotimm/heads/_heads.py
forward ¶
Forward pass through FPN.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
list[Tensor]
|
List of feature maps from backbone [C2, C3, C4, C5] or similar. |
required |
Returns:
| Type | Description |
|---|---|
list[Tensor]
|
List of pyramid features [P3, P4, P5, P6, P7] with uniform channels. |
Source code in src/autotimm/heads/_heads.py
Usage Example¶
from autotimm import FPN, create_feature_backbone
import torch
# Create backbone
backbone = create_feature_backbone("resnet50")
in_channels = [512, 1024, 2048] # C3, C4, C5
# Create FPN
fpn = FPN(
in_channels=in_channels,
out_channels=256,
)
# Extract features
images = torch.randn(2, 3, 640, 640)
features = backbone(images) # [C3, C4, C5]
pyramid = fpn(features) # [P3, P4, P5, P6, P7]
# Pyramid features all have 256 channels
for i, feat in enumerate(pyramid):
print(f"P{i+3}: {feat.shape}")
# P3: torch.Size([2, 256, 80, 80])
# P4: torch.Size([2, 256, 40, 40])
# P5: torch.Size([2, 256, 20, 20])
# P6: torch.Size([2, 256, 10, 10])
# P7: torch.Size([2, 256, 5, 5])
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
in_channels |
list[int] |
Required | Input channels for each backbone level (C3, C4, C5) |
out_channels |
int |
256 |
Output channels for all pyramid levels |
Architecture¶
FPN
├── Lateral connections (1x1 conv)
│ ├── C5 → lateral_conv5
│ ├── C4 → lateral_conv4
│ └── C3 → lateral_conv3
├── Top-down pathway (upsample + add)
│ ├── P5 = lateral5
│ ├── P4 = lateral4 + upsample(P5)
│ └── P3 = lateral3 + upsample(P4)
├── Output convolutions (3x3 conv)
│ ├── P3 → fpn_conv3
│ ├── P4 → fpn_conv4
│ └── P5 → fpn_conv5
└── Additional levels (downsampling)
├── P6 = MaxPool(P5)
└── P7 = MaxPool(P6)
Feature Pyramid Levels¶
| Level | Stride | Resolution (640px) | Typical Object Size |
|---|---|---|---|
| P3 | 8 | 80×80 | Very small (8-64px) |
| P4 | 16 | 40×40 | Small (64-128px) |
| P5 | 32 | 20×20 | Medium (128-256px) |
| P6 | 64 | 10×10 | Large (256-512px) |
| P7 | 128 | 5×5 | Very large (>512px) |
How FPN Works¶
- Bottom-up pathway: Backbone extracts features at multiple scales (C3, C4, C5)
- Lateral connections: 1×1 convolutions reduce channels to
out_channels - Top-down pathway: Higher-level features are upsampled and added to lateral connections
- Output smoothing: 3×3 convolutions smooth the merged features
- Additional levels: P6 and P7 are created via max pooling for detecting larger objects
Design Choices¶
Why out_channels=256?
- Standard choice in FPN literature (from original paper)
- Good balance between capacity and efficiency
- Can use 128 for faster inference or 512 for higher capacity
Why 5 pyramid levels?
- Covers object scales from very small to very large
- P3-P7 handles objects from ~8px to >512px
- More levels = better multi-scale detection but slower inference
Why top-down + lateral?
- Top-down: High-level semantic information flows to lower levels
- Lateral: Preserves precise localization from lower levels
- Combination: Semantically strong features at all scales
Usage in Detection¶
from autotimm import ObjectDetector
# FPN is automatically created inside ObjectDetector
model = ObjectDetector(
backbone="resnet50",
num_classes=80,
fpn_channels=256, # Controls FPN out_channels
)
# FPN construction:
# 1. Backbone extracts [C3, C4, C5] with channels [512, 1024, 2048]
# 2. FPN converts to [P3, P4, P5, P6, P7] with uniform 256 channels
# 3. DetectionHead processes all pyramid levels