Skip to content

ImageClassifier

End-to-end image classifier backed by a timm backbone.

Overview

ImageClassifier is a PyTorch Lightning module that combines:

  • A timm backbone (feature extractor)
  • A classification head
  • Configurable metrics
  • Optimizer and scheduler setup

API Reference

autotimm.ImageClassifier

Bases: PreprocessingMixin, LightningModule

End-to-end image classifier backed by a timm backbone.

Parameters:

Name Type Description Default
backbone str | BackboneConfig

A timm model name (str) or a :class:BackboneConfig.

required
num_classes int

Number of target classes (or labels for multi-label).

required
multi_label bool

If True, use BCEWithLogitsLoss and sigmoid-based predictions for multi-label classification. Default is False.

False
threshold float

Prediction threshold for multi-label mode. Sigmoid outputs above this value are predicted as positive. Default is 0.5.

0.5
loss_fn str | Module | None

Loss function to use. Can be: - A string from the loss registry (e.g., 'cross_entropy', 'bce', 'mse') - An instance of nn.Module (custom loss) - None (uses default: CrossEntropyLoss for multi-class, BCEWithLogitsLoss for multi-label)

None
metrics MetricManager | list[MetricConfig] | None

A :class:MetricManager instance or list of :class:MetricConfig objects. Optional - if None, no metrics will be computed during training. This is useful for inference-only scenarios.

None
logging_config LoggingConfig | None

Optional :class:LoggingConfig for enhanced logging.

None
transform_config TransformConfig | None

Optional :class:TransformConfig for unified transform configuration. When provided, enables the preprocess() method for inference-time preprocessing using model-specific normalization.

None
lr float

Learning rate.

0.001
weight_decay float

Weight decay for optimizer.

0.0001
optimizer str | dict[str, Any]

Optimizer name ("adamw", "adam", "sgd", etc.) or dict with "class" (fully qualified class path) and "params" keys. Supports both torch.optim and timm optimizers.

'adamw'
optimizer_kwargs dict[str, Any] | None

Additional kwargs for the optimizer (merged with lr and weight_decay).

None
scheduler str | dict[str, Any] | None

Scheduler name ("cosine", "step", "onecycle", "cosine_with_restarts", etc.), dict with "class" and "params" keys, or None for no scheduler. Supports both torch.optim.lr_scheduler and timm schedulers.

'cosine'
scheduler_kwargs dict[str, Any] | None

Extra kwargs forwarded to the LR scheduler.

None
head_dropout float

Dropout before the classification linear layer.

0.0
label_smoothing float

Label smoothing factor for cross-entropy. Not supported with multi_label=True (raises ValueError).

0.0
freeze_backbone bool

If True, backbone parameters are frozen (useful for linear probing).

False
mixup_alpha float

If > 0, apply Mixup augmentation with this alpha.

0.0
compile_model bool

If True (default), apply torch.compile() to the backbone and head for faster inference and training. Requires PyTorch 2.0+.

True
compile_kwargs dict[str, Any] | None

Optional dict of kwargs to pass to torch.compile(). Common options: mode ("default", "reduce-overhead", "max-autotune"), fullgraph (True/False), dynamic (True/False).

None
seed int | None

Random seed for reproducibility. If None, no seeding is performed. Default is 42 for reproducible results.

None
deterministic bool

If True (default), enables deterministic algorithms in PyTorch for full reproducibility (may impact performance). Set to False for faster training.

True
Example

For training with metrics

model = ImageClassifier( ... backbone="resnet50", ... num_classes=10, ... metrics=[ ... MetricConfig( ... name="accuracy", ... backend="torchmetrics", ... metric_class="Accuracy", ... params={"task": "multiclass"}, ... stages=["train", "val", "test"], ... prog_bar=True, ... ), ... ], ... logging_config=LoggingConfig( ... log_learning_rate=True, ... log_gradient_norm=False, ... ), ... transform_config=TransformConfig(use_timm_config=True), ... )

For inference only (no metrics needed)

model = ImageClassifier( ... backbone="resnet50", ... num_classes=10, ... transform_config=TransformConfig(use_timm_config=True), ... )

With transform_config, you can preprocess raw images

from PIL import Image img = Image.open("test.jpg") tensor = model.preprocess(img) # Returns (1, 3, 224, 224) output = model(tensor)

