Skip to content

ObjectDetector

Anchor-free object detector with timm backbones and Feature Pyramid Networks, supporting FCOS and YOLOX architectures.

Overview

ObjectDetector is a PyTorch Lightning module for object detection that combines:

  • Any timm backbone for feature extraction
  • Feature Pyramid Network (FPN) for multi-scale features
  • FCOS or YOLOX detection head (configurable via detection_arch)
  • Focal Loss, GIoU Loss, and Centerness Loss (FCOS) or YOLOX losses
  • NMS post-processing for inference
  • Configurable optimizer and scheduler

API Reference

autotimm.ObjectDetector

Bases: PreprocessingMixin, LightningModule

End-to-end object detector supporting FCOS and YOLOX architectures.

Architecture: timm backbone → FPN → Detection Head (FCOS/YOLOX) → NMS

Parameters:

Name Type Description Default
backbone str | FeatureBackboneConfig

A timm model name (str) or a :class:FeatureBackboneConfig.

required
num_classes int

Number of object classes (excluding background).

required
detection_arch str

Detection architecture to use. Options: "fcos" or "yolox". Default is "fcos". YOLOX uses a decoupled head and no centerness prediction.

'fcos'
cls_loss_fn str | Module | None

Classification loss function. Can be: - A string from the loss registry (e.g., 'focal') - An instance of nn.Module (custom loss) - None (uses FocalLoss with focal_alpha and focal_gamma)

None
reg_loss_fn str | Module | None

Regression loss function. Can be: - A string from the loss registry (e.g., 'giou') - An instance of nn.Module (custom loss) - None (uses GIoULoss)

None
metrics MetricManager | list[MetricConfig] | None

A :class:MetricManager instance or list of :class:MetricConfig objects. Optional - if not provided, uses MeanAveragePrecision.

None
logging_config LoggingConfig | None

Optional :class:LoggingConfig for enhanced logging.

None
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided, enables the preprocess() method for inference-time preprocessing using model-specific normalization.

None
lr float

Learning rate.

0.0001
weight_decay float

Weight decay for optimizer.

0.0001
optimizer str | dict[str, Any]

Optimizer name ("adamw", "adam", "sgd", etc.) or dict with "class" and "params" keys.

'adamw'
optimizer_kwargs dict[str, Any] | None

Additional kwargs for the optimizer.

None
scheduler str | dict[str, Any] | None

Scheduler name ("cosine", "step", "onecycle", etc.), dict config, or None for no scheduler.

'cosine'
scheduler_kwargs dict[str, Any] | None

Extra kwargs forwarded to the LR scheduler.

None
fpn_channels int

Number of channels in FPN layers.

256
head_num_convs int

Number of conv layers in detection head branches.

4
focal_alpha float

Alpha parameter for focal loss.

0.25
focal_gamma float

Gamma parameter for focal loss.

2.0
cls_loss_weight float

Weight for classification loss.

1.0
reg_loss_weight float

Weight for regression loss.

1.0
centerness_loss_weight float

Weight for centerness loss (FCOS only).

1.0
score_thresh float

Score threshold for detections during inference.

0.05
nms_thresh float

IoU threshold for NMS.

0.5
max_detections_per_image int

Maximum detections to keep per image.

100
freeze_backbone bool

If True, backbone parameters are frozen.

False
strides tuple[int, ...]

FPN output strides. Default (8, 16, 32, 64, 128) for P3-P7.

(8, 16, 32, 64, 128)
regress_ranges tuple[tuple[int, int], ...] | None

Regression ranges for each FPN level (FCOS only).

None
compile_model bool

If True (default), apply torch.compile() to the backbone, FPN, and head for faster inference and training. Requires PyTorch 2.0+.

True
compile_kwargs dict[str, Any] | None

Optional dict of kwargs to pass to torch.compile(). Common options: mode ("default", "reduce-overhead", "max-autotune"),

None
seed int | None

Random seed for reproducibility. If None, no seeding is performed. Default is 42 for reproducible results.

None
deterministic bool

If True (default), enables deterministic algorithms in PyTorch for full reproducibility (may impact performance). Set to False for faster training. fullgraph (True/False), dynamic (True/False).

True
Example

model = ObjectDetector( ... backbone="resnet50", ... num_classes=80, ... metrics=[ ... MetricConfig( ... name="mAP", ... backend="torchmetrics", ... metric_class="MeanAveragePrecision", ... params={}, ... stages=["val", "test"], ... ), ... ], ... lr=1e-4, ... )