Source code in src/autotimm/tasks/classification.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
class ImageClassifier(PreprocessingMixin, pl.LightningModule):
    """End-to-end image classifier backed by a timm backbone.

    Parameters:
        backbone: A timm model name (str) or a :class:`BackboneConfig`.
        num_classes: Number of target classes (or labels for multi-label).
        multi_label: If ``True``, use ``BCEWithLogitsLoss`` and sigmoid-based
            predictions for multi-label classification. Default is ``False``.
        threshold: Prediction threshold for multi-label mode. Sigmoid outputs
            above this value are predicted as positive. Default is ``0.5``.
        loss_fn: Loss function to use. Can be:
            - A string from the loss registry (e.g., 'cross_entropy', 'bce', 'mse')
            - An instance of nn.Module (custom loss)
            - None (uses default: CrossEntropyLoss for multi-class, BCEWithLogitsLoss for multi-label)
        metrics: A :class:`MetricManager` instance or list of :class:`MetricConfig`
            objects. Optional - if ``None``, no metrics will be computed during training.
            This is useful for inference-only scenarios.
        logging_config: Optional :class:`LoggingConfig` for enhanced logging.
        transform_config: Optional :class:`TransformConfig` for unified transform
            configuration. When provided, enables the ``preprocess()`` method
            for inference-time preprocessing using model-specific normalization.
        lr: Learning rate.
        weight_decay: Weight decay for optimizer.
        optimizer: Optimizer name (``"adamw"``, ``"adam"``, ``"sgd"``, etc.) or dict
            with ``"class"`` (fully qualified class path) and ``"params"`` keys.
            Supports both torch.optim and timm optimizers.
        optimizer_kwargs: Additional kwargs for the optimizer (merged with lr and weight_decay).
        scheduler: Scheduler name (``"cosine"``, ``"step"``, ``"onecycle"``, ``"cosine_with_restarts"``, etc.),
            dict with ``"class"`` and ``"params"`` keys, or ``None`` for no scheduler.
            Supports both torch.optim.lr_scheduler and timm schedulers.
        scheduler_kwargs: Extra kwargs forwarded to the LR scheduler.
        head_dropout: Dropout before the classification linear layer.
        label_smoothing: Label smoothing factor for cross-entropy.
            Not supported with ``multi_label=True`` (raises ``ValueError``).
        freeze_backbone: If ``True``, backbone parameters are frozen
            (useful for linear probing).
        mixup_alpha: If > 0, apply Mixup augmentation with this alpha.
        compile_model: If ``True`` (default), apply ``torch.compile()`` to the backbone and head
            for faster inference and training. Requires PyTorch 2.0+.
        compile_kwargs: Optional dict of kwargs to pass to ``torch.compile()``.
            Common options: ``mode`` (``"default"``, ``"reduce-overhead"``, ``"max-autotune"``),
            ``fullgraph`` (``True``/``False``), ``dynamic`` (``True``/``False``).
        seed: Random seed for reproducibility. If ``None``, no seeding is performed.
            Default is ``42`` for reproducible results.
        deterministic: If ``True`` (default), enables deterministic algorithms in PyTorch for full
            reproducibility (may impact performance). Set to ``False`` for faster training.

    Example:
        >>> # For training with metrics
        >>> model = ImageClassifier(
        ...     backbone="resnet50",
        ...     num_classes=10,
        ...     metrics=[
        ...         MetricConfig(
        ...             name="accuracy",
        ...             backend="torchmetrics",
        ...             metric_class="Accuracy",
        ...             params={"task": "multiclass"},
        ...             stages=["train", "val", "test"],
        ...             prog_bar=True,
        ...         ),
        ...     ],
        ...     logging_config=LoggingConfig(
        ...         log_learning_rate=True,
        ...         log_gradient_norm=False,
        ...     ),
        ...     transform_config=TransformConfig(use_timm_config=True),
        ... )
        >>>
        >>> # For inference only (no metrics needed)
        >>> model = ImageClassifier(
        ...     backbone="resnet50",
        ...     num_classes=10,
        ...     transform_config=TransformConfig(use_timm_config=True),
        ... )
        >>> # With transform_config, you can preprocess raw images
        >>> from PIL import Image
        >>> img = Image.open("test.jpg")
        >>> tensor = model.preprocess(img)  # Returns (1, 3, 224, 224)
        >>> output = model(tensor)
    """

    def __init__(
        self,
        backbone: str | BackboneConfig,
        num_classes: int,
        multi_label: bool = False,
        threshold: float = 0.5,
        loss_fn: str | nn.Module | None = None,
        metrics: MetricManager | list[MetricConfig] | None = None,
        logging_config: LoggingConfig | None = None,
        transform_config: TransformConfig | None = None,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        optimizer: str | dict[str, Any] = "adamw",
        optimizer_kwargs: dict[str, Any] | None = None,
        scheduler: str | dict[str, Any] | None = "cosine",
        scheduler_kwargs: dict[str, Any] | None = None,
        head_dropout: float = 0.0,
        label_smoothing: float = 0.0,
        freeze_backbone: bool = False,
        mixup_alpha: float = 0.0,
        compile_model: bool = True,
        compile_kwargs: dict[str, Any] | None = None,
        seed: int | None = None,
        deterministic: bool = True,
    ):
        # Validate multi-label + label_smoothing combination
        if multi_label and label_smoothing > 0:
            raise ValueError(
                "label_smoothing is not supported with multi_label=True. "
                "BCEWithLogitsLoss does not support label smoothing."
            )

        # Seed for reproducibility
        if seed is not None:
            seed_everything(seed, deterministic=deterministic)
        elif deterministic:
            import warnings

            warnings.warn(
                "deterministic=True has no effect when seed=None. "
                "Set a seed value to enable deterministic behavior.",
                stacklevel=2,
            )

        # Normalize backbone to a plain string so it survives checkpoint
        # round-trip (BackboneConfig is not serialised by save_hyperparameters).
        backbone = backbone.model_name if hasattr(backbone, "model_name") else str(backbone)

        super().__init__()
        self.save_hyperparameters(
            ignore=["metrics", "logging_config", "transform_config", "loss_fn"]
        )
        self.hparams.update({
            "backbone_name": backbone,
            "username": getpass.getuser(),
            "timestamp": _dt.datetime.now().isoformat(timespec="seconds"),
        })

        # Backbone and head
        self.backbone = create_backbone(backbone)
        in_features = get_backbone_out_features(self.backbone)
        self.head = ClassificationHead(in_features, num_classes, dropout=head_dropout)

        # Multi-label mode
        self._multi_label = multi_label
        self._threshold = threshold

        # Setup loss function
        if loss_fn is not None:
            # Use provided loss function
            if isinstance(loss_fn, str):
                # Get from registry
                registry = get_loss_registry()
                loss_kwargs = {}

                # Handle label_smoothing for cross_entropy
                if loss_fn in ["cross_entropy", "ce"] and label_smoothing > 0:
                    loss_kwargs["label_smoothing"] = label_smoothing

                self.criterion = registry.get_loss(loss_fn, **loss_kwargs)
            elif isinstance(loss_fn, nn.Module):
                # Use custom loss instance
                self.criterion = loss_fn
            else:
                raise TypeError(
                    f"loss_fn must be a string or nn.Module instance, got {type(loss_fn)}"
                )
        else:
            # Default behavior based on multi_label
            if multi_label:
                self.criterion = nn.BCEWithLogitsLoss()
            else:
                self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

        # Initialize metrics from config
        if metrics is not None:
            if isinstance(metrics, list):
                metrics = MetricManager(configs=metrics, num_classes=num_classes)
            self._metric_manager = metrics
            # Register metrics as ModuleDicts for proper device handling
            self.train_metrics = metrics.get_train_metrics()
            self.val_metrics = metrics.get_val_metrics()
            self.test_metrics = metrics.get_test_metrics()
        else:
            self._metric_manager = None
            # Create empty ModuleDicts when no metrics are provided
            self.train_metrics = nn.ModuleDict()
            self.val_metrics = nn.ModuleDict()
            self.test_metrics = nn.ModuleDict()

        # Logging configuration
        self._logging_config = logging_config or LoggingConfig(
            log_learning_rate=False,
            log_gradient_norm=False,
        )

        # For confusion matrix logging (not meaningful for multi-label)
        if self._logging_config.log_confusion_matrix and not multi_label:
            self._val_confusion = torchmetrics.ConfusionMatrix(
                task="multiclass", num_classes=num_classes
            )

        # Store hyperparameters
        self._lr = lr
        self._weight_decay = weight_decay
        self._optimizer = optimizer
        self._optimizer_kwargs = optimizer_kwargs or {}
        self._scheduler = scheduler
        self._scheduler_kwargs = scheduler_kwargs or {}
        self._mixup_alpha = mixup_alpha
        self._num_classes = num_classes

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Apply torch.compile for optimization (PyTorch 2.0+)
        # Skip on MPS — the inductor backend generates invalid Metal shaders.
        # Skip on Windows — the inductor backend requires cl.exe (MSVC) which
        # is typically unavailable; compilation is deferred so the try/except
        # at init time cannot catch the error.
        import sys as _sys
        _mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        _skip_compile = _mps_available or _sys.platform == "win32"
        if compile_model and not _skip_compile:
            try:
                compile_opts = compile_kwargs or {}
                self.backbone = torch.compile(self.backbone, **compile_opts)
                self.head = torch.compile(self.head, **compile_opts)
            except Exception as e:
                import warnings

                warnings.warn(
                    f"torch.compile failed: {e}. Continuing without compilation. "
                    f"Ensure you have PyTorch 2.0+ for compile support.",
                    stacklevel=2,
                )
        elif compile_kwargs is not None:
            import warnings

            warnings.warn(
                "compile_kwargs is ignored when compile_model=False.",
                stacklevel=2,
            )

        # Setup transforms from config (PreprocessingMixin)
        self._setup_transforms(transform_config, task="classification")

    def on_fit_start(self) -> None:
        """Capture batch_size from the datamodule when training begins."""
        if (
            self.trainer is not None
            and self.trainer.datamodule is not None
            and hasattr(self.trainer.datamodule, "batch_size")
        ):
            self.hparams["batch_size"] = self.trainer.datamodule.batch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.head(features)

    def training_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        x, y = batch

        # BCEWithLogitsLoss requires float targets
        if self._multi_label:
            y = y.float()

        if self._mixup_alpha > 0 and self.training:
            lam = (
                torch.distributions.Beta(self._mixup_alpha, self._mixup_alpha)
                .sample()
                .to(x.device)
            )
            idx = torch.randperm(x.size(0), device=x.device)
            x = lam * x + (1 - lam) * x[idx]
            y_a, y_b = y, y[idx]
            logits = self(x)
            loss = lam * self.criterion(logits, y_a) + (1 - lam) * self.criterion(
                logits, y_b
            )
        else:
            logits = self(x)
            loss = self.criterion(logits, y)

        if self._multi_label:
            preds = (logits.sigmoid() > self._threshold).int()
        else:
            preds = logits.argmax(dim=-1)

        # Log loss
        self.log("train/loss", loss, prog_bar=True)

        # Update and log all train metrics
        if self._metric_manager is not None:
            for name, metric in self.train_metrics.items():
                config = self._metric_manager.get_metric_config("train", name)
                metric.update(preds, y)
                if config:
                    self.log(
                        f"train/{name}",
                        metric,
                        on_step=config.log_on_step,
                        on_epoch=config.log_on_epoch,
                        prog_bar=config.prog_bar,
                    )

        # Enhanced logging: learning rate
        if self._logging_config.log_learning_rate:
            opt = self.optimizers()
            if opt is not None and hasattr(opt, "param_groups"):
                lr = opt.param_groups[0]["lr"]
                self.log("train/lr", lr, on_step=True, on_epoch=False)

        return loss

    def on_before_optimizer_step(self, optimizer) -> None:
        """Hook for gradient norm logging."""
        if self._logging_config.log_gradient_norm:
            grad_norm = self._compute_gradient_norm()
            self.log("train/grad_norm", grad_norm, on_step=True, on_epoch=False)

        if self._logging_config.log_weight_norm:
            weight_norm = self._compute_weight_norm()
            self.log("train/weight_norm", weight_norm, on_step=True, on_epoch=False)

    def _compute_gradient_norm(self) -> torch.Tensor:
        """Compute the total gradient norm across all parameters."""
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return torch.tensor(total_norm**0.5, device=self.device)

    def _compute_weight_norm(self) -> torch.Tensor:
        """Compute the total weight norm across all parameters."""
        total_norm = 0.0
        for p in self.parameters():
            param_norm = p.data.norm(2)
            total_norm += param_norm.item() ** 2
        return torch.tensor(total_norm**0.5, device=self.device)

    def validation_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        x, y = batch

        if self._multi_label:
            y = y.float()

        logits = self(x)
        loss = self.criterion(logits, y)

        if self._multi_label:
            preds = (logits.sigmoid() > self._threshold).int()
        else:
            preds = logits.argmax(dim=-1)

        try:
            is_sanity = getattr(self.trainer, "sanity_checking", False)
        except RuntimeError:
            is_sanity = False
        prefix = "sanity_val" if is_sanity else "val"

        # Log loss
        self.log(f"{prefix}/loss", loss, prog_bar=True)

        # Update and log all val metrics
        if self._metric_manager is not None:
            for name, metric in self.val_metrics.items():
                config = self._metric_manager.get_metric_config("val", name)
                metric.update(preds, y)
                if config:
                    self.log(
                        f"{prefix}/{name}",
                        metric,
                        on_step=config.log_on_step,
                        on_epoch=config.log_on_epoch,
                        prog_bar=config.prog_bar,
                    )

        # Update confusion matrix if enabled (not meaningful for multi-label)
        if self._logging_config.log_confusion_matrix and not self._multi_label:
            self._val_confusion.update(preds, y)

    def on_validation_epoch_end(self) -> None:
        """Log confusion matrix at the end of validation epoch."""
        if not self._logging_config.log_confusion_matrix or self._multi_label:
            return

        if self.logger is None:
            return

        cm = self._val_confusion.compute()

        # Log as image if matplotlib is available
        try:
            import matplotlib

            matplotlib.use("Agg")
            import matplotlib.pyplot as plt

            fig = self._create_confusion_matrix_figure(cm.cpu().numpy())

            # Try to log to tensorboard or other loggers
            if hasattr(self.logger, "experiment") and hasattr(
                self.logger.experiment, "add_figure"
            ):
                self.logger.experiment.add_figure(
                    "val/confusion_matrix", fig, self.current_epoch
                )

            plt.close(fig)
        except ImportError:
            # matplotlib not available, skip confusion matrix visualization
            pass

        self._val_confusion.reset()

    def _create_confusion_matrix_figure(self, cm_np: Any) -> Any:
        """Create a professional confusion matrix visualization.

        Parameters:
            cm_np: Confusion matrix as numpy array.

        Returns:
            matplotlib Figure object.
        """
        import matplotlib.pyplot as plt
        import numpy as np

        # Calculate normalized confusion matrix for percentages
        cm_normalized = cm_np.astype("float") / (
            cm_np.sum(axis=1)[:, np.newaxis] + 1e-10
        )

        # Create figure with appropriate size
        fig_size = max(10, min(20, self._num_classes * 0.8))
        fig, ax = plt.subplots(figsize=(fig_size, fig_size))

        # Use a professional color scheme
        cmap = plt.cm.Blues
        im = ax.imshow(
            cm_normalized, interpolation="nearest", cmap=cmap, vmin=0, vmax=1
        )

        # Add colorbar with label
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label(
            "Normalized Count (Recall)", rotation=270, labelpad=25, fontsize=12
        )

        # Set ticks and labels
        num_classes = cm_np.shape[0]
        tick_marks = np.arange(num_classes)

        # Create class labels (use class indices if no names available)
        class_labels = [str(i) for i in range(num_classes)]

        ax.set_xticks(tick_marks)
        ax.set_yticks(tick_marks)
        ax.set_xticklabels(class_labels, fontsize=10, rotation=45, ha="right")
        ax.set_yticklabels(class_labels, fontsize=10)

        # Add text annotations to each cell
        thresh = cm_normalized.max() / 2.0
        for i in range(num_classes):
            for j in range(num_classes):
                count = cm_np[i, j]
                percentage = cm_normalized[i, j] * 100

                # Choose text color based on background
                color = "white" if cm_normalized[i, j] > thresh else "black"

                # Display count and percentage
                text_str = f"{int(count)}\n({percentage:.1f}%)"
                ax.text(
                    j,
                    i,
                    text_str,
                    ha="center",
                    va="center",
                    color=color,
                    fontsize=max(8, min(12, 100 // num_classes)),
                    weight="bold" if i == j else "normal",
                )

        # Calculate overall accuracy
        accuracy = np.trace(cm_np) / (np.sum(cm_np) + 1e-10) * 100

        # Set labels and title with metrics
        ax.set_xlabel("Predicted Label", fontsize=14, weight="bold")
        ax.set_ylabel("True Label", fontsize=14, weight="bold")
        ax.set_title(
            f"Confusion Matrix - Epoch {self.current_epoch}\n"
            f"Validation Accuracy: {accuracy:.2f}%",
            fontsize=16,
            weight="bold",
            pad=20,
        )

        # Add grid for better readability
        ax.set_xticks(tick_marks - 0.5, minor=True)
        ax.set_yticks(tick_marks - 0.5, minor=True)
        ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5, alpha=0.3)

        # Adjust layout to prevent label cutoff
        plt.tight_layout()

        return fig

    def test_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> dict[str, torch.Tensor]:
        x, y = batch

        if self._multi_label:
            y = y.float()

        logits = self(x)
        loss = self.criterion(logits, y)

        if self._multi_label:
            preds = (logits.sigmoid() > self._threshold).int()
        else:
            preds = logits.argmax(dim=-1)

        # Log loss
        self.log("test/loss", loss)

        # Update and log all test metrics
        if self._metric_manager is not None:
            for name, metric in self.test_metrics.items():
                config = self._metric_manager.get_metric_config("test", name)
                metric.update(preds, y)
                if config:
                    self.log(
                        f"test/{name}",
                        metric,
                        on_step=config.log_on_step,
                        on_epoch=config.log_on_epoch,
                        prog_bar=config.prog_bar,
                    )

        return {"preds": preds.detach(), "targets": y.detach()}

    def predict_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
        x = batch[0] if isinstance(batch, (tuple, list)) else batch
        logits = self(x)
        if self._multi_label:
            return logits.sigmoid()
        return logits.softmax(dim=-1)

    def to_torchscript(
        self,
        save_path: str | None = None,
        example_input: torch.Tensor | None = None,
        method: str = "trace",
        **kwargs: Any,
    ) -> torch.jit.ScriptModule:
        """Export model to TorchScript format.

        Args:
            save_path: Optional path to save the TorchScript model. If None, returns compiled model without saving.
            example_input: Example input tensor for tracing. If None, uses default shape (1, 3, 224, 224).
            method: Export method ("trace" or "script"). Default is "trace".
            **kwargs: Additional arguments passed to export_to_torchscript.

        Returns:
            Compiled TorchScript module.

        Example:
            >>> model = ImageClassifier(backbone="resnet50", num_classes=10)
            >>> scripted = model.to_torchscript("model.pt")
        """
        from autotimm.export import export_to_torchscript

        if example_input is None:
            example_input = torch.randn(1, 3, 224, 224)

        if save_path is None:
            # Return scripted model without saving
            import tempfile

            with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp:
                scripted = export_to_torchscript(
                    self, tmp.name, example_input, method, **kwargs
                )
                import os

                os.unlink(tmp.name)
                return scripted
        else:
            return export_to_torchscript(
                self, save_path, example_input, method, **kwargs
            )

    def to_onnx(
        self,
        save_path: str | None = None,
        example_input: torch.Tensor | None = None,
        opset_version: int = 17,
        dynamic_axes: dict[str, dict[int, str]] | None = None,
        **kwargs: Any,
    ) -> str:
        """Export model to ONNX format.

        Args:
            save_path: Path to save the ONNX model. If None, uses a temp file.
            example_input: Example input tensor. If None, uses default shape (1, 3, 224, 224).
            opset_version: ONNX opset version. Default is 17.
            dynamic_axes: Dynamic axes specification. If None, batch dimension is dynamic.
            **kwargs: Additional arguments passed to export_to_onnx.

        Returns:
            Path to the saved ONNX model.

        Example:
            >>> model = ImageClassifier(backbone="resnet50", num_classes=10)
            >>> path = model.to_onnx("model.onnx")
        """
        from autotimm.export import export_to_onnx

        if example_input is None:
            example_input = torch.randn(1, 3, 224, 224)

        if save_path is None:
            import tempfile

            save_path = tempfile.mktemp(suffix=".onnx")

        return export_to_onnx(
            self,
            save_path,
            example_input,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes,
            **kwargs,
        )

    def configure_optimizers(self) -> dict:
        """Configure optimizer and learning rate scheduler.

        Supports torch.optim, timm optimizers, and custom optimizers/schedulers.
        """
        params = filter(lambda p: p.requires_grad, self.parameters())

        # Create optimizer
        optimizer = self._create_optimizer(params)

        # Return early if no scheduler
        if self._scheduler is None or self._scheduler == "none":
            return {"optimizer": optimizer}

        # Create scheduler
        scheduler_config = self._create_scheduler(optimizer)

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler_config,
        }

    def _create_optimizer(self, params) -> torch.optim.Optimizer:
        """Create optimizer from config."""
        # Prepare base kwargs
        opt_kwargs = {"lr": self._lr, "weight_decay": self._weight_decay}
        opt_kwargs.update(self._optimizer_kwargs)

        # Dict config: {"class": "path.to.Optimizer", "params": {...}}
        if isinstance(self._optimizer, dict):
            opt_class_path = self._optimizer["class"]
            opt_params = self._optimizer.get("params", {})
            opt_kwargs.update(opt_params)

            # Import and instantiate
            optimizer_cls = self._import_class(opt_class_path)
            return optimizer_cls(params, **opt_kwargs)

        # String name: try torch.optim first, then timm
        optimizer_name = self._optimizer.lower()

        # Torch optimizers
        torch_optimizers = {
            "adamw": torch.optim.AdamW,
            "adam": torch.optim.Adam,
            "sgd": torch.optim.SGD,
            "rmsprop": torch.optim.RMSprop,
            "adagrad": torch.optim.Adagrad,
        }

        if optimizer_name in torch_optimizers:
            return torch_optimizers[optimizer_name](params, **opt_kwargs)

        # Try timm optimizers
        try:
            import timm.optim as timm_optim

            timm_optimizers = {
                "adamp": timm_optim.AdamP,
                "sgdp": timm_optim.SGDP,
                "adabelief": timm_optim.AdaBelief,
                "radam": timm_optim.RAdam,
                "adahessian": timm_optim.Adahessian,
                "lamb": timm_optim.Lamb,
                "lars": timm_optim.Lars,
                "madgrad": timm_optim.MADGRAD,
                "novograd": timm_optim.NovGrad,
            }

            if optimizer_name in timm_optimizers:
                return timm_optimizers[optimizer_name](params, **opt_kwargs)
        except ImportError:
            pass

        raise ValueError(
            f"Unknown optimizer: {self._optimizer}. "
            f"Use a torch.optim optimizer name, timm optimizer name, "
            f"or provide a dict with 'class' and 'params' keys."
        )

    def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> dict:
        """Create scheduler config from optimizer."""
        sched_kwargs = self._scheduler_kwargs.copy()
        interval = sched_kwargs.pop("interval", "step")
        frequency = sched_kwargs.pop("frequency", 1)

        # Dict config: {"class": "path.to.Scheduler", "params": {...}}
        if isinstance(self._scheduler, dict):
            sched_class_path = self._scheduler["class"]
            sched_params = self._scheduler.get("params", {})
            sched_kwargs.update(sched_params)

            scheduler_cls = self._import_class(sched_class_path)
            scheduler = scheduler_cls(optimizer, **sched_kwargs)

            return {
                "scheduler": scheduler,
                "interval": interval,
                "frequency": frequency,
            }

        # String name: try torch schedulers first, then timm
        scheduler_name = self._scheduler.lower()

        # Torch schedulers
        if scheduler_name == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=sched_kwargs.pop(
                    "T_max", self.trainer.estimated_stepping_batches
                ),
                **sched_kwargs,
            )
        elif scheduler_name == "step":
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=sched_kwargs.pop("step_size", 30),
                gamma=sched_kwargs.pop("gamma", 0.1),
                **sched_kwargs,
            )
        elif scheduler_name == "multistep":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=sched_kwargs.pop("milestones", [30, 60, 90]),
                gamma=sched_kwargs.pop("gamma", 0.1),
                **sched_kwargs,
            )
        elif scheduler_name == "exponential":
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer,
                gamma=sched_kwargs.pop("gamma", 0.95),
                **sched_kwargs,
            )
        elif scheduler_name == "onecycle":
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=sched_kwargs.pop("max_lr", self._lr * 10),
                total_steps=sched_kwargs.pop(
                    "total_steps", self.trainer.estimated_stepping_batches
                ),
                **sched_kwargs,
            )
        else:
            # Try timm schedulers
            try:
                import timm.scheduler as timm_scheduler

                if scheduler_name == "cosine_with_restarts":
                    scheduler = timm_scheduler.CosineLRScheduler(
                        optimizer,
                        t_initial=sched_kwargs.pop(
                            "t_initial", self.trainer.max_epochs
                        ),
                        cycle_limit=sched_kwargs.pop("cycle_limit", 1),
                        **sched_kwargs,
                    )
                    interval = "epoch"
                elif scheduler_name == "plateau":
                    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                        optimizer,
                        mode=sched_kwargs.pop("mode", "min"),
                        factor=sched_kwargs.pop("factor", 0.1),
                        patience=sched_kwargs.pop("patience", 10),
                        **sched_kwargs,
                    )
                    interval = "epoch"
                    return {
                        "scheduler": scheduler,
                        "monitor": sched_kwargs.pop("monitor", "val/loss"),
                        "interval": interval,
                        "frequency": frequency,
                    }
                else:
                    raise ValueError(f"Unknown scheduler: {self._scheduler}")
            except (ImportError, ValueError):
                raise ValueError(
                    f"Unknown scheduler: {self._scheduler}. "
                    f"Use a torch.optim.lr_scheduler name, timm scheduler name, "
                    f"or provide a dict with 'class' and 'params' keys."
                )

        return {
            "scheduler": scheduler,
            "interval": interval,
            "frequency": frequency,
        }

    def _import_class(self, class_path: str):
        """Import a class from a fully qualified path."""
        import importlib

        if "." not in class_path:
            raise ValueError(
                f"Class path must be fully qualified (e.g., 'torch.optim.Adam'), "
                f"got: {class_path}"
            )

        module_path, class_name = class_path.rsplit(".", 1)
        module = importlib.import_module(module_path)
        return getattr(module, class_name)

__init__

__init__(backbone: str | BackboneConfig, num_classes: int, multi_label: bool = False, threshold: float = 0.5, loss_fn: str | Module | None = None, metrics: MetricManager | list[MetricConfig] | None = None, logging_config: LoggingConfig | None = None, transform_config: TransformConfig | None = None, lr: float = 0.001, weight_decay: float = 0.0001, optimizer: str | dict[str, Any] = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: str | dict[str, Any] | None = 'cosine', scheduler_kwargs: dict[str, Any] | None = None, head_dropout: float = 0.0, label_smoothing: float = 0.0, freeze_backbone: bool = False, mixup_alpha: float = 0.0, compile_model: bool = True, compile_kwargs: dict[str, Any] | None = None, seed: int | None = None, deterministic: bool = True)
Source code in src/autotimm/tasks/classification.py
def __init__(
    self,
    backbone: str | BackboneConfig,
    num_classes: int,
    multi_label: bool = False,
    threshold: float = 0.5,
    loss_fn: str | nn.Module | None = None,
    metrics: MetricManager | list[MetricConfig] | None = None,
    logging_config: LoggingConfig | None = None,
    transform_config: TransformConfig | None = None,
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    optimizer: str | dict[str, Any] = "adamw",
    optimizer_kwargs: dict[str, Any] | None = None,
    scheduler: str | dict[str, Any] | None = "cosine",
    scheduler_kwargs: dict[str, Any] | None = None,
    head_dropout: float = 0.0,
    label_smoothing: float = 0.0,
    freeze_backbone: bool = False,
    mixup_alpha: float = 0.0,
    compile_model: bool = True,
    compile_kwargs: dict[str, Any] | None = None,
    seed: int | None = None,
    deterministic: bool = True,
):
    # Validate multi-label + label_smoothing combination
    if multi_label and label_smoothing > 0:
        raise ValueError(
            "label_smoothing is not supported with multi_label=True. "
            "BCEWithLogitsLoss does not support label smoothing."
        )

    # Seed for reproducibility
    if seed is not None:
        seed_everything(seed, deterministic=deterministic)
    elif deterministic:
        import warnings

        warnings.warn(
            "deterministic=True has no effect when seed=None. "
            "Set a seed value to enable deterministic behavior.",
            stacklevel=2,
        )

    # Normalize backbone to a plain string so it survives checkpoint
    # round-trip (BackboneConfig is not serialised by save_hyperparameters).
    backbone = backbone.model_name if hasattr(backbone, "model_name") else str(backbone)

    super().__init__()
    self.save_hyperparameters(
        ignore=["metrics", "logging_config", "transform_config", "loss_fn"]
    )
    self.hparams.update({
        "backbone_name": backbone,
        "username": getpass.getuser(),
        "timestamp": _dt.datetime.now().isoformat(timespec="seconds"),
    })

    # Backbone and head
    self.backbone = create_backbone(backbone)
    in_features = get_backbone_out_features(self.backbone)
    self.head = ClassificationHead(in_features, num_classes, dropout=head_dropout)

    # Multi-label mode
    self._multi_label = multi_label
    self._threshold = threshold

    # Setup loss function
    if loss_fn is not None:
        # Use provided loss function
        if isinstance(loss_fn, str):
            # Get from registry
            registry = get_loss_registry()
            loss_kwargs = {}

            # Handle label_smoothing for cross_entropy
            if loss_fn in ["cross_entropy", "ce"] and label_smoothing > 0:
                loss_kwargs["label_smoothing"] = label_smoothing

            self.criterion = registry.get_loss(loss_fn, **loss_kwargs)
        elif isinstance(loss_fn, nn.Module):
            # Use custom loss instance
            self.criterion = loss_fn
        else:
            raise TypeError(
                f"loss_fn must be a string or nn.Module instance, got {type(loss_fn)}"
            )
    else:
        # Default behavior based on multi_label
        if multi_label:
            self.criterion = nn.BCEWithLogitsLoss()
        else:
            self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    # Initialize metrics from config
    if metrics is not None:
        if isinstance(metrics, list):
            metrics = MetricManager(configs=metrics, num_classes=num_classes)
        self._metric_manager = metrics
        # Register metrics as ModuleDicts for proper device handling
        self.train_metrics = metrics.get_train_metrics()
        self.val_metrics = metrics.get_val_metrics()
        self.test_metrics = metrics.get_test_metrics()
    else:
        self._metric_manager = None
        # Create empty ModuleDicts when no metrics are provided
        self.train_metrics = nn.ModuleDict()
        self.val_metrics = nn.ModuleDict()
        self.test_metrics = nn.ModuleDict()

    # Logging configuration
    self._logging_config = logging_config or LoggingConfig(
        log_learning_rate=False,
        log_gradient_norm=False,
    )

    # For confusion matrix logging (not meaningful for multi-label)
    if self._logging_config.log_confusion_matrix and not multi_label:
        self._val_confusion = torchmetrics.ConfusionMatrix(
            task="multiclass", num_classes=num_classes
        )

    # Store hyperparameters
    self._lr = lr
    self._weight_decay = weight_decay
    self._optimizer = optimizer
    self._optimizer_kwargs = optimizer_kwargs or {}
    self._scheduler = scheduler
    self._scheduler_kwargs = scheduler_kwargs or {}
    self._mixup_alpha = mixup_alpha
    self._num_classes = num_classes

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False

    # Apply torch.compile for optimization (PyTorch 2.0+)
    # Skip on MPS — the inductor backend generates invalid Metal shaders.
    # Skip on Windows — the inductor backend requires cl.exe (MSVC) which
    # is typically unavailable; compilation is deferred so the try/except
    # at init time cannot catch the error.
    import sys as _sys
    _mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    _skip_compile = _mps_available or _sys.platform == "win32"
    if compile_model and not _skip_compile:
        try:
            compile_opts = compile_kwargs or {}
            self.backbone = torch.compile(self.backbone, **compile_opts)
            self.head = torch.compile(self.head, **compile_opts)
        except Exception as e:
            import warnings

            warnings.warn(
                f"torch.compile failed: {e}. Continuing without compilation. "
                f"Ensure you have PyTorch 2.0+ for compile support.",
                stacklevel=2,
            )
    elif compile_kwargs is not None:
        import warnings

        warnings.warn(
            "compile_kwargs is ignored when compile_model=False.",
            stacklevel=2,
        )

    # Setup transforms from config (PreprocessingMixin)
    self._setup_transforms(transform_config, task="classification")