Source code in src/autotimm/tasks/object_detection.py
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
class ObjectDetector(PreprocessingMixin, pl.LightningModule):
    """End-to-end object detector supporting FCOS and YOLOX architectures.

    Architecture: timm backbone → FPN → Detection Head (FCOS/YOLOX) → NMS

    Parameters:
        backbone: A timm model name (str) or a :class:`FeatureBackboneConfig`.
        num_classes: Number of object classes (excluding background).
        detection_arch: Detection architecture to use. Options: ``"fcos"`` or ``"yolox"``.
            Default is ``"fcos"``. YOLOX uses a decoupled head and no centerness prediction.
        cls_loss_fn: Classification loss function. Can be:
            - A string from the loss registry (e.g., 'focal')
            - An instance of nn.Module (custom loss)
            - None (uses FocalLoss with focal_alpha and focal_gamma)
        reg_loss_fn: Regression loss function. Can be:
            - A string from the loss registry (e.g., 'giou')
            - An instance of nn.Module (custom loss)
            - None (uses GIoULoss)
        metrics: A :class:`MetricManager` instance or list of :class:`MetricConfig`
            objects. Optional - if not provided, uses MeanAveragePrecision.
        logging_config: Optional :class:`LoggingConfig` for enhanced logging.
        transform_config: Optional :class:`TransformConfig` for unified transform
            configuration. When provided, enables the ``preprocess()`` method
            for inference-time preprocessing using model-specific normalization.
        lr: Learning rate.
        weight_decay: Weight decay for optimizer.
        optimizer: Optimizer name (``"adamw"``, ``"adam"``, ``"sgd"``, etc.) or dict
            with ``"class"`` and ``"params"`` keys.
        optimizer_kwargs: Additional kwargs for the optimizer.
        scheduler: Scheduler name (``"cosine"``, ``"step"``, ``"onecycle"``, etc.),
            dict config, or ``None`` for no scheduler.
        scheduler_kwargs: Extra kwargs forwarded to the LR scheduler.
        fpn_channels: Number of channels in FPN layers.
        head_num_convs: Number of conv layers in detection head branches.
        focal_alpha: Alpha parameter for focal loss.
        focal_gamma: Gamma parameter for focal loss.
        cls_loss_weight: Weight for classification loss.
        reg_loss_weight: Weight for regression loss.
        centerness_loss_weight: Weight for centerness loss (FCOS only).
        score_thresh: Score threshold for detections during inference.
        nms_thresh: IoU threshold for NMS.
        max_detections_per_image: Maximum detections to keep per image.
        freeze_backbone: If ``True``, backbone parameters are frozen.
        strides: FPN output strides. Default (8, 16, 32, 64, 128) for P3-P7.
        regress_ranges: Regression ranges for each FPN level (FCOS only).
        compile_model: If ``True`` (default), apply ``torch.compile()`` to the backbone, FPN, and head
            for faster inference and training. Requires PyTorch 2.0+.
        compile_kwargs: Optional dict of kwargs to pass to ``torch.compile()``.
            Common options: ``mode`` (``"default"``, ``"reduce-overhead"``, ``"max-autotune"``),
        seed: Random seed for reproducibility. If ``None``, no seeding is performed.
            Default is ``42`` for reproducible results.
        deterministic: If ``True`` (default), enables deterministic algorithms in PyTorch for full
            reproducibility (may impact performance). Set to ``False`` for faster training.
            ``fullgraph`` (``True``/``False``), ``dynamic`` (``True``/``False``).

    Example:
        >>> model = ObjectDetector(
        ...     backbone="resnet50",
        ...     num_classes=80,
        ...     metrics=[
        ...         MetricConfig(
        ...             name="mAP",
        ...             backend="torchmetrics",
        ...             metric_class="MeanAveragePrecision",
        ...             params={},
        ...             stages=["val", "test"],
        ...         ),
        ...     ],
        ...     lr=1e-4,
        ... )
    """

    def __init__(
        self,
        backbone: str | FeatureBackboneConfig,
        num_classes: int,
        detection_arch: str = "fcos",
        cls_loss_fn: str | nn.Module | None = None,
        reg_loss_fn: str | nn.Module | None = None,
        metrics: MetricManager | list[MetricConfig] | None = None,
        logging_config: LoggingConfig | None = None,
        transform_config: TransformConfig | None = None,
        lr: float = 1e-4,
        weight_decay: float = 1e-4,
        optimizer: str | dict[str, Any] = "adamw",
        optimizer_kwargs: dict[str, Any] | None = None,
        scheduler: str | dict[str, Any] | None = "cosine",
        scheduler_kwargs: dict[str, Any] | None = None,
        fpn_channels: int = 256,
        head_num_convs: int = 4,
        focal_alpha: float = 0.25,
        focal_gamma: float = 2.0,
        cls_loss_weight: float = 1.0,
        reg_loss_weight: float = 1.0,
        centerness_loss_weight: float = 1.0,
        score_thresh: float = 0.05,
        nms_thresh: float = 0.5,
        max_detections_per_image: int = 100,
        freeze_backbone: bool = False,
        strides: tuple[int, ...] = (8, 16, 32, 64, 128),
        regress_ranges: tuple[tuple[int, int], ...] | None = None,
        compile_model: bool = True,
        compile_kwargs: dict[str, Any] | None = None,
        seed: int | None = None,
        deterministic: bool = True,
    ):
        # Seed for reproducibility
        if seed is not None:
            seed_everything(seed, deterministic=deterministic)
        elif deterministic:
            import warnings

            warnings.warn(
                "deterministic=True has no effect when seed=None. "
                "Set a seed value to enable deterministic behavior.",
                stacklevel=2,
            )

        # Normalize backbone to a plain string so it survives checkpoint
        # round-trip (FeatureBackboneConfig is not serialised by save_hyperparameters).
        backbone = backbone.model_name if hasattr(backbone, "model_name") else str(backbone)

        super().__init__()
        self.save_hyperparameters(
            ignore=["metrics", "logging_config", "transform_config", "cls_loss_fn", "reg_loss_fn"]
        )
        self.hparams.update({
            "backbone_name": backbone,
            "username": getpass.getuser(),
            "timestamp": _dt.datetime.now().isoformat(timespec="seconds"),
        })

        # Validate detection architecture
        if detection_arch not in ["fcos", "yolox"]:
            raise ValueError(
                f"detection_arch must be 'fcos' or 'yolox', got '{detection_arch}'"
            )

        self.detection_arch = detection_arch
        self.num_classes = num_classes
        self._lr = lr
        self._weight_decay = weight_decay
        self._optimizer = optimizer
        self._optimizer_kwargs = optimizer_kwargs or {}
        self._scheduler = scheduler
        self._scheduler_kwargs = scheduler_kwargs or {}

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.max_detections_per_image = max_detections_per_image
        self.strides = strides

        # Default regression ranges for FCOS (P3-P7)
        # Note: YOLOX doesn't use regression ranges (handles all scales dynamically)
        if regress_ranges is None:
            self.regress_ranges = (
                (-1, 64),
                (64, 128),
                (128, 256),
                (256, 512),
                (512, float("inf")),
            )
        else:
            self.regress_ranges = regress_ranges

        # Build model
        # For object detection, we typically use 3 backbone features (C3, C4, C5)
        # which combined with 2 extra FPN levels gives us P3-P7 (5 levels total)
        if isinstance(backbone, str):
            from autotimm.core.backbone import FeatureBackboneConfig

            backbone = FeatureBackboneConfig(model_name=backbone, out_indices=(2, 3, 4))
        elif not hasattr(backbone, "out_indices"):
            # If it's a FeatureBackboneConfig without out_indices set, use (2, 3, 4)
            if hasattr(backbone, "model_name"):
                from autotimm.core.backbone import FeatureBackboneConfig

                backbone = FeatureBackboneConfig(
                    model_name=backbone.model_name,
                    pretrained=getattr(backbone, "pretrained", True),
                    out_indices=(2, 3, 4),
                    drop_rate=getattr(backbone, "drop_rate", 0.0),
                    drop_path_rate=getattr(backbone, "drop_path_rate", 0.0),
                    extra_kwargs=getattr(backbone, "extra_kwargs", {}),
                )

        self.backbone = create_feature_backbone(backbone)
        in_channels = get_feature_channels(self.backbone)

        self.fpn = FPN(
            in_channels_list=in_channels,
            out_channels=fpn_channels,
            num_extra_levels=2,  # P6, P7 from P5
        )

        # Create detection head based on architecture
        if detection_arch == "fcos":
            self.head = DetectionHead(
                in_channels=fpn_channels,
                num_classes=num_classes,
                num_convs=head_num_convs,
                prior_prob=0.01,
            )
        elif detection_arch == "yolox":
            self.head = YOLOXHead(
                in_channels=fpn_channels,
                num_classes=num_classes,
                num_convs=(
                    head_num_convs if head_num_convs <= 2 else 2
                ),  # YOLOX typically uses 2 convs
                prior_prob=0.01,
                activation="silu",
            )

        # Losses
        self.cls_loss_weight = cls_loss_weight
        self.reg_loss_weight = reg_loss_weight
        self.centerness_loss_weight = centerness_loss_weight

        # Setup classification loss
        if cls_loss_fn is not None:
            if isinstance(cls_loss_fn, str):
                registry = get_loss_registry()
                loss_kwargs = {"reduction": "sum"}
                if cls_loss_fn == "focal":
                    loss_kwargs["alpha"] = focal_alpha
                    loss_kwargs["gamma"] = focal_gamma
                self.focal_loss = registry.get_loss(cls_loss_fn, **loss_kwargs)
            elif isinstance(cls_loss_fn, nn.Module):
                self.focal_loss = cls_loss_fn
            else:
                raise TypeError(
                    f"cls_loss_fn must be a string or nn.Module instance, got {type(cls_loss_fn)}"
                )
        else:
            # Default: FocalLoss
            self.focal_loss = FocalLoss(
                alpha=focal_alpha, gamma=focal_gamma, reduction="sum"
            )

        # Setup regression loss
        if reg_loss_fn is not None:
            if isinstance(reg_loss_fn, str):
                registry = get_loss_registry()
                self.giou_loss = registry.get_loss(reg_loss_fn, reduction="sum")
            elif isinstance(reg_loss_fn, nn.Module):
                self.giou_loss = reg_loss_fn
            else:
                raise TypeError(
                    f"reg_loss_fn must be a string or nn.Module instance, got {type(reg_loss_fn)}"
                )
        else:
            # Default: GIoULoss
            self.giou_loss = GIoULoss(reduction="sum")

        # Initialize metrics
        if metrics is None:
            # Default: use mAP for validation and test
            metrics = [
                MetricConfig(
                    name="mAP",
                    backend="torchmetrics",
                    metric_class="MeanAveragePrecision",
                    params={"box_format": "xyxy", "iou_type": "bbox"},
                    stages=["val", "test"],
                    prog_bar=True,
                ),
            ]

        if isinstance(metrics, list):
            # For detection metrics, num_classes isn't typically passed to constructor
            # We pass it but MetricManager will handle it appropriately
            self._metric_configs = metrics
            self._use_metric_manager = False
        else:
            self._metric_manager = metrics
            self._use_metric_manager = True

        # Register metrics as ModuleDicts
        self._register_detection_metrics()

        # Logging configuration
        self._logging_config = logging_config or LoggingConfig(
            log_learning_rate=False,
            log_gradient_norm=False,
        )

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Apply torch.compile for optimization (PyTorch 2.0+)
        # Skip on MPS — the inductor backend generates invalid Metal shaders.
        # Skip on Windows — the inductor backend requires cl.exe (MSVC) which
        # is typically unavailable; compilation is deferred so the try/except
        # at init time cannot catch the error.
        import sys as _sys
        _mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        _skip_compile = _mps_available or _sys.platform == "win32"
        if compile_model and not _skip_compile:
            try:
                compile_opts = compile_kwargs or {}
                self.backbone = torch.compile(self.backbone, **compile_opts)
                self.fpn = torch.compile(self.fpn, **compile_opts)
                self.head = torch.compile(self.head, **compile_opts)
            except Exception as e:
                import warnings

                warnings.warn(
                    f"torch.compile failed: {e}. Continuing without compilation. "
                    f"Ensure you have PyTorch 2.0+ for compile support.",
                    stacklevel=2,
                )
        elif compile_kwargs is not None:
            import warnings

            warnings.warn(
                "compile_kwargs is ignored when compile_model=False.",
                stacklevel=2,
            )

        # Setup transforms from config (PreprocessingMixin)
        self._setup_transforms(transform_config, task="detection")

    def _register_detection_metrics(self):
        """Register detection-specific metrics."""
        import torchmetrics.detection

        self.val_metrics = nn.ModuleDict()
        self.test_metrics = nn.ModuleDict()

        if self._use_metric_manager:
            # Use the manager's metrics
            self.val_metrics = self._metric_manager.get_val_metrics()
            self.test_metrics = self._metric_manager.get_test_metrics()
        else:
            # Create metrics from configs
            for config in self._metric_configs:
                if config.backend == "torchmetrics":
                    if hasattr(torchmetrics.detection, config.metric_class):
                        metric_cls = getattr(
                            torchmetrics.detection, config.metric_class
                        )
                    elif hasattr(torchmetrics, config.metric_class):
                        metric_cls = getattr(torchmetrics, config.metric_class)
                    else:
                        raise ValueError(f"Unknown metric: {config.metric_class}")

                    if "val" in config.stages:
                        self.val_metrics[config.name] = metric_cls(**config.params)
                    if "test" in config.stages:
                        self.test_metrics[config.name] = metric_cls(**config.params)

    def on_fit_start(self) -> None:
        """Capture batch_size from the datamodule when training begins."""
        if (
            self.trainer is not None
            and self.trainer.datamodule is not None
            and hasattr(self.trainer.datamodule, "batch_size")
        ):
            self.hparams["batch_size"] = self.trainer.datamodule.batch_size

    def forward(
        self, images: torch.Tensor
    ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None]:
        """Forward pass through the detector.

        Args:
            images: Input images [B, C, H, W].

        Returns:
            Tuple of (cls_outputs, reg_outputs, centerness_outputs) per FPN level.
            For YOLOX, centerness_outputs is None.
        """
        features = self.backbone(images)
        fpn_features = self.fpn(features)

        if self.detection_arch == "fcos":
            cls_outputs, reg_outputs, centerness_outputs = self.head(fpn_features)
        elif self.detection_arch == "yolox":
            cls_outputs, reg_outputs = self.head(fpn_features)
            centerness_outputs = None
        else:
            raise ValueError(f"Unknown detection_arch: {self.detection_arch}")

        return cls_outputs, reg_outputs, centerness_outputs

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        images = batch["images"]
        target_boxes = batch["boxes"]  # List of [N_i, 4] tensors
        target_labels = batch["labels"]  # List of [N_i] tensors

        # Forward
        cls_outputs, reg_outputs, centerness_outputs = self(images)

        # Compute targets for each FPN level
        device = images.device
        batch_size = images.shape[0]
        img_h, img_w = images.shape[-2:]

        # Compute loss
        total_cls_loss = torch.tensor(0.0, device=device)
        total_reg_loss = torch.tensor(0.0, device=device)
        total_centerness_loss = torch.tensor(0.0, device=device)
        num_pos = 0

        # Prepare centerness outputs for iteration
        if centerness_outputs is None:
            # YOLOX doesn't use centerness
            centerness_iter = [None] * len(cls_outputs)
        else:
            centerness_iter = centerness_outputs

        for level_idx, (cls_out, reg_out, cent_out) in enumerate(
            zip(cls_outputs, reg_outputs, centerness_iter)
        ):
            stride = self.strides[level_idx]
            feat_h, feat_w = cls_out.shape[-2:]

            # Generate grid points for this level
            grid_y, grid_x = torch.meshgrid(
                torch.arange(feat_h, device=device, dtype=torch.float32),
                torch.arange(feat_w, device=device, dtype=torch.float32),
                indexing="ij",
            )
            # Points are at the center of each cell
            points_x = (grid_x + 0.5) * stride
            points_y = (grid_y + 0.5) * stride
            points = torch.stack([points_x, points_y], dim=-1)  # [H, W, 2]

            # Compute targets for this level
            level_cls_targets = []
            level_reg_targets = []
            level_centerness_targets = [] if self.detection_arch == "fcos" else None

            for b in range(batch_size):
                boxes = target_boxes[b]  # [N, 4] in xyxy format
                labels = target_labels[b]  # [N]

                if len(boxes) == 0:
                    # No objects - all background
                    cls_target = torch.full(
                        (feat_h, feat_w), -1, dtype=torch.long, device=device
                    )
                    reg_target = torch.zeros(feat_h, feat_w, 4, device=device)
                    cent_target = (
                        torch.zeros(feat_h, feat_w, device=device)
                        if self.detection_arch == "fcos"
                        else None
                    )
                else:
                    cls_target, reg_target, cent_target = (
                        self._compute_targets_per_level(
                            points, boxes, labels, stride, level_idx, (img_h, img_w)
                        )
                    )

                level_cls_targets.append(cls_target)
                level_reg_targets.append(reg_target)
                if self.detection_arch == "fcos":
                    level_centerness_targets.append(cent_target)

            # Stack batch
            cls_targets = torch.stack(level_cls_targets)  # [B, H, W]
            reg_targets = torch.stack(level_reg_targets)  # [B, H, W, 4]
            cent_targets = (
                torch.stack(level_centerness_targets)
                if self.detection_arch == "fcos"
                else None
            )  # [B, H, W]

            # Compute classification loss (all locations except ignored=-1)
            cls_out_flat = cls_out.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            cls_targets_flat = cls_targets.reshape(-1)
            valid_mask = cls_targets_flat >= 0

            if valid_mask.any():
                # For focal loss, background class is handled implicitly
                # Positive samples have class labels, negatives have -1 (ignored in loss)
                total_cls_loss = total_cls_loss + self.focal_loss(
                    cls_out_flat, cls_targets_flat
                )

            # Compute regression and centerness loss (positive samples only)
            pos_mask = cls_targets >= 0  # [B, H, W]

            if pos_mask.any():
                pos_reg_pred = reg_out.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]
                pos_reg_target = reg_targets[pos_mask]  # [N_pos, 4]

                # IoU-based regression loss
                reg_loss = self._compute_iou_loss(pos_reg_pred, pos_reg_target)
                total_reg_loss = total_reg_loss + reg_loss

                # Centerness BCE loss (FCOS only)
                if self.detection_arch == "fcos" and cent_out is not None:
                    pos_cent_pred = cent_out.squeeze(1)[pos_mask]  # [N_pos]
                    pos_cent_target = cent_targets[pos_mask]  # [N_pos]
                    cent_loss = F.binary_cross_entropy_with_logits(
                        pos_cent_pred, pos_cent_target, reduction="sum"
                    )
                    total_centerness_loss = total_centerness_loss + cent_loss

                num_pos += pos_mask.sum().item()

        # Normalize by number of positive samples
        num_pos = max(num_pos, 1)

        cls_loss = self.cls_loss_weight * total_cls_loss / num_pos
        reg_loss = self.reg_loss_weight * total_reg_loss / num_pos

        # Centerness loss only for FCOS
        if self.detection_arch == "fcos":
            centerness_loss = (
                self.centerness_loss_weight * total_centerness_loss / num_pos
            )
            total_loss = cls_loss + reg_loss + centerness_loss
        else:
            centerness_loss = torch.tensor(0.0, device=device)
            total_loss = cls_loss + reg_loss

        # Log losses
        self.log("train/loss", total_loss, prog_bar=True)
        self.log("train/cls_loss", cls_loss)
        self.log("train/reg_loss", reg_loss)
        if self.detection_arch == "fcos":
            self.log("train/centerness_loss", centerness_loss)
        self.log("train/num_pos", float(num_pos))

        # Enhanced logging
        if self._logging_config.log_learning_rate:
            opt = self.optimizers()
            if opt is not None and hasattr(opt, "param_groups"):
                lr = opt.param_groups[0]["lr"]
                self.log("train/lr", lr, on_step=True, on_epoch=False)

        return total_loss

    def _compute_targets_per_level(
        self,
        points: torch.Tensor,
        boxes: torch.Tensor,
        labels: torch.Tensor,
        stride: int,
        level_idx: int,
        img_size: tuple[int, int],
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute FCOS targets for a single image at one FPN level.

        Args:
            points: Grid points [H, W, 2].
            boxes: Target boxes [N, 4] in xyxy format.
            labels: Target labels [N].
            stride: Stride for this FPN level.
            level_idx: Index of FPN level.
            img_size: (H, W) of input image.

        Returns:
            cls_target: [H, W] with class labels or -1 for ignore.
            reg_target: [H, W, 4] with (l, t, r, b) distances.
            centerness_target: [H, W] with centerness values.
        """
        device = points.device
        feat_h, feat_w = points.shape[:2]

        # Expand points and boxes for broadcasting
        points_flat = points.reshape(-1, 2)  # [H*W, 2]
        num_points = points_flat.shape[0]

        # Compute distances from each point to each box
        # boxes: [N, 4] -> [1, N, 4] for broadcasting
        boxes_exp = boxes.unsqueeze(0)  # [1, N, 4]
        points_exp = points_flat.unsqueeze(1)  # [H*W, 1, 2]

        # Left, top, right, bottom distances
        left = points_exp[..., 0] - boxes_exp[..., 0]  # [H*W, N]
        top = points_exp[..., 1] - boxes_exp[..., 1]
        right = boxes_exp[..., 2] - points_exp[..., 0]
        bottom = boxes_exp[..., 3] - points_exp[..., 1]

        reg_targets_per_box = torch.stack(
            [left, top, right, bottom], dim=-1
        )  # [H*W, N, 4]

        # Check if point is inside box
        inside_box = (left > 0) & (top > 0) & (right > 0) & (bottom > 0)  # [H*W, N]

        # Check regression range constraint
        max_reg = reg_targets_per_box.max(dim=-1)[0]  # [H*W, N]
        min_range, max_range = self.regress_ranges[level_idx]
        in_range = (max_reg >= min_range) & (max_reg < max_range)

        # Valid assignments: inside box AND in regression range
        valid = inside_box & in_range  # [H*W, N]

        # For each point, find the box with minimum area (most specific)
        box_areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])  # [N]
        box_areas_exp = box_areas.unsqueeze(0).expand(num_points, -1)  # [H*W, N]

        # Set invalid assignments to inf area
        box_areas_masked = torch.where(
            valid, box_areas_exp, torch.tensor(float("inf"), device=device)
        )

        # Find best box for each point
        min_areas, best_box_idx = box_areas_masked.min(dim=1)  # [H*W]
        has_assignment = min_areas < float("inf")

        # Create targets
        cls_target = torch.full((num_points,), -1, dtype=torch.long, device=device)
        reg_target = torch.zeros(num_points, 4, device=device)
        centerness_target = torch.zeros(num_points, device=device)

        if has_assignment.any():
            assigned_idx = best_box_idx[has_assignment]
            cls_target[has_assignment] = labels[assigned_idx]

            # Get regression targets for assigned points
            point_indices = torch.arange(num_points, device=device)[has_assignment]
            reg_target[has_assignment] = reg_targets_per_box[
                point_indices, assigned_idx
            ]

            # Compute centerness
            lr = reg_target[has_assignment]
            left_right_min = torch.min(lr[:, 0], lr[:, 2])
            left_right_max = torch.max(lr[:, 0], lr[:, 2])
            top_bottom_min = torch.min(lr[:, 1], lr[:, 3])
            top_bottom_max = torch.max(lr[:, 1], lr[:, 3])

            centerness = torch.sqrt(
                (left_right_min / left_right_max.clamp(min=1e-7))
                * (top_bottom_min / top_bottom_max.clamp(min=1e-7))
            )
            centerness_target[has_assignment] = centerness

        # Reshape to spatial dimensions
        cls_target = cls_target.reshape(feat_h, feat_w)
        reg_target = reg_target.reshape(feat_h, feat_w, 4)
        centerness_target = centerness_target.reshape(feat_h, feat_w)

        return cls_target, reg_target, centerness_target

    def _compute_iou_loss(
        self, pred: torch.Tensor, target: torch.Tensor
    ) -> torch.Tensor:
        """Compute IoU-based regression loss for LTRB predictions."""
        # Compute areas from LTRB distances
        pred_area = (pred[:, 0] + pred[:, 2]) * (pred[:, 1] + pred[:, 3])
        target_area = (target[:, 0] + target[:, 2]) * (target[:, 1] + target[:, 3])

        # Intersection
        inter_w = torch.min(pred[:, 0], target[:, 0]) + torch.min(
            pred[:, 2], target[:, 2]
        )
        inter_h = torch.min(pred[:, 1], target[:, 1]) + torch.min(
            pred[:, 3], target[:, 3]
        )
        inter_area = inter_w * inter_h

        # Union
        union_area = pred_area + target_area - inter_area

        # IoU loss (negative log)
        iou = inter_area / union_area.clamp(min=1e-7)
        loss = -torch.log(iou.clamp(min=1e-7))

        return loss.sum()

    def validation_step(self, batch: dict[str, Any], batch_idx: int) -> None:
        images = batch["images"]
        target_boxes = batch["boxes"]
        target_labels = batch["labels"]

        # Get predictions
        detections = self.predict(images)

        # Convert to format expected by torchmetrics.detection.MeanAveragePrecision
        preds = []
        targets = []

        for i in range(len(detections)):
            preds.append(
                {
                    "boxes": detections[i]["boxes"],
                    "scores": detections[i]["scores"],
                    "labels": detections[i]["labels"],
                }
            )
            targets.append(
                {
                    "boxes": target_boxes[i],
                    "labels": target_labels[i],
                }
            )

        # Update metrics
        for name, metric in self.val_metrics.items():
            metric.update(preds, targets)

    def on_validation_epoch_end(self) -> None:
        is_sanity = getattr(self.trainer, "sanity_checking", False)
        prefix = "sanity_val" if is_sanity else "val"

        for name, metric in self.val_metrics.items():
            result = metric.compute()
            # MeanAveragePrecision returns a dict
            if isinstance(result, dict):
                for key, value in result.items():
                    if isinstance(value, torch.Tensor) and value.numel() == 1:
                        self.log(f"{prefix}/{key}", value, prog_bar=(key == "map"))
            else:
                self.log(f"{prefix}/{name}", result)
            metric.reset()

    def test_step(self, batch: dict[str, Any], batch_idx: int) -> None:
        images = batch["images"]
        target_boxes = batch["boxes"]
        target_labels = batch["labels"]

        detections = self.predict(images)

        preds = []
        targets = []

        for i in range(len(detections)):
            preds.append(
                {
                    "boxes": detections[i]["boxes"],
                    "scores": detections[i]["scores"],
                    "labels": detections[i]["labels"],
                }
            )
            targets.append(
                {
                    "boxes": target_boxes[i],
                    "labels": target_labels[i],
                }
            )

        for name, metric in self.test_metrics.items():
            metric.update(preds, targets)

    def on_test_epoch_end(self) -> None:
        for name, metric in self.test_metrics.items():
            result = metric.compute()
            if isinstance(result, dict):
                for key, value in result.items():
                    if isinstance(value, torch.Tensor) and value.numel() == 1:
                        self.log(f"test/{key}", value)
            else:
                self.log(f"test/{name}", result)
            metric.reset()

    def predict(self, images: torch.Tensor) -> list[dict[str, torch.Tensor]]:
        """Run inference and return detections with NMS.

        Args:
            images: Input images [B, C, H, W].

        Returns:
            List of dicts per image with 'boxes', 'scores', 'labels'.
        """
        cls_outputs, reg_outputs, centerness_outputs = self(images)

        batch_size = images.shape[0]
        img_h, img_w = images.shape[-2:]
        device = images.device

        all_detections = []

        # Prepare centerness outputs for iteration
        if centerness_outputs is None:
            centerness_iter = [None] * len(cls_outputs)
        else:
            centerness_iter = centerness_outputs

        for b in range(batch_size):
            all_boxes = []
            all_scores = []
            all_labels = []

            for level_idx, (cls_out, reg_out, cent_out) in enumerate(
                zip(cls_outputs, reg_outputs, centerness_iter)
            ):
                stride = self.strides[level_idx]
                feat_h, feat_w = cls_out.shape[-2:]

                # Get predictions for this image
                cls_logits = cls_out[b]  # [C, H, W]
                reg_pred = reg_out[b]  # [4, H, W]

                # Generate grid points
                grid_y, grid_x = torch.meshgrid(
                    torch.arange(feat_h, device=device, dtype=torch.float32),
                    torch.arange(feat_w, device=device, dtype=torch.float32),
                    indexing="ij",
                )
                points_x = (grid_x + 0.5) * stride
                points_y = (grid_y + 0.5) * stride

                # Flatten spatial dimensions
                cls_logits = cls_logits.permute(1, 2, 0).reshape(
                    -1, self.num_classes
                )  # [H*W, C]
                reg_pred = reg_pred.permute(1, 2, 0).reshape(-1, 4)  # [H*W, 4]
                points_x = points_x.reshape(-1)
                points_y = points_y.reshape(-1)

                # Compute scores
                cls_scores = cls_logits.sigmoid()
                if cent_out is not None:
                    # FCOS: classification * centerness
                    cent_pred = cent_out[b, 0]  # [H, W]
                    cent_pred = cent_pred.reshape(-1)  # [H*W]
                    centerness = cent_pred.sigmoid()
                    scores = cls_scores * centerness.unsqueeze(-1)  # [H*W, C]
                else:
                    # YOLOX: just use classification scores
                    scores = cls_scores  # [H*W, C]

                # Get max score per location
                max_scores, class_ids = scores.max(dim=-1)  # [H*W]

                # Filter by score threshold
                keep = max_scores > self.score_thresh
                if not keep.any():
                    continue

                max_scores = max_scores[keep]
                class_ids = class_ids[keep]
                reg_pred = reg_pred[keep]
                points_x = points_x[keep]
                points_y = points_y[keep]

                # Convert LTRB to xyxy boxes
                left = reg_pred[:, 0]
                top = reg_pred[:, 1]
                right = reg_pred[:, 2]
                bottom = reg_pred[:, 3]

                x1 = points_x - left
                y1 = points_y - top
                x2 = points_x + right
                y2 = points_y + bottom

                # Clamp to image bounds
                x1 = x1.clamp(min=0, max=img_w)
                y1 = y1.clamp(min=0, max=img_h)
                x2 = x2.clamp(min=0, max=img_w)
                y2 = y2.clamp(min=0, max=img_h)

                boxes = torch.stack([x1, y1, x2, y2], dim=-1)

                all_boxes.append(boxes)
                all_scores.append(max_scores)
                all_labels.append(class_ids)

            # Concatenate all levels
            if len(all_boxes) > 0:
                boxes = torch.cat(all_boxes, dim=0)
                scores = torch.cat(all_scores, dim=0)
                labels = torch.cat(all_labels, dim=0)

                # Apply NMS per class
                keep_indices = ops.batched_nms(boxes, scores, labels, self.nms_thresh)

                # Limit number of detections
                keep_indices = keep_indices[: self.max_detections_per_image]

                boxes = boxes[keep_indices]
                scores = scores[keep_indices]
                labels = labels[keep_indices]
            else:
                boxes = torch.zeros((0, 4), device=device)
                scores = torch.zeros((0,), device=device)
                labels = torch.zeros((0,), dtype=torch.long, device=device)

            all_detections.append(
                {
                    "boxes": boxes,
                    "scores": scores,
                    "labels": labels,
                }
            )

        return all_detections

    def predict_step(self, batch: Any, batch_idx: int) -> list[dict[str, torch.Tensor]]:
        images = batch["images"] if isinstance(batch, dict) else batch
        return self.predict(images)

    def to_onnx(
        self,
        save_path: str | None = None,
        example_input: torch.Tensor | None = None,
        opset_version: int = 17,
        dynamic_axes: dict[str, dict[int, str]] | None = None,
        **kwargs: Any,
    ) -> str:
        """Export model to ONNX format.

        Detection models flatten their list outputs into named tensors for ONNX
        compatibility (e.g., cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4).

        Args:
            save_path: Path to save the ONNX model. If None, uses a temp file.
            example_input: Example input tensor. If None, uses default shape (1, 3, 224, 224).
            opset_version: ONNX opset version. Default is 17.
            dynamic_axes: Dynamic axes specification. If None, batch dimension is dynamic.
            **kwargs: Additional arguments passed to export_to_onnx.

        Returns:
            Path to the saved ONNX model.

        Example:
            >>> model = ObjectDetector(backbone="resnet50", num_classes=80)
            >>> path = model.to_onnx("detector.onnx")
        """
        from autotimm.export import export_to_onnx

        if example_input is None:
            example_input = torch.randn(1, 3, 224, 224)

        if save_path is None:
            import tempfile

            save_path = tempfile.mktemp(suffix=".onnx")

        return export_to_onnx(
            self,
            save_path,
            example_input,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes,
            **kwargs,
        )

    def configure_optimizers(self) -> dict:
        """Configure optimizer and learning rate scheduler."""
        params = filter(lambda p: p.requires_grad, self.parameters())
        optimizer = self._create_optimizer(params)

        if self._scheduler is None or self._scheduler == "none":
            return {"optimizer": optimizer}

        scheduler_config = self._create_scheduler(optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler_config,
        }

    def _create_optimizer(self, params) -> torch.optim.Optimizer:
        """Create optimizer from config."""
        opt_kwargs = {"lr": self._lr, "weight_decay": self._weight_decay}
        opt_kwargs.update(self._optimizer_kwargs)

        if isinstance(self._optimizer, dict):
            opt_class_path = self._optimizer["class"]
            opt_params = self._optimizer.get("params", {})
            opt_kwargs.update(opt_params)
            optimizer_cls = self._import_class(opt_class_path)
            return optimizer_cls(params, **opt_kwargs)

        optimizer_name = self._optimizer.lower()

        torch_optimizers = {
            "adamw": torch.optim.AdamW,
            "adam": torch.optim.Adam,
            "sgd": torch.optim.SGD,
            "rmsprop": torch.optim.RMSprop,
        }

        if optimizer_name in torch_optimizers:
            return torch_optimizers[optimizer_name](params, **opt_kwargs)

        try:
            import timm.optim as timm_optim

            timm_optimizers = {
                "adamp": timm_optim.AdamP,
                "sgdp": timm_optim.SGDP,
                "lamb": timm_optim.Lamb,
            }
            if optimizer_name in timm_optimizers:
                return timm_optimizers[optimizer_name](params, **opt_kwargs)
        except ImportError:
            pass

        raise ValueError(f"Unknown optimizer: {self._optimizer}")

    def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> dict:
        """Create scheduler config."""
        sched_kwargs = self._scheduler_kwargs.copy()
        interval = sched_kwargs.pop("interval", "step")
        frequency = sched_kwargs.pop("frequency", 1)

        if isinstance(self._scheduler, dict):
            sched_class_path = self._scheduler["class"]
            sched_params = self._scheduler.get("params", {})
            sched_kwargs.update(sched_params)
            scheduler_cls = self._import_class(sched_class_path)
            scheduler = scheduler_cls(optimizer, **sched_kwargs)
            return {
                "scheduler": scheduler,
                "interval": interval,
                "frequency": frequency,
            }

        scheduler_name = self._scheduler.lower()

        if scheduler_name == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=sched_kwargs.pop(
                    "T_max", self.trainer.estimated_stepping_batches
                ),
                **sched_kwargs,
            )
        elif scheduler_name == "step":
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=sched_kwargs.pop("step_size", 8),
                gamma=sched_kwargs.pop("gamma", 0.1),
                **sched_kwargs,
            )
        elif scheduler_name == "multistep":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=sched_kwargs.pop("milestones", [8, 11]),
                gamma=sched_kwargs.pop("gamma", 0.1),
                **sched_kwargs,
            )
        elif scheduler_name == "onecycle":
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=sched_kwargs.pop("max_lr", self._lr * 10),
                total_steps=sched_kwargs.pop(
                    "total_steps", self.trainer.estimated_stepping_batches
                ),
                **sched_kwargs,
            )
        else:
            raise ValueError(f"Unknown scheduler: {self._scheduler}")

        return {"scheduler": scheduler, "interval": interval, "frequency": frequency}

    def _import_class(self, class_path: str):
        """Import a class from a fully qualified path."""
        import importlib

        if "." not in class_path:
            raise ValueError(f"Class path must be fully qualified, got: {class_path}")

        module_path, class_name = class_path.rsplit(".", 1)
        module = importlib.import_module(module_path)
        return getattr(module, class_name)

    def on_before_optimizer_step(self, optimizer) -> None:
        """Hook for gradient norm logging."""
        if self._logging_config.log_gradient_norm:
            grad_norm = self._compute_gradient_norm()
            self.log("train/grad_norm", grad_norm, on_step=True, on_epoch=False)

        if self._logging_config.log_weight_norm:
            weight_norm = self._compute_weight_norm()
            self.log("train/weight_norm", weight_norm, on_step=True, on_epoch=False)

    def _compute_gradient_norm(self) -> torch.Tensor:
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return torch.tensor(total_norm**0.5, device=self.device)

    def _compute_weight_norm(self) -> torch.Tensor:
        total_norm = 0.0
        for p in self.parameters():
            param_norm = p.data.norm(2)
            total_norm += param_norm.item() ** 2
        return torch.tensor(total_norm**0.5, device=self.device)

__init__

__init__(backbone: str | FeatureBackboneConfig, num_classes: int, detection_arch: str = 'fcos', cls_loss_fn: str | Module | None = None, reg_loss_fn: str | Module | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.0001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, fpn_channels: int = 256, head_num_convs: int = 4, focal_alpha: float = 0.25, focal_gamma: float = 2.0, cls_loss_weight: float = 1.0, reg_loss_weight: float = 1.0, centerness_loss_weight: float = 1.0, score_thresh: float = 0.05, nms_thresh: float = 0.5, max_detections_per_image: int = 100, freeze_backbone: bool = False, strides: tuple[int, ...] = (8, 16, 32, 64, 128), regress_ranges: tuple[tuple[int, int], ...] | None = None, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)
Source code in src/autotimm/tasks/object_detection.py
def __init__(
    self,
    backbone: str | FeatureBackboneConfig,
    num_classes: int,
    detection_arch: str = "fcos",
    cls_loss_fn: str | nn.Module | None = None,
    reg_loss_fn: str | nn.Module | None = None,
    metrics: MetricManager | list[MetricConfig] | None = None,
    logging_config: LoggingConfig | None = None,
    transform_config: TransformConfig | None = None,
    lr: float = 1e-4,
    weight_decay: float = 1e-4,
    optimizer: str | dict[str, Any] = "adamw",
    optimizer_kwargs: dict[str, Any] | None = None,
    scheduler: str | dict[str, Any] | None = "cosine",
    scheduler_kwargs: dict[str, Any] | None = None,
    fpn_channels: int = 256,
    head_num_convs: int = 4,
    focal_alpha: float = 0.25,
    focal_gamma: float = 2.0,
    cls_loss_weight: float = 1.0,
    reg_loss_weight: float = 1.0,
    centerness_loss_weight: float = 1.0,
    score_thresh: float = 0.05,
    nms_thresh: float = 0.5,
    max_detections_per_image: int = 100,
    freeze_backbone: bool = False,
    strides: tuple[int, ...] = (8, 16, 32, 64, 128),
    regress_ranges: tuple[tuple[int, int], ...] | None = None,
    compile_model: bool = True,
    compile_kwargs: dict[str, Any] | None = None,
    seed: int | None = None,
    deterministic: bool = True,
):
    # Seed for reproducibility
    if seed is not None:
        seed_everything(seed, deterministic=deterministic)
    elif deterministic:
        import warnings

        warnings.warn(
            "deterministic=True has no effect when seed=None. "
            "Set a seed value to enable deterministic behavior.",
            stacklevel=2,
        )

    # Normalize backbone to a plain string so it survives checkpoint
    # round-trip (FeatureBackboneConfig is not serialised by save_hyperparameters).
    backbone = backbone.model_name if hasattr(backbone, "model_name") else str(backbone)

    super().__init__()
    self.save_hyperparameters(
        ignore=["metrics", "logging_config", "transform_config", "cls_loss_fn", "reg_loss_fn"]
    )
    self.hparams.update({
        "backbone_name": backbone,
        "username": getpass.getuser(),
        "timestamp": _dt.datetime.now().isoformat(timespec="seconds"),
    })

    # Validate detection architecture
    if detection_arch not in ["fcos", "yolox"]:
        raise ValueError(
            f"detection_arch must be 'fcos' or 'yolox', got '{detection_arch}'"
        )

    self.detection_arch = detection_arch
    self.num_classes = num_classes
    self._lr = lr
    self._weight_decay = weight_decay
    self._optimizer = optimizer
    self._optimizer_kwargs = optimizer_kwargs or {}
    self._scheduler = scheduler
    self._scheduler_kwargs = scheduler_kwargs or {}

    self.score_thresh = score_thresh
    self.nms_thresh = nms_thresh
    self.max_detections_per_image = max_detections_per_image
    self.strides = strides

    # Default regression ranges for FCOS (P3-P7)
    # Note: YOLOX doesn't use regression ranges (handles all scales dynamically)
    if regress_ranges is None:
        self.regress_ranges = (
            (-1, 64),
            (64, 128),
            (128, 256),
            (256, 512),
            (512, float("inf")),
        )
    else:
        self.regress_ranges = regress_ranges

    # Build model
    # For object detection, we typically use 3 backbone features (C3, C4, C5)
    # which combined with 2 extra FPN levels gives us P3-P7 (5 levels total)
    if isinstance(backbone, str):
        from autotimm.core.backbone import FeatureBackboneConfig

        backbone = FeatureBackboneConfig(model_name=backbone, out_indices=(2, 3, 4))
    elif not hasattr(backbone, "out_indices"):
        # If it's a FeatureBackboneConfig without out_indices set, use (2, 3, 4)
        if hasattr(backbone, "model_name"):
            from autotimm.core.backbone import FeatureBackboneConfig

            backbone = FeatureBackboneConfig(
                model_name=backbone.model_name,
                pretrained=getattr(backbone, "pretrained", True),
                out_indices=(2, 3, 4),
                drop_rate=getattr(backbone, "drop_rate", 0.0),
                drop_path_rate=getattr(backbone, "drop_path_rate", 0.0),
                extra_kwargs=getattr(backbone, "extra_kwargs", {}),
            )

    self.backbone = create_feature_backbone(backbone)
    in_channels = get_feature_channels(self.backbone)

    self.fpn = FPN(
        in_channels_list=in_channels,
        out_channels=fpn_channels,
        num_extra_levels=2,  # P6, P7 from P5
    )

    # Create detection head based on architecture
    if detection_arch == "fcos":
        self.head = DetectionHead(
            in_channels=fpn_channels,
            num_classes=num_classes,
            num_convs=head_num_convs,
            prior_prob=0.01,
        )
    elif detection_arch == "yolox":
        self.head = YOLOXHead(
            in_channels=fpn_channels,
            num_classes=num_classes,
            num_convs=(
                head_num_convs if head_num_convs <= 2 else 2
            ),  # YOLOX typically uses 2 convs
            prior_prob=0.01,
            activation="silu",
        )

    # Losses
    self.cls_loss_weight = cls_loss_weight
    self.reg_loss_weight = reg_loss_weight
    self.centerness_loss_weight = centerness_loss_weight

    # Setup classification loss
    if cls_loss_fn is not None:
        if isinstance(cls_loss_fn, str):
            registry = get_loss_registry()
            loss_kwargs = {"reduction": "sum"}
            if cls_loss_fn == "focal":
                loss_kwargs["alpha"] = focal_alpha
                loss_kwargs["gamma"] = focal_gamma
            self.focal_loss = registry.get_loss(cls_loss_fn, **loss_kwargs)
        elif isinstance(cls_loss_fn, nn.Module):
            self.focal_loss = cls_loss_fn
        else:
            raise TypeError(
                f"cls_loss_fn must be a string or nn.Module instance, got {type(cls_loss_fn)}"
            )
    else:
        # Default: FocalLoss
        self.focal_loss = FocalLoss(
            alpha=focal_alpha, gamma=focal_gamma, reduction="sum"
        )

    # Setup regression loss
    if reg_loss_fn is not None:
        if isinstance(reg_loss_fn, str):
            registry = get_loss_registry()
            self.giou_loss = registry.get_loss(reg_loss_fn, reduction="sum")
        elif isinstance(reg_loss_fn, nn.Module):
            self.giou_loss = reg_loss_fn
        else:
            raise TypeError(
                f"reg_loss_fn must be a string or nn.Module instance, got {type(reg_loss_fn)}"
            )
    else:
        # Default: GIoULoss
        self.giou_loss = GIoULoss(reduction="sum")

    # Initialize metrics
    if metrics is None:
        # Default: use mAP for validation and test
        metrics = [
            MetricConfig(
                name="mAP",
                backend="torchmetrics",
                metric_class="MeanAveragePrecision",
                params={"box_format": "xyxy", "iou_type": "bbox"},
                stages=["val", "test"],
                prog_bar=True,
            ),
        ]

    if isinstance(metrics, list):
        # For detection metrics, num_classes isn't typically passed to constructor
        # We pass it but MetricManager will handle it appropriately
        self._metric_configs = metrics
        self._use_metric_manager = False
    else:
        self._metric_manager = metrics
        self._use_metric_manager = True

    # Register metrics as ModuleDicts
    self._register_detection_metrics()

    # Logging configuration
    self._logging_config = logging_config or LoggingConfig(
        log_learning_rate=False,
        log_gradient_norm=False,
    )

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False

    # Apply torch.compile for optimization (PyTorch 2.0+)
    # Skip on MPS — the inductor backend generates invalid Metal shaders.
    # Skip on Windows — the inductor backend requires cl.exe (MSVC) which
    # is typically unavailable; compilation is deferred so the try/except
    # at init time cannot catch the error.
    import sys as _sys
    _mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    _skip_compile = _mps_available or _sys.platform == "win32"
    if compile_model and not _skip_compile:
        try:
            compile_opts = compile_kwargs or {}
            self.backbone = torch.compile(self.backbone, **compile_opts)
            self.fpn = torch.compile(self.fpn, **compile_opts)
            self.head = torch.compile(self.head, **compile_opts)
        except Exception as e:
            import warnings

            warnings.warn(
                f"torch.compile failed: {e}. Continuing without compilation. "
                f"Ensure you have PyTorch 2.0+ for compile support.",
                stacklevel=2,
            )
    elif compile_kwargs is not None:
        import warnings

        warnings.warn(
            "compile_kwargs is ignored when compile_model=False.",
            stacklevel=2,
        )

    # Setup transforms from config (PreprocessingMixin)
    self._setup_transforms(transform_config, task="detection")

forward

forward(images: Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None]

Forward pass through the detector.

Parameters:

Name Type Description Default
images Tensor

Input images [B, C, H, W].

required

Returns:

Type Description
list[Tensor]

Tuple of (cls_outputs, reg_outputs, centerness_outputs) per FPN level.

list[Tensor]

For YOLOX, centerness_outputs is None.

Source code in src/autotimm/tasks/object_detection.py
def forward(
    self, images: torch.Tensor
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None]:
    """Forward pass through the detector.

    Args:
        images: Input images [B, C, H, W].

    Returns:
        Tuple of (cls_outputs, reg_outputs, centerness_outputs) per FPN level.
        For YOLOX, centerness_outputs is None.
    """
    features = self.backbone(images)
    fpn_features = self.fpn(features)

    if self.detection_arch == "fcos":
        cls_outputs, reg_outputs, centerness_outputs = self.head(fpn_features)
    elif self.detection_arch == "yolox":
        cls_outputs, reg_outputs = self.head(fpn_features)
        centerness_outputs = None
    else:
        raise ValueError(f"Unknown detection_arch: {self.detection_arch}")

    return cls_outputs, reg_outputs, centerness_outputs

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor
Source code in src/autotimm/tasks/object_detection.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    images = batch["images"]
    target_boxes = batch["boxes"]  # List of [N_i, 4] tensors
    target_labels = batch["labels"]  # List of [N_i] tensors

    # Forward
    cls_outputs, reg_outputs, centerness_outputs = self(images)

    # Compute targets for each FPN level
    device = images.device
    batch_size = images.shape[0]
    img_h, img_w = images.shape[-2:]

    # Compute loss
    total_cls_loss = torch.tensor(0.0, device=device)
    total_reg_loss = torch.tensor(0.0, device=device)
    total_centerness_loss = torch.tensor(0.0, device=device)
    num_pos = 0

    # Prepare centerness outputs for iteration
    if centerness_outputs is None:
        # YOLOX doesn't use centerness
        centerness_iter = [None] * len(cls_outputs)
    else:
        centerness_iter = centerness_outputs

    for level_idx, (cls_out, reg_out, cent_out) in enumerate(
        zip(cls_outputs, reg_outputs, centerness_iter)
    ):
        stride = self.strides[level_idx]
        feat_h, feat_w = cls_out.shape[-2:]

        # Generate grid points for this level
        grid_y, grid_x = torch.meshgrid(
            torch.arange(feat_h, device=device, dtype=torch.float32),
            torch.arange(feat_w, device=device, dtype=torch.float32),
            indexing="ij",
        )
        # Points are at the center of each cell
        points_x = (grid_x + 0.5) * stride
        points_y = (grid_y + 0.5) * stride
        points = torch.stack([points_x, points_y], dim=-1)  # [H, W, 2]

        # Compute targets for this level
        level_cls_targets = []
        level_reg_targets = []
        level_centerness_targets = [] if self.detection_arch == "fcos" else None

        for b in range(batch_size):
            boxes = target_boxes[b]  # [N, 4] in xyxy format
            labels = target_labels[b]  # [N]

            if len(boxes) == 0:
                # No objects - all background
                cls_target = torch.full(
                    (feat_h, feat_w), -1, dtype=torch.long, device=device
                )
                reg_target = torch.zeros(feat_h, feat_w, 4, device=device)
                cent_target = (
                    torch.zeros(feat_h, feat_w, device=device)
                    if self.detection_arch == "fcos"
                    else None
                )
            else:
                cls_target, reg_target, cent_target = (
                    self._compute_targets_per_level(
                        points, boxes, labels, stride, level_idx, (img_h, img_w)
                    )
                )

            level_cls_targets.append(cls_target)
            level_reg_targets.append(reg_target)
            if self.detection_arch == "fcos":
                level_centerness_targets.append(cent_target)

        # Stack batch
        cls_targets = torch.stack(level_cls_targets)  # [B, H, W]
        reg_targets = torch.stack(level_reg_targets)  # [B, H, W, 4]
        cent_targets = (
            torch.stack(level_centerness_targets)
            if self.detection_arch == "fcos"
            else None
        )  # [B, H, W]

        # Compute classification loss (all locations except ignored=-1)
        cls_out_flat = cls_out.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
        cls_targets_flat = cls_targets.reshape(-1)
        valid_mask = cls_targets_flat >= 0

        if valid_mask.any():
            # For focal loss, background class is handled implicitly
            # Positive samples have class labels, negatives have -1 (ignored in loss)
            total_cls_loss = total_cls_loss + self.focal_loss(
                cls_out_flat, cls_targets_flat
            )

        # Compute regression and centerness loss (positive samples only)
        pos_mask = cls_targets >= 0  # [B, H, W]

        if pos_mask.any():
            pos_reg_pred = reg_out.permute(0, 2, 3, 1)[pos_mask]  # [N_pos, 4]
            pos_reg_target = reg_targets[pos_mask]  # [N_pos, 4]

            # IoU-based regression loss
            reg_loss = self._compute_iou_loss(pos_reg_pred, pos_reg_target)
            total_reg_loss = total_reg_loss + reg_loss

            # Centerness BCE loss (FCOS only)
            if self.detection_arch == "fcos" and cent_out is not None:
                pos_cent_pred = cent_out.squeeze(1)[pos_mask]  # [N_pos]
                pos_cent_target = cent_targets[pos_mask]  # [N_pos]
                cent_loss = F.binary_cross_entropy_with_logits(
                    pos_cent_pred, pos_cent_target, reduction="sum"
                )
                total_centerness_loss = total_centerness_loss + cent_loss

            num_pos += pos_mask.sum().item()

    # Normalize by number of positive samples
    num_pos = max(num_pos, 1)

    cls_loss = self.cls_loss_weight * total_cls_loss / num_pos
    reg_loss = self.reg_loss_weight * total_reg_loss / num_pos

    # Centerness loss only for FCOS
    if self.detection_arch == "fcos":
        centerness_loss = (
            self.centerness_loss_weight * total_centerness_loss / num_pos
        )
        total_loss = cls_loss + reg_loss + centerness_loss
    else:
        centerness_loss = torch.tensor(0.0, device=device)
        total_loss = cls_loss + reg_loss

    # Log losses
    self.log("train/loss", total_loss, prog_bar=True)
    self.log("train/cls_loss", cls_loss)
    self.log("train/reg_loss", reg_loss)
    if self.detection_arch == "fcos":
        self.log("train/centerness_loss", centerness_loss)
    self.log("train/num_pos", float(num_pos))

    # Enhanced logging
    if self._logging_config.log_learning_rate:
        opt = self.optimizers()
        if opt is not None and hasattr(opt, "param_groups"):
            lr = opt.param_groups[0]["lr"]
            self.log("train/lr", lr, on_step=True, on_epoch=False)

    return total_loss

validation_step

validation_step(batch: dict[str, Any], batch_idx: int) -> None
Source code in src/autotimm/tasks/object_detection.py
def validation_step(self, batch: dict[str, Any], batch_idx: int) -> None:
    images = batch["images"]
    target_boxes = batch["boxes"]
    target_labels = batch["labels"]

    # Get predictions
    detections = self.predict(images)

    # Convert to format expected by torchmetrics.detection.MeanAveragePrecision
    preds = []
    targets = []

    for i in range(len(detections)):
        preds.append(
            {
                "boxes": detections[i]["boxes"],
                "scores": detections[i]["scores"],
                "labels": detections[i]["labels"],
            }
        )
        targets.append(
            {
                "boxes": target_boxes[i],
                "labels": target_labels[i],
            }
        )

    # Update metrics
    for name, metric in self.val_metrics.items():
        metric.update(preds, targets)

test_step

test_step(batch: dict[str, Any], batch_idx: int) -> None
Source code in src/autotimm/tasks/object_detection.py
def test_step(self, batch: dict[str, Any], batch_idx: int) -> None:
    images = batch["images"]
    target_boxes = batch["boxes"]
    target_labels = batch["labels"]

    detections = self.predict(images)

    preds = []
    targets = []

    for i in range(len(detections)):
        preds.append(
            {
                "boxes": detections[i]["boxes"],
                "scores": detections[i]["scores"],
                "labels": detections[i]["labels"],
            }
        )
        targets.append(
            {
                "boxes": target_boxes[i],
                "labels": target_labels[i],
            }
        )

    for name, metric in self.test_metrics.items():
        metric.update(preds, targets)

predict_step

predict_step(batch: Any, batch_idx: int) -> list[dict[str, torch.Tensor]]
Source code in src/autotimm/tasks/object_detection.py
def predict_step(self, batch: Any, batch_idx: int) -> list[dict[str, torch.Tensor]]:
    images = batch["images"] if isinstance(batch, dict) else batch
    return self.predict(images)

configure_optimizers

configure_optimizers() -> dict

Configure optimizer and learning rate scheduler.

Source code in src/autotimm/tasks/object_detection.py
def configure_optimizers(self) -> dict:
    """Configure optimizer and learning rate scheduler."""
    params = filter(lambda p: p.requires_grad, self.parameters())
    optimizer = self._create_optimizer(params)

    if self._scheduler is None or self._scheduler == "none":
        return {"optimizer": optimizer}

    scheduler_config = self._create_scheduler(optimizer)
    return {
        "optimizer": optimizer,
        "lr_scheduler": scheduler_config,
    }

to_onnx

to_onnx(save_path: str | None = None, example_input: Tensor | None = None, opset_version: int = 17, dynamic_axes: dict[str, dict[int, str]] | None = None, **kwargs: Any) -> str

Export model to ONNX format.

Detection models flatten their list outputs into named tensors for ONNX compatibility (e.g., cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4).

Parameters:

Name Type Description Default
save_path str | None

Path to save the ONNX model. If None, uses a temp file.

None
example_input Tensor | None

Example input tensor. If None, uses default shape (1, 3, 224, 224).

None
opset_version int

ONNX opset version. Default is 17.

17
dynamic_axes dict[str, dict[int, str]] | None

Dynamic axes specification. If None, batch dimension is dynamic.

None
**kwargs Any

Additional arguments passed to export_to_onnx.

{}

Returns:

Type Description
str

Path to the saved ONNX model.

Example

model = ObjectDetector(backbone="resnet50", num_classes=80) path = model.to_onnx("detector.onnx")

Source code in src/autotimm/tasks/object_detection.py
def to_onnx(
    self,
    save_path: str | None = None,
    example_input: torch.Tensor | None = None,
    opset_version: int = 17,
    dynamic_axes: dict[str, dict[int, str]] | None = None,
    **kwargs: Any,
) -> str:
    """Export model to ONNX format.

    Detection models flatten their list outputs into named tensors for ONNX
    compatibility (e.g., cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4).

    Args:
        save_path: Path to save the ONNX model. If None, uses a temp file.
        example_input: Example input tensor. If None, uses default shape (1, 3, 224, 224).
        opset_version: ONNX opset version. Default is 17.
        dynamic_axes: Dynamic axes specification. If None, batch dimension is dynamic.
        **kwargs: Additional arguments passed to export_to_onnx.

    Returns:
        Path to the saved ONNX model.

    Example:
        >>> model = ObjectDetector(backbone="resnet50", num_classes=80)
        >>> path = model.to_onnx("detector.onnx")
    """
    from autotimm.export import export_to_onnx

    if example_input is None:
        example_input = torch.randn(1, 3, 224, 224)

    if save_path is None:
        import tempfile

        save_path = tempfile.mktemp(suffix=".onnx")

    return export_to_onnx(
        self,
        save_path,
        example_input,
        opset_version=opset_version,
        dynamic_axes=dynamic_axes,
        **kwargs,
    )

Usage Examples

Basic Usage

import autotimm as at  # recommended alias
from autotimm import ObjectDetector, MetricConfig

metrics = [
    MetricConfig(
        name="mAP",
        backend="torchmetrics",
        metric_class="MeanAveragePrecision",
        params={"box_format": "xyxy", "iou_type": "bbox"},
        stages=["val", "test"],
        prog_bar=True,
    ),
]

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    lr=1e-4,
)

With FeatureBackboneConfig

from autotimm import FeatureBackboneConfig, ObjectDetector

cfg = FeatureBackboneConfig(
    model_name="resnet50",
    pretrained=True,
    out_indices=(2, 3, 4),  # C3, C4, C5
)

model = ObjectDetector(
    backbone=cfg,
    num_classes=80,
    metrics=metrics,
)

With Transformer Backbone

model = ObjectDetector(
    backbone="swin_tiny_patch4_window7_224",
    num_classes=80,
    metrics=metrics,
    lr=1e-5,  # Lower LR for transformers
    fpn_channels=256,
)

Custom FPN and Head

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    fpn_channels=256,      # FPN channels (128, 256, or 512)
    head_num_convs=4,      # Number of conv layers in head
)

Custom Loss Configuration

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    focal_alpha=0.25,
    focal_gamma=2.0,
    cls_loss_weight=1.0,
    reg_loss_weight=1.0,
    centerness_loss_weight=1.0,
)

Custom Inference Settings

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    score_thresh=0.05,              # Confidence threshold
    nms_thresh=0.5,                 # NMS IoU threshold
    max_detections_per_image=100,   # Max detections to keep
)

With TransformConfig (Preprocessing)

Enable inference-time preprocessing with model-specific normalization:

from autotimm import ObjectDetector, TransformConfig

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    transform_config=TransformConfig(),  # Enable preprocess()
)

# Now you can preprocess raw images
from PIL import Image
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Returns preprocessed tensor
output = model(tensor)

Get Model's Data Config

model = ObjectDetector(
    backbone="swin_tiny_patch4_window7_224",
    num_classes=80,
    metrics=metrics,
    transform_config=TransformConfig(),
)

# Get normalization config
config = model.get_data_config()
print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")
print(f"Input size: {config['input_size']}")

Frozen Backbone

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    freeze_backbone=True,  # Only train FPN and head
    lr=1e-3,               # Higher LR when backbone frozen
)

With MultiStep Scheduler

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    lr=1e-4,
    scheduler="multistep",
    scheduler_kwargs={"milestones": [8, 11], "gamma": 0.1},
)

Parameters

Parameter Type Default Description
backbone str \| FeatureBackboneConfig Required Model name or config
num_classes int Required Number of object classes
detection_arch str "fcos" Detection architecture ("fcos" or "yolox")
cls_loss_fn str \| nn.Module \| None None Classification loss function (string from registry, nn.Module, or None for default)
reg_loss_fn str \| nn.Module \| None None Regression loss function (string from registry, nn.Module, or None for default)
metrics MetricManager \| list[MetricConfig] None Metrics configuration
logging_config LoggingConfig \| None None Enhanced logging options
transform_config TransformConfig \| None None Transform config for preprocessing
lr float 1e-4 Learning rate
weight_decay float 1e-4 Weight decay
optimizer str \| dict "adamw" Optimizer name or config
optimizer_kwargs dict \| None None Extra optimizer kwargs
scheduler str \| dict \| None "cosine" Scheduler name or config
scheduler_kwargs dict \| None None Extra scheduler kwargs
fpn_channels int 256 Number of FPN channels
head_num_convs int 4 Conv layers in detection head
focal_alpha float 0.25 Focal loss alpha
focal_gamma float 2.0 Focal loss gamma
cls_loss_weight float 1.0 Classification loss weight
reg_loss_weight float 1.0 Regression loss weight
centerness_loss_weight float 1.0 Centerness loss weight
score_thresh float 0.05 Score threshold for detections
nms_thresh float 0.5 NMS IoU threshold
max_detections_per_image int 100 Max detections per image
freeze_backbone bool False Freeze backbone weights
strides tuple[int, ...] (8, 16, 32, 64, 128) FPN strides
regress_ranges tuple \| None None Custom regression ranges (FCOS only)
compile_model bool True Apply torch.compile() for faster training/inference (PyTorch 2.0+)
compile_kwargs dict \| None None Kwargs for torch.compile() (e.g., mode, fullgraph, dynamic)
seed int \| None None Random seed for reproducibility (None to disable seeding)
deterministic bool True Enable deterministic algorithms (may impact performance)

Model Architecture

ObjectDetector
├── backbone (timm feature extractor)
│   └── Multi-scale features: C3, C4, C5
├── fpn (Feature Pyramid Network)
│   └── Pyramid levels: P3, P4, P5, P6, P7
├── detection_head (DetectionHead)
│   ├── cls_subnet → classification logits
│   ├── bbox_subnet → bbox offsets (l, t, r, b)
│   └── centerness_subnet → centerness scores
└── loss_fn (FCOSLoss)
    ├── FocalLoss (classification)
    ├── GIoULoss (bbox regression)
    └── CenternessLoss (center-ness)

FCOS Architecture

Feature Pyramid Network (FPN):

  • Takes C3, C4, C5 features from backbone
  • Builds pyramid levels P3-P7 via top-down and lateral connections
  • Each pyramid level detects objects at different scales

Regression Ranges: Objects are assigned to FPN levels based on their size:

Level Stride Default Range Object Size
P3 8 (-1, 64) Very small
P4 16 (64, 128) Small
P5 32 (128, 256) Medium
P6 64 (256, 512) Large
P7 128 (512, ∞) Very large

Detection Head:

  • Shared across all FPN levels
  • 3 branches: classification, bbox regression, centerness
  • Each branch has 4 conv layers (configurable via head_num_convs)

Loss Functions:

  • Focal Loss: Handles class imbalance in one-stage detectors
  • GIoU Loss: IoU-based metric for bbox regression
  • Centerness Loss: Suppresses low-quality detections far from object centers

Backbone Selection

CNN Backbones

Backbone Speed Accuracy Use Case
resnet18 Fast Good Quick experiments
resnet50 Medium Better Standard baseline
efficientnet_b3 Medium Better Efficiency
convnext_tiny Medium Best Modern CNN
resnet101 Slow Best High accuracy

Transformer Backbones

Backbone Speed Memory Use Case
swin_tiny_patch4_window7_224 Fast Medium Balanced
swin_small_patch4_window7_224 Medium Medium Production
swin_base_patch4_window7_224 Slow High Maximum accuracy
vit_base_patch16_224 Slow High Research

Notes:

  • Swin Transformers work best for detection (hierarchical features)
  • Use smaller batch sizes (8-16) with transformers
  • Use lower learning rates (1e-5) with transformer backbones

Logged Metrics

Metric Stage Condition
{stage}/loss train, val, test Always
{stage}/cls_loss train, val, test Always
{stage}/reg_loss train, val, test Always
{stage}/centerness_loss train, val, test Always
{stage}/{metric_name} As configured Per MetricConfig
train/lr train log_learning_rate=True
train/grad_norm train log_gradient_norm=True

Training Tips

Standard COCO Training

model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    lr=1e-4,
    scheduler="multistep",
    scheduler_kwargs={"milestones": [8, 11], "gamma": 0.1},
)

trainer = AutoTrainer(
    max_epochs=12,
    gradient_clip_val=1.0,
)
# Phase 1: Train FPN and head only
model = ObjectDetector(
    backbone="resnet50",
    num_classes=80,
    metrics=metrics,
    freeze_backbone=True,
    lr=1e-3,
)
trainer = AutoTrainer(max_epochs=3)
trainer.fit(model, datamodule=data)

# Phase 2: Fine-tune entire model
for param in model.backbone.parameters():
    param.requires_grad = True
model._lr = 1e-4
trainer = AutoTrainer(max_epochs=12, gradient_clip_val=1.0)
trainer.fit(model, datamodule=data)

Small Object Detection

For better small object detection:

model = ObjectDetector(
    backbone="resnet50",
    num_classes=num_classes,
    metrics=metrics,
    fpn_channels=256,  # More capacity
    head_num_convs=4,  # Deeper head
    # Adjust regression ranges to emphasize smaller levels
    regress_ranges=(
        (-1, 32),      # P3: extra small
        (32, 64),      # P4: very small
        (64, 128),     # P5: small
        (128, 256),    # P6: medium
        (256, float("inf")),  # P7: large
    ),
)