forward

forward(x: Tensor) -> torch.Tensor
Source code in src/autotimm/tasks/classification.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    features = self.backbone(x)
    return self.head(features)

training_step

training_step(batch: tuple[Tensor, Tensor], batch_idx: int) -> torch.Tensor
Source code in src/autotimm/tasks/classification.py
def training_step(
    self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
    x, y = batch

    # BCEWithLogitsLoss requires float targets
    if self._multi_label:
        y = y.float()

    if self._mixup_alpha > 0 and self.training:
        lam = (
            torch.distributions.Beta(self._mixup_alpha, self._mixup_alpha)
            .sample()
            .to(x.device)
        )
        idx = torch.randperm(x.size(0), device=x.device)
        x = lam * x + (1 - lam) * x[idx]
        y_a, y_b = y, y[idx]
        logits = self(x)
        loss = lam * self.criterion(logits, y_a) + (1 - lam) * self.criterion(
            logits, y_b
        )
    else:
        logits = self(x)
        loss = self.criterion(logits, y)

    if self._multi_label:
        preds = (logits.sigmoid() > self._threshold).int()
    else:
        preds = logits.argmax(dim=-1)

    # Log loss
    self.log("train/loss", loss, prog_bar=True)

    # Update and log all train metrics
    if self._metric_manager is not None:
        for name, metric in self.train_metrics.items():
            config = self._metric_manager.get_metric_config("train", name)
            metric.update(preds, y)
            if config:
                self.log(
                    f"train/{name}",
                    metric,
                    on_step=config.log_on_step,
                    on_epoch=config.log_on_epoch,
                    prog_bar=config.prog_bar,
                )

    # Enhanced logging: learning rate
    if self._logging_config.log_learning_rate:
        opt = self.optimizers()
        if opt is not None and hasattr(opt, "param_groups"):
            lr = opt.param_groups[0]["lr"]
            self.log("train/lr", lr, on_step=True, on_epoch=False)

    return loss

validation_step

validation_step(batch: tuple[Tensor, Tensor], batch_idx: int) -> None
Source code in src/autotimm/tasks/classification.py
def validation_step(
    self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> None:
    x, y = batch

    if self._multi_label:
        y = y.float()

    logits = self(x)
    loss = self.criterion(logits, y)

    if self._multi_label:
        preds = (logits.sigmoid() > self._threshold).int()
    else:
        preds = logits.argmax(dim=-1)

    try:
        is_sanity = getattr(self.trainer, "sanity_checking", False)
    except RuntimeError:
        is_sanity = False
    prefix = "sanity_val" if is_sanity else "val"

    # Log loss
    self.log(f"{prefix}/loss", loss, prog_bar=True)

    # Update and log all val metrics
    if self._metric_manager is not None:
        for name, metric in self.val_metrics.items():
            config = self._metric_manager.get_metric_config("val", name)
            metric.update(preds, y)
            if config:
                self.log(
                    f"{prefix}/{name}",
                    metric,
                    on_step=config.log_on_step,
                    on_epoch=config.log_on_epoch,
                    prog_bar=config.prog_bar,
                )

    # Update confusion matrix if enabled (not meaningful for multi-label)
    if self._logging_config.log_confusion_matrix and not self._multi_label:
        self._val_confusion.update(preds, y)

test_step

test_step(batch: tuple[Tensor, Tensor], batch_idx: int) -> dict[str, torch.Tensor]
Source code in src/autotimm/tasks/classification.py
def test_step(
    self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> dict[str, torch.Tensor]:
    x, y = batch

    if self._multi_label:
        y = y.float()

    logits = self(x)
    loss = self.criterion(logits, y)

    if self._multi_label:
        preds = (logits.sigmoid() > self._threshold).int()
    else:
        preds = logits.argmax(dim=-1)

    # Log loss
    self.log("test/loss", loss)

    # Update and log all test metrics
    if self._metric_manager is not None:
        for name, metric in self.test_metrics.items():
            config = self._metric_manager.get_metric_config("test", name)
            metric.update(preds, y)
            if config:
                self.log(
                    f"test/{name}",
                    metric,
                    on_step=config.log_on_step,
                    on_epoch=config.log_on_epoch,
                    prog_bar=config.prog_bar,
                )

    return {"preds": preds.detach(), "targets": y.detach()}

predict_step

predict_step(batch: Any, batch_idx: int) -> torch.Tensor
Source code in src/autotimm/tasks/classification.py
def predict_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
    x = batch[0] if isinstance(batch, (tuple, list)) else batch
    logits = self(x)
    if self._multi_label:
        return logits.sigmoid()
    return logits.softmax(dim=-1)

configure_optimizers

configure_optimizers() -> dict

Configure optimizer and learning rate scheduler.

Supports torch.optim, timm optimizers, and custom optimizers/schedulers.

Source code in src/autotimm/tasks/classification.py
def configure_optimizers(self) -> dict:
    """Configure optimizer and learning rate scheduler.

    Supports torch.optim, timm optimizers, and custom optimizers/schedulers.
    """
    params = filter(lambda p: p.requires_grad, self.parameters())

    # Create optimizer
    optimizer = self._create_optimizer(params)

    # Return early if no scheduler
    if self._scheduler is None or self._scheduler == "none":
        return {"optimizer": optimizer}

    # Create scheduler
    scheduler_config = self._create_scheduler(optimizer)

    return {
        "optimizer": optimizer,
        "lr_scheduler": scheduler_config,
    }

to_torchscript

to_torchscript(save_path: str | None = None, example_input: Tensor | None = None, method: str = 'trace', **kwargs: Any) -> torch.jit.ScriptModule

Export model to TorchScript format.

Parameters:

Name Type Description Default
save_path str | None

Optional path to save the TorchScript model. If None, returns compiled model without saving.

None
example_input Tensor | None

Example input tensor for tracing. If None, uses default shape (1, 3, 224, 224).

None
method str

Export method ("trace" or "script"). Default is "trace".

'trace'
**kwargs Any

Additional arguments passed to export_to_torchscript.

{}

Returns:

Type Description
ScriptModule

Compiled TorchScript module.

Example

model = ImageClassifier(backbone="resnet50", num_classes=10) scripted = model.to_torchscript("model.pt")

Source code in src/autotimm/tasks/classification.py
def to_torchscript(
    self,
    save_path: str | None = None,
    example_input: torch.Tensor | None = None,
    method: str = "trace",
    **kwargs: Any,
) -> torch.jit.ScriptModule:
    """Export model to TorchScript format.

    Args:
        save_path: Optional path to save the TorchScript model. If None, returns compiled model without saving.
        example_input: Example input tensor for tracing. If None, uses default shape (1, 3, 224, 224).
        method: Export method ("trace" or "script"). Default is "trace".
        **kwargs: Additional arguments passed to export_to_torchscript.

    Returns:
        Compiled TorchScript module.

    Example:
        >>> model = ImageClassifier(backbone="resnet50", num_classes=10)
        >>> scripted = model.to_torchscript("model.pt")
    """
    from autotimm.export import export_to_torchscript

    if example_input is None:
        example_input = torch.randn(1, 3, 224, 224)

    if save_path is None:
        # Return scripted model without saving
        import tempfile

        with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp:
            scripted = export_to_torchscript(
                self, tmp.name, example_input, method, **kwargs
            )
            import os

            os.unlink(tmp.name)
            return scripted
    else:
        return export_to_torchscript(
            self, save_path, example_input, method, **kwargs
        )

to_onnx

to_onnx(save_path: str | None = None, example_input: Tensor | None = None, opset_version: int = 17, dynamic_axes: dict[str, dict[int, str]] | None = None, **kwargs: Any) -> str

Export model to ONNX format.

Parameters:

Name Type Description Default
save_path str | None

Path to save the ONNX model. If None, uses a temp file.

None
example_input Tensor | None

Example input tensor. If None, uses default shape (1, 3, 224, 224).

None
opset_version int

ONNX opset version. Default is 17.

17
dynamic_axes dict[str, dict[int, str]] | None

Dynamic axes specification. If None, batch dimension is dynamic.

None
**kwargs Any

Additional arguments passed to export_to_onnx.

{}

Returns:

Type Description
str

Path to the saved ONNX model.

Example

model = ImageClassifier(backbone="resnet50", num_classes=10) path = model.to_onnx("model.onnx")

Source code in src/autotimm/tasks/classification.py
def to_onnx(
    self,
    save_path: str | None = None,
    example_input: torch.Tensor | None = None,
    opset_version: int = 17,
    dynamic_axes: dict[str, dict[int, str]] | None = None,
    **kwargs: Any,
) -> str:
    """Export model to ONNX format.

    Args:
        save_path: Path to save the ONNX model. If None, uses a temp file.
        example_input: Example input tensor. If None, uses default shape (1, 3, 224, 224).
        opset_version: ONNX opset version. Default is 17.
        dynamic_axes: Dynamic axes specification. If None, batch dimension is dynamic.
        **kwargs: Additional arguments passed to export_to_onnx.

    Returns:
        Path to the saved ONNX model.

    Example:
        >>> model = ImageClassifier(backbone="resnet50", num_classes=10)
        >>> path = model.to_onnx("model.onnx")
    """
    from autotimm.export import export_to_onnx

    if example_input is None:
        example_input = torch.randn(1, 3, 224, 224)

    if save_path is None:
        import tempfile

        save_path = tempfile.mktemp(suffix=".onnx")

    return export_to_onnx(
        self,
        save_path,
        example_input,
        opset_version=opset_version,
        dynamic_axes=dynamic_axes,
        **kwargs,
    )

Usage Examples

Basic Usage

import autotimm as at  # recommended alias
from autotimm import ImageClassifier, MetricConfig

metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
        prog_bar=True,
    ),
]

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
)

With BackboneConfig

from autotimm import BackboneConfig, ImageClassifier

cfg = BackboneConfig(
    model_name="vit_base_patch16_224",
    pretrained=True,
    drop_path_rate=0.1,
)

model = ImageClassifier(
    backbone=cfg,
    num_classes=100,
    metrics=metrics,
)

With Enhanced Logging

from autotimm import LoggingConfig

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    logging_config=LoggingConfig(
        log_learning_rate=True,
        log_gradient_norm=True,
        log_confusion_matrix=True,
    ),
)

With Custom Optimizer

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    optimizer="adamw",
    lr=1e-3,
    weight_decay=1e-4,
    optimizer_kwargs={"betas": (0.9, 0.999)},
)

With Custom Scheduler

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    scheduler="onecycle",
    scheduler_kwargs={"max_lr": 1e-2},
)

Frozen Backbone (Linear Probing)

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    freeze_backbone=True,
    lr=1e-2,
)

With Regularization

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    label_smoothing=0.1,
    mixup_alpha=0.2,
    head_dropout=0.5,
)

Multi-Label Classification

from autotimm import ImageClassifier, MetricConfig

model = ImageClassifier(
    backbone="resnet50",
    num_classes=4,           # number of labels
    multi_label=True,        # BCEWithLogitsLoss + sigmoid
    threshold=0.5,           # prediction threshold
    metrics=[
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="MultilabelAccuracy",
            params={"num_labels": 4},
            stages=["train", "val"],
            prog_bar=True,
        ),
    ],
)

With TransformConfig (Preprocessing)

Enable inference-time preprocessing with model-specific normalization:

from autotimm import TransformConfig

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    transform_config=TransformConfig(),  # Use model's pretrained normalization
)

# Now you can preprocess raw images
from PIL import Image
image = Image.open("test.jpg")
tensor = model.preprocess(image)  # Returns (1, 3, 224, 224)
output = model(tensor)

Preprocessing Multiple Images

model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
    transform_config=TransformConfig(image_size=384),
)

# Preprocess a batch of images
images = [Image.open(f"img{i}.jpg") for i in range(4)]
tensor = model.preprocess(images)  # Returns (4, 3, 384, 384)

# Forward pass
model.eval()
with torch.inference_mode():
    predictions = model(tensor).softmax(dim=1)

Get Model's Data Config

model = ImageClassifier(
    backbone="vit_base_patch16_224",
    num_classes=100,
    metrics=metrics,
    transform_config=TransformConfig(),
)

# Get normalization config
config = model.get_data_config()
print(f"Mean: {config['mean']}")      # (0.5, 0.5, 0.5) for ViT
print(f"Std: {config['std']}")        # (0.5, 0.5, 0.5) for ViT
print(f"Input size: {config['input_size']}")

Parameters

Parameter Type Default Description
backbone str \| BackboneConfig Required Model name or config
num_classes int Required Number of target classes (or labels)
multi_label bool False Enable multi-label classification
threshold float 0.5 Prediction threshold (multi-label only)
metrics MetricManager \| list[MetricConfig] None Metrics configuration
logging_config LoggingConfig \| None None Enhanced logging options
transform_config TransformConfig \| None None Transform config for preprocessing
lr float 1e-3 Learning rate
weight_decay float 1e-4 Weight decay
optimizer str \| dict "adamw" Optimizer name or config
optimizer_kwargs dict \| None None Extra optimizer kwargs
scheduler str \| dict \| None "cosine" Scheduler name or config
scheduler_kwargs dict \| None None Extra scheduler kwargs
head_dropout float 0.0 Dropout before classifier
label_smoothing float 0.0 Label smoothing (not for multi-label)
freeze_backbone bool False Freeze backbone weights
mixup_alpha float 0.0 Mixup augmentation alpha
loss_fn str \| nn.Module \| None None Loss function (string from registry, nn.Module, or None for default)
compile_model bool True Apply torch.compile() for faster training/inference (PyTorch 2.0+)
compile_kwargs dict \| None None Kwargs for torch.compile() (e.g., mode, fullgraph, dynamic)
seed int \| None None Random seed for reproducibility (None to disable seeding)
deterministic bool True Enable deterministic algorithms (may impact performance)

Supported Optimizers

Torch: adamw, adam, sgd, rmsprop, adagrad

Timm: adamp, sgdp, adabelief, radam, lamb, lars, madgrad, novograd

Custom:

optimizer={
    "class": "torch.optim.AdamW",
    "params": {"betas": (0.9, 0.999)},
}

Supported Schedulers

Torch: cosine, step, multistep, exponential, onecycle, plateau

Timm: cosine_with_restarts

Custom:

scheduler={
    "class": "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts",
    "params": {"T_0": 10},
}

Model Architecture

ImageClassifier
├── backbone (timm model, headless)
│   └── num_features outputs
├── head (ClassificationHead)
│   ├── dropout (if head_dropout > 0)
│   └── Linear(num_features, num_classes)
└── criterion
    ├── CrossEntropyLoss  (multi_label=False, default)
    └── BCEWithLogitsLoss (multi_label=True)

Logged Metrics

Metric Stage Condition
{stage}/loss train, val, test Always
{stage}/{metric_name} As configured Per MetricConfig
train/lr train log_learning_rate=True
train/grad_norm train log_gradient_norm=True
train/weight_norm train log_weight_norm=True
val/confusion_matrix val log_confusion_matrix=True