Skip to content

AutoTrainer & TunerConfig

Configured PyTorch Lightning Trainer with auto-tuning support. For YAML-config-driven training from the command line, see AutoTimmCLI.

AutoTrainer

A convenience wrapper around pl.Trainer with sensible defaults for AutoTimm.

API Reference

autotimm.AutoTrainer

Bases: Trainer

A configured pl.Trainer with sensible defaults for autotimm.

This is a convenience class that wires up the logger, checkpointing, and automatic hyperparameter tuning. All **trainer_kwargs are forwarded to pl.Trainer, so any Lightning Trainer argument works.

Auto-tuning is enabled by default - both learning rate and batch size finding are enabled unless explicitly disabled via tuner_config.

Parameters:

Name Type Description Default
max_epochs int

Number of training epochs.

10
accelerator str

"auto", "gpu", "cpu", "tpu", etc.

'auto'
devices int | str

Number of devices or "auto".

'auto'
precision str | int

Training precision (32, 16, "bf16-mixed", etc.).

32
logger LoggerManager | list[LoggerConfig] | Logger | list[Logger] | bool

A LoggerManager instance, a list of LoggerConfig objects, a pre-built Logger instance/list, or False to disable logging.

False
tuner_config TunerConfig | None | bool

A TunerConfig instance to configure automatic learning rate and/or batch size finding. If None, a default TunerConfig() is created with both auto_lr and auto_batch_size enabled. To disable auto-tuning completely, pass False.

None
checkpoint_monitor str | None

Metric to monitor for checkpointing (e.g., "val/accuracy"). If None, no automatic checkpoint callback is added.

'val/loss'
checkpoint_mode str

One of "min" or "max" for checkpoint monitoring.

'min'
callbacks list[Callback] | None

List of Lightning callbacks.

None
default_root_dir str

Root directory for logs and checkpoints.

'.'
gradient_clip_val float | None

Gradient clipping value.

None
accumulate_grad_batches int

Gradient accumulation steps.

1
val_check_interval float | int

How often to run validation (1.0 = every epoch).

1.0
enable_checkpointing bool

Whether to save model checkpoints.

True
fast_dev_run bool | int

Runs a single batch through train, val, and test to find bugs quickly. Can be True (1 batch), False (disabled), or an integer (number of batches). Useful for debugging.

False
seed int | None

Random seed for reproducibility. If None, no seeding is performed. Default is 42 for reproducible results.

None
deterministic bool

If True (default), enables deterministic algorithms in PyTorch for full reproducibility (may impact performance). Set to False for faster training.

False
use_autotimm_seeding bool

If True, uses AutoTimm's custom seed_everything() function which provides comprehensive seeding with deterministic mode support. If False (default), uses PyTorch Lightning's built-in seeding.

False
json_progress bool

If True, appends a JsonProgressCallback that emits NDJSON progress events to stdout for consumption by a Tauri/Preact frontend. Default False.

False
json_progress_every_n_steps int

How often to emit batch_end events when json_progress is enabled. Default 10.

10
json_progress_log_file str | None

Optional path to an NDJSON file where all progress events are also appended. Allows the frontend to replay missed events after a disconnect. Default None.

None
**trainer_kwargs Any

Any other pl.Trainer argument.

{}
Example

Default: auto-tuning enabled (both LR and batch size)

trainer = AutoTrainer(max_epochs=10) trainer.fit(model, datamodule=data)

Disable all auto-tuning

trainer = AutoTrainer(max_epochs=10, tuner_config=False) trainer.fit(model, datamodule=data)

Custom tuning configuration

trainer = AutoTrainer( ... max_epochs=10, ... tuner_config=TunerConfig( ... auto_lr=True, ... auto_batch_size=False, ... lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1.0}, ... ), ... ) trainer.fit(model, datamodule=data)

Quick debugging with fast_dev_run (auto-tuning disabled automatically)

trainer = AutoTrainer(fast_dev_run=True) trainer.fit(model, datamodule=data) # Runs 1 batch only

Custom seeding for reproducibility

trainer = AutoTrainer(max_epochs=10, seed=123, deterministic=True) trainer.fit(model, datamodule=data)

Use Lightning's built-in seeding instead of AutoTimm's

trainer = AutoTrainer(max_epochs=10, seed=42, use_autotimm_seeding=False) trainer.fit(model, datamodule=data)

Disable seeding completely

trainer = AutoTrainer(max_epochs=10, seed=None) trainer.fit(model, datamodule=data)

Source code in src/autotimm/training/trainer.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
class AutoTrainer(pl.Trainer):
    """A configured ``pl.Trainer`` with sensible defaults for autotimm.

    This is a convenience class that wires up the logger, checkpointing,
    and automatic hyperparameter tuning. All ``**trainer_kwargs`` are
    forwarded to ``pl.Trainer``, so any Lightning Trainer argument works.

    **Auto-tuning is enabled by default** - both learning rate and batch size
    finding are enabled unless explicitly disabled via ``tuner_config``.

    Parameters:
        max_epochs: Number of training epochs.
        accelerator: ``"auto"``, ``"gpu"``, ``"cpu"``, ``"tpu"``, etc.
        devices: Number of devices or ``"auto"``.
        precision: Training precision (``32``, ``16``, ``"bf16-mixed"``, etc.).
        logger: A ``LoggerManager`` instance, a list of ``LoggerConfig``
            objects, a pre-built Logger instance/list, or ``False`` to
            disable logging.
        tuner_config: A ``TunerConfig`` instance to configure automatic learning
            rate and/or batch size finding. If ``None``, a default ``TunerConfig()``
            is created with both auto_lr and auto_batch_size enabled.
            To disable auto-tuning completely, pass ``False``.
        checkpoint_monitor: Metric to monitor for checkpointing (e.g.,
            ``"val/accuracy"``). If ``None``, no automatic checkpoint
            callback is added.
        checkpoint_mode: One of ``"min"`` or ``"max"`` for checkpoint
            monitoring.
        callbacks: List of Lightning callbacks.
        default_root_dir: Root directory for logs and checkpoints.
        gradient_clip_val: Gradient clipping value.
        accumulate_grad_batches: Gradient accumulation steps.
        val_check_interval: How often to run validation (1.0 = every epoch).
        enable_checkpointing: Whether to save model checkpoints.
        fast_dev_run: Runs a single batch through train, val, and test to
            find bugs quickly. Can be ``True`` (1 batch), ``False`` (disabled),
            or an integer (number of batches). Useful for debugging.
        seed: Random seed for reproducibility. If ``None``, no seeding is performed.
            Default is ``42`` for reproducible results.
        deterministic: If ``True`` (default), enables deterministic algorithms in PyTorch
            for full reproducibility (may impact performance). Set to ``False`` for faster training.
        use_autotimm_seeding: If ``True``, uses AutoTimm's custom ``seed_everything()``
            function which provides comprehensive seeding with deterministic mode support.
            If ``False`` (default), uses PyTorch Lightning's built-in seeding.
        json_progress: If ``True``, appends a ``JsonProgressCallback`` that
            emits NDJSON progress events to stdout for consumption by a
            Tauri/Preact frontend.  Default ``False``.
        json_progress_every_n_steps: How often to emit ``batch_end`` events
            when ``json_progress`` is enabled.  Default ``10``.
        json_progress_log_file: Optional path to an NDJSON file where all
            progress events are also appended.  Allows the frontend to
            replay missed events after a disconnect.  Default ``None``.
        **trainer_kwargs: Any other ``pl.Trainer`` argument.

    Example:
        >>> # Default: auto-tuning enabled (both LR and batch size)
        >>> trainer = AutoTrainer(max_epochs=10)
        >>> trainer.fit(model, datamodule=data)

        >>> # Disable all auto-tuning
        >>> trainer = AutoTrainer(max_epochs=10, tuner_config=False)
        >>> trainer.fit(model, datamodule=data)

        >>> # Custom tuning configuration
        >>> trainer = AutoTrainer(
        ...     max_epochs=10,
        ...     tuner_config=TunerConfig(
        ...         auto_lr=True,
        ...         auto_batch_size=False,
        ...         lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1.0},
        ...     ),
        ... )
        >>> trainer.fit(model, datamodule=data)

        >>> # Quick debugging with fast_dev_run (auto-tuning disabled automatically)
        >>> trainer = AutoTrainer(fast_dev_run=True)
        >>> trainer.fit(model, datamodule=data)  # Runs 1 batch only

        >>> # Custom seeding for reproducibility
        >>> trainer = AutoTrainer(max_epochs=10, seed=123, deterministic=True)
        >>> trainer.fit(model, datamodule=data)

        >>> # Use Lightning's built-in seeding instead of AutoTimm's
        >>> trainer = AutoTrainer(max_epochs=10, seed=42, use_autotimm_seeding=False)
        >>> trainer.fit(model, datamodule=data)

        >>> # Disable seeding completely
        >>> trainer = AutoTrainer(max_epochs=10, seed=None)
        >>> trainer.fit(model, datamodule=data)
    """

    def __init__(
        self,
        max_epochs: int = 10,
        accelerator: str = "auto",
        devices: int | str = "auto",
        precision: str | int = 32,
        logger: (
            LoggerManager
            | list[LoggerConfig]
            | pl.loggers.Logger
            | list[pl.loggers.Logger]
            | bool
        ) = False,
        tuner_config: TunerConfig | None | bool = None,
        checkpoint_monitor: str | None = "val/loss",
        checkpoint_mode: str = "min",
        callbacks: list[pl.Callback] | None = None,
        default_root_dir: str = ".",
        gradient_clip_val: float | None = None,
        accumulate_grad_batches: int = 1,
        val_check_interval: float | int = 1.0,
        enable_checkpointing: bool = True,
        fast_dev_run: bool | int = False,
        seed: int | None = None,
        deterministic: bool = False,
        use_autotimm_seeding: bool = False,
        json_progress: bool = False,
        json_progress_every_n_steps: int = 10,
        json_progress_log_file: str | None = None,
        **trainer_kwargs: Any,
    ) -> None:
        # Print environment watermark on first AutoTrainer instantiation
        _print_watermark()

        # Store seeding params — actual seeding is deferred to fit()
        self._seed = seed
        self._deterministic = deterministic
        self._use_autotimm_seeding = use_autotimm_seeding

        if isinstance(logger, LoggerManager):
            resolved_logger = logger.loggers
        elif (
            isinstance(logger, list) and logger and isinstance(logger[0], LoggerConfig)
        ):
            manager = LoggerManager(configs=logger)
            resolved_logger = manager.loggers
        else:
            resolved_logger = logger

        if callbacks is None:
            callbacks = []

        if enable_checkpointing and checkpoint_monitor:
            has_checkpoint_cb = any(
                isinstance(cb, pl.callbacks.ModelCheckpoint) for cb in callbacks
            )
            if not has_checkpoint_cb:
                # Replace '/' in monitor name to avoid creating subdirectories
                safe_monitor = checkpoint_monitor.replace("/", "_")
                callbacks.append(
                    pl.callbacks.ModelCheckpoint(
                        monitor=checkpoint_monitor,
                        mode=checkpoint_mode,
                        save_top_k=1,
                        filename=f"best-{{epoch}}-{{{safe_monitor}:.4f}}",
                        auto_insert_metric_name=False,
                    )
                )

        # Store json_progress flag and log file for tuning event emission
        self._json_progress = json_progress
        self._json_progress_log_file = json_progress_log_file

        if json_progress:
            callbacks.append(
                JsonProgressCallback(
                    emit_every_n_steps=json_progress_every_n_steps,
                    log_file=json_progress_log_file,
                )
            )

        # Configure auto-tuning behavior
        # Disable auto-tuning during fast_dev_run or if explicitly set to False
        if fast_dev_run or tuner_config is False:
            self._tuner_config = None
        elif tuner_config is None or tuner_config is True:
            # Enable auto-tuning by default with sensible defaults.
            # Batch size finder is only useful on GPU instances where VRAM
            # probing is meaningful; disable it on CPU / MPS to avoid
            # unnecessary overhead or failures.
            import torch

            has_gpu = torch.cuda.is_available()
            self._tuner_config = TunerConfig(auto_batch_size=has_gpu)
        elif isinstance(tuner_config, TunerConfig):
            self._tuner_config = tuner_config
        else:
            raise TypeError(
                f"tuner_config must be a TunerConfig, None, True, or False, "
                f"got {type(tuner_config).__name__}"
            )

        # Guard flag: Lightning's Tuner internally calls trainer.fit(), which would
        # re-enter our fit() override and trigger a recursive tuning loop. This flag
        # breaks the cycle so that the inner fit() call skips _run_tuning().
        self._is_tuning = False

        super().__init__(
            max_epochs=max_epochs,
            accelerator=accelerator,
            devices=devices,
            precision=precision,
            logger=resolved_logger,
            callbacks=callbacks,
            default_root_dir=default_root_dir,
            gradient_clip_val=gradient_clip_val,
            accumulate_grad_batches=accumulate_grad_batches,
            val_check_interval=val_check_interval,
            enable_checkpointing=enable_checkpointing,
            fast_dev_run=fast_dev_run,
            **trainer_kwargs,
        )

    def fit(
        self,
        model: pl.LightningModule,
        train_dataloaders: Any = None,
        val_dataloaders: Any = None,
        datamodule: pl.LightningDataModule | None = None,
        ckpt_path: str | None = None,
    ) -> None:
        """Fit the model, optionally running LR/batch size tuning first.

        If ``tuner_config`` was provided with ``auto_lr=True`` or
        ``auto_batch_size=True``, the respective tuning will run before
        training begins.

        Parameters:
            model: The LightningModule to train.
            train_dataloaders: Train dataloaders (if not using datamodule).
            val_dataloaders: Validation dataloaders (if not using datamodule).
            datamodule: A LightningDataModule instance.
            ckpt_path: Path to checkpoint for resuming training.
        """
        # Ensure fork-based multiprocessing on macOS to avoid spawn guard issues
        _ensure_safe_multiprocessing()

        # Apply seeding at the start of fit() so trainer seed is authoritative
        self._apply_seeding()

        if self._tuner_config is not None and not self._is_tuning:
            self._is_tuning = True
            try:
                if self._json_progress:
                    _emit({"event": "tuning_started"}, log_file=self._json_progress_log_file)
                self._run_tuning(model, train_dataloaders, val_dataloaders, datamodule)
                if self._json_progress:
                    _emit({"event": "tuning_complete"}, log_file=self._json_progress_log_file)
            finally:
                self._is_tuning = False

        super().fit(
            model,
            train_dataloaders=train_dataloaders,
            val_dataloaders=val_dataloaders,
            datamodule=datamodule,
            ckpt_path=ckpt_path,
        )

    def test(
        self,
        model: pl.LightningModule | None = None,
        dataloaders: Any = None,
        datamodule: pl.LightningDataModule | None = None,
        ckpt_path: str | None = "best",
        verbose: bool = True,
    ) -> list[dict[str, float]]:
        """Test the model, logging results to configured loggers (CSV, JSONL, etc.).

        Ensures multiprocessing safety and seeding are applied before
        running the test loop. All configured loggers (CSVLogger,
        TensorBoard, etc.) and callbacks (``JsonProgressCallback``)
        remain active during testing.

        Parameters:
            model: The LightningModule to test. If ``None``, uses the
                model from the last ``fit()`` call.
            dataloaders: Test dataloaders (if not using datamodule).
            datamodule: A LightningDataModule instance.
            ckpt_path: Path to checkpoint to load for testing.
                Defaults to ``"best"`` (best checkpoint from training).
            verbose: Whether to print test results.

        Returns:
            List of dictionaries containing test metrics.
        """
        _ensure_safe_multiprocessing()
        self._apply_seeding()

        # Resolve ckpt_path="best" robustly — handles both in-process
        # (fit then test) and separate-process (CLI test after fit) scenarios.
        if ckpt_path == "best":
            ckpt_path = self._resolve_best_ckpt_path()

        return super().test(
            model=model,
            dataloaders=dataloaders,
            datamodule=datamodule,
            ckpt_path=ckpt_path,
            verbose=verbose,
        )

    def _resolve_best_ckpt_path(self) -> str | None:
        """Resolve ``ckpt_path="best"`` to an actual file path.

        Handles three scenarios:
        1. A ``ModelCheckpoint`` callback has ``best_model_path`` populated
           (in-process fit→test) — returns that path directly.
        2. No ``ModelCheckpoint`` with a monitor is configured — falls back
           to ``None`` (current model weights).
        3. A ``ModelCheckpoint`` is configured but ``best_model_path`` is empty
           (separate-process test after fit, e.g. NightFlow CLI) — searches
           the checkpoint directory on disk for ``best-*.ckpt`` files and
           returns the most recent one.
        """
        from pathlib import Path

        ckpt_cb = None
        for cb in (self.callbacks or []):
            if isinstance(cb, pl.callbacks.ModelCheckpoint) and cb.monitor is not None:
                ckpt_cb = cb
                break

        if ckpt_cb is None:
            logger.warning(
                'No ModelCheckpoint configured with a monitored metric. '
                'Falling back to ckpt_path=None (current model weights) '
                'instead of ckpt_path="best".'
            )
            return None

        # In-process: fit() already set best_model_path
        if ckpt_cb.best_model_path:
            return "best"

        # Separate-process: search for checkpoint files on disk
        ckpt_dir = Path(ckpt_cb.dirpath) if ckpt_cb.dirpath else None

        # If dirpath is not set, try the default Lightning checkpoint dir
        if ckpt_dir is None:
            root = Path(self.default_root_dir)
            # Check common checkpoint locations
            for candidate in [
                root / "checkpoints",
                root,
            ]:
                if candidate.is_dir():
                    ckpts = sorted(candidate.glob("best-*.ckpt"), key=lambda p: p.stat().st_mtime)
                    if ckpts:
                        ckpt_dir = candidate
                        break

            # Also check logger-specific directories
            if ckpt_dir is None and self.loggers:
                for lg in self.loggers:
                    log_dir = getattr(lg, "log_dir", None)
                    if log_dir:
                        candidate = Path(log_dir) / "checkpoints"
                        if candidate.is_dir():
                            ckpts = sorted(candidate.glob("best-*.ckpt"), key=lambda p: p.stat().st_mtime)
                            if ckpts:
                                ckpt_dir = candidate
                                break

        if ckpt_dir and ckpt_dir.is_dir():
            ckpts = sorted(ckpt_dir.glob("best-*.ckpt"), key=lambda p: p.stat().st_mtime)
            if ckpts:
                best = str(ckpts[-1])
                logger.info(f"Found best checkpoint on disk: {best}")
                return best

        logger.warning(
            'ckpt_path="best" requested but no checkpoint file found on disk. '
            'Falling back to ckpt_path=None (current model weights).'
        )
        return None

    def _apply_seeding(self) -> None:
        """Apply seeding based on stored parameters."""
        if self._seed is not None and self._use_autotimm_seeding:
            seed_everything(self._seed, deterministic=self._deterministic)
        elif self._seed is not None:
            import pytorch_lightning as pl_seed

            pl_seed.seed_everything(self._seed, workers=True)
            if self._deterministic:
                import torch

                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False
        elif self._seed is None and self._deterministic:
            import warnings

            warnings.warn(
                "deterministic=True has no effect when seed=None. "
                "Set a seed value to enable deterministic behavior.",
                stacklevel=2,
            )

    def _remove_tuner_callbacks(self) -> None:
        """Remove BatchSizeFinder/LearningRateFinder callbacks injected by Tuner.

        In Lightning >= 2.x, Tuner methods permanently attach their callbacks
        to the trainer. Calling both scale_batch_size and lr_find sequentially
        (or calling fit() more than once) triggers a conflict. This method
        cleans them up after each tuning step.
        """
        try:
            from pytorch_lightning.callbacks import BatchSizeFinder, LearningRateFinder

            self.callbacks[:] = [
                cb
                for cb in self.callbacks
                if not isinstance(cb, (BatchSizeFinder, LearningRateFinder))
            ]
        except ImportError:
            pass

    def _run_tuning(
        self,
        model: pl.LightningModule,
        train_dataloaders: Any,
        val_dataloaders: Any,
        datamodule: pl.LightningDataModule | None,
    ) -> None:
        """Run automatic hyperparameter tuning."""
        tuner = Tuner(self)
        config = self._tuner_config

        # Run batch size finder first (if enabled)
        # This should run before LR finder since LR depends on batch size
        if config.auto_batch_size:
            logger.info("Running batch size finder...")
            try:
                result = tuner.scale_batch_size(
                    model,
                    train_dataloaders=train_dataloaders,
                    val_dataloaders=val_dataloaders,
                    datamodule=datamodule,
                    **config.batch_size_kwargs,
                )
                logger.info(f"Optimal batch size found: {result}")
            except Exception as e:
                logger.error(f"Batch size finder failed: {e}")
                logger.info("Continuing with user-specified batch size.")
            finally:
                # Remove the BatchSizeFinder callback Tuner injected so the
                # subsequent lr_find call (and any future fit() call) won't
                # see a duplicate and raise an error.
                self._remove_tuner_callbacks()

        # Run LR finder (if enabled)
        if config.auto_lr:
            logger.info("Running learning rate finder...")
            try:
                lr_finder = tuner.lr_find(
                    model,
                    train_dataloaders=train_dataloaders,
                    val_dataloaders=val_dataloaders,
                    datamodule=datamodule,
                    **config.lr_find_kwargs,
                )
                if lr_finder is not None:
                    suggested_lr = lr_finder.suggestion()
                    if suggested_lr is not None:
                        logger.info(f"Suggested learning rate: {suggested_lr:.2e}")
                        # Update model's learning rate
                        if hasattr(model, "_lr"):
                            model._lr = suggested_lr
                        elif hasattr(model, "lr"):
                            model.lr = suggested_lr
                        elif hasattr(model, "hparams") and "lr" in model.hparams:
                            model.hparams.lr = suggested_lr
                    else:
                        logger.warning("LR finder could not suggest a learning rate.")
            except Exception as e:
                logger.error(f"LR finder failed: {e}")
                logger.info("Continuing with user-specified learning rate.")
            finally:
                # Remove the LearningRateFinder callback so fit() starts clean.
                self._remove_tuner_callbacks()

    @property
    def tuner_config(self) -> TunerConfig | None:
        """Return the tuner configuration."""
        return self._tuner_config

tuner_config property

tuner_config: TunerConfig | None

Return the tuner configuration.

__init__

__init__(max_epochs: int = 10, accelerator: str = 'auto', devices: int | str = 'auto', precision: str | int = 32, logger: LoggerManager | list[LoggerConfig] | Logger | list[Logger] | bool = False, tuner_config: TunerConfig | None | bool = None, checkpoint_monitor: str | None = 'val/loss', checkpoint_mode: str = 'min', callbacks: list[Callback] | None = None, default_root_dir: str = '.', gradient_clip_val: float | None = None, accumulate_grad_batches: int = 1, val_check_interval: float | int = 1.0, enable_checkpointing: bool = True, fast_dev_run: bool | int = False, seed: int | None = None, deterministic: bool = False, use_autotimm_seeding: bool = False, json_progress: bool = False, json_progress_every_n_steps: int = 10, json_progress_log_file: str | None = None, **trainer_kwargs: Any) -> None
Source code in src/autotimm/training/trainer.py
def __init__(
    self,
    max_epochs: int = 10,
    accelerator: str = "auto",
    devices: int | str = "auto",
    precision: str | int = 32,
    logger: (
        LoggerManager
        | list[LoggerConfig]
        | pl.loggers.Logger
        | list[pl.loggers.Logger]
        | bool
    ) = False,
    tuner_config: TunerConfig | None | bool = None,
    checkpoint_monitor: str | None = "val/loss",
    checkpoint_mode: str = "min",
    callbacks: list[pl.Callback] | None = None,
    default_root_dir: str = ".",
    gradient_clip_val: float | None = None,
    accumulate_grad_batches: int = 1,
    val_check_interval: float | int = 1.0,
    enable_checkpointing: bool = True,
    fast_dev_run: bool | int = False,
    seed: int | None = None,
    deterministic: bool = False,
    use_autotimm_seeding: bool = False,
    json_progress: bool = False,
    json_progress_every_n_steps: int = 10,
    json_progress_log_file: str | None = None,
    **trainer_kwargs: Any,
) -> None:
    # Print environment watermark on first AutoTrainer instantiation
    _print_watermark()

    # Store seeding params — actual seeding is deferred to fit()
    self._seed = seed
    self._deterministic = deterministic
    self._use_autotimm_seeding = use_autotimm_seeding

    if isinstance(logger, LoggerManager):
        resolved_logger = logger.loggers
    elif (
        isinstance(logger, list) and logger and isinstance(logger[0], LoggerConfig)
    ):
        manager = LoggerManager(configs=logger)
        resolved_logger = manager.loggers
    else:
        resolved_logger = logger

    if callbacks is None:
        callbacks = []

    if enable_checkpointing and checkpoint_monitor:
        has_checkpoint_cb = any(
            isinstance(cb, pl.callbacks.ModelCheckpoint) for cb in callbacks
        )
        if not has_checkpoint_cb:
            # Replace '/' in monitor name to avoid creating subdirectories
            safe_monitor = checkpoint_monitor.replace("/", "_")
            callbacks.append(
                pl.callbacks.ModelCheckpoint(
                    monitor=checkpoint_monitor,
                    mode=checkpoint_mode,
                    save_top_k=1,
                    filename=f"best-{{epoch}}-{{{safe_monitor}:.4f}}",
                    auto_insert_metric_name=False,
                )
            )

    # Store json_progress flag and log file for tuning event emission
    self._json_progress = json_progress
    self._json_progress_log_file = json_progress_log_file

    if json_progress:
        callbacks.append(
            JsonProgressCallback(
                emit_every_n_steps=json_progress_every_n_steps,
                log_file=json_progress_log_file,
            )
        )

    # Configure auto-tuning behavior
    # Disable auto-tuning during fast_dev_run or if explicitly set to False
    if fast_dev_run or tuner_config is False:
        self._tuner_config = None
    elif tuner_config is None or tuner_config is True:
        # Enable auto-tuning by default with sensible defaults.
        # Batch size finder is only useful on GPU instances where VRAM
        # probing is meaningful; disable it on CPU / MPS to avoid
        # unnecessary overhead or failures.
        import torch

        has_gpu = torch.cuda.is_available()
        self._tuner_config = TunerConfig(auto_batch_size=has_gpu)
    elif isinstance(tuner_config, TunerConfig):
        self._tuner_config = tuner_config
    else:
        raise TypeError(
            f"tuner_config must be a TunerConfig, None, True, or False, "
            f"got {type(tuner_config).__name__}"
        )

    # Guard flag: Lightning's Tuner internally calls trainer.fit(), which would
    # re-enter our fit() override and trigger a recursive tuning loop. This flag
    # breaks the cycle so that the inner fit() call skips _run_tuning().
    self._is_tuning = False

    super().__init__(
        max_epochs=max_epochs,
        accelerator=accelerator,
        devices=devices,
        precision=precision,
        logger=resolved_logger,
        callbacks=callbacks,
        default_root_dir=default_root_dir,
        gradient_clip_val=gradient_clip_val,
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        enable_checkpointing=enable_checkpointing,
        fast_dev_run=fast_dev_run,
        **trainer_kwargs,
    )

fit

fit(model: LightningModule, train_dataloaders: Any = None, val_dataloaders: Any = None, datamodule: LightningDataModule | None = None, ckpt_path: str | None = None) -> None

Fit the model, optionally running LR/batch size tuning first.

If tuner_config was provided with auto_lr=True or auto_batch_size=True, the respective tuning will run before training begins.

Parameters:

Name Type Description Default
model LightningModule

The LightningModule to train.

required
train_dataloaders Any

Train dataloaders (if not using datamodule).

None
val_dataloaders Any

Validation dataloaders (if not using datamodule).

None
datamodule LightningDataModule | None

A LightningDataModule instance.

None
ckpt_path str | None

Path to checkpoint for resuming training.

None
Source code in src/autotimm/training/trainer.py
def fit(
    self,
    model: pl.LightningModule,
    train_dataloaders: Any = None,
    val_dataloaders: Any = None,
    datamodule: pl.LightningDataModule | None = None,
    ckpt_path: str | None = None,
) -> None:
    """Fit the model, optionally running LR/batch size tuning first.

    If ``tuner_config`` was provided with ``auto_lr=True`` or
    ``auto_batch_size=True``, the respective tuning will run before
    training begins.

    Parameters:
        model: The LightningModule to train.
        train_dataloaders: Train dataloaders (if not using datamodule).
        val_dataloaders: Validation dataloaders (if not using datamodule).
        datamodule: A LightningDataModule instance.
        ckpt_path: Path to checkpoint for resuming training.
    """
    # Ensure fork-based multiprocessing on macOS to avoid spawn guard issues
    _ensure_safe_multiprocessing()

    # Apply seeding at the start of fit() so trainer seed is authoritative
    self._apply_seeding()

    if self._tuner_config is not None and not self._is_tuning:
        self._is_tuning = True
        try:
            if self._json_progress:
                _emit({"event": "tuning_started"}, log_file=self._json_progress_log_file)
            self._run_tuning(model, train_dataloaders, val_dataloaders, datamodule)
            if self._json_progress:
                _emit({"event": "tuning_complete"}, log_file=self._json_progress_log_file)
        finally:
            self._is_tuning = False

    super().fit(
        model,
        train_dataloaders=train_dataloaders,
        val_dataloaders=val_dataloaders,
        datamodule=datamodule,
        ckpt_path=ckpt_path,
    )

Usage Examples

Basic Training

from autotimm import AutoTrainer

trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
trainer.test(model, datamodule=data)

With Logging and Checkpointing

from autotimm import AutoTrainer, LoggerConfig

trainer = AutoTrainer(
    max_epochs=10,
    logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
    checkpoint_monitor="val/accuracy",
    checkpoint_mode="max",
)

GPU Training with Mixed Precision

trainer = AutoTrainer(
    max_epochs=10,
    accelerator="gpu",
    devices=1,
    precision="bf16-mixed",
)

Multi-GPU Training

trainer = AutoTrainer(
    max_epochs=10,
    accelerator="gpu",
    devices=2,
    strategy="ddp",
)

With Auto-Tuning

from autotimm import AutoTrainer, TunerConfig

trainer = AutoTrainer(
    max_epochs=10,
    tuner_config=TunerConfig(
        auto_lr=True,
        auto_batch_size=True,
    ),
)
trainer.fit(model, datamodule=data)  # Runs tuning first

With Reproducibility Settings

# Default: seed=None, deterministic=True
trainer = AutoTrainer(max_epochs=10)

# Custom seed
trainer = AutoTrainer(max_epochs=10, seed=123)

# Faster training (disable deterministic mode)
trainer = AutoTrainer(max_epochs=10, deterministic=False)

# Use AutoTimm's custom seeding instead of Lightning's
trainer = AutoTrainer(max_epochs=10, use_autotimm_seeding=True)

# Disable seeding completely (set deterministic=False to avoid warning)
trainer = AutoTrainer(max_epochs=10, seed=None, deterministic=False)

With Gradient Accumulation

trainer = AutoTrainer(
    max_epochs=10,
    accumulate_grad_batches=4,
    gradient_clip_val=1.0,
)

Fast Development Run

# Run 1 batch for quick debugging
trainer = AutoTrainer(fast_dev_run=True)
trainer.fit(model, datamodule=data)

# Run 5 batches for testing
trainer = AutoTrainer(fast_dev_run=5)
trainer.fit(model, datamodule=data)

Parameters

Parameter Type Default Description
max_epochs int 10 Training epochs
accelerator str "auto" "auto", "gpu", "cpu", "tpu"
devices int \| str "auto" Number of devices
precision str \| int 32 32, 16, "bf16-mixed", "16-mixed"
logger Various False Logger configuration
tuner_config TunerConfig \| None \| bool None Auto-tuning config. None/True creates default config, False disables
seed int \| None None Random seed for reproducibility. Set to None to disable seeding
deterministic bool False Enable deterministic algorithms for reproducibility. May impact performance
use_autotimm_seeding bool False Use AutoTimm's seed_everything() instead of Lightning's built-in seeding
checkpoint_monitor str \| None "val/loss" Metric for checkpointing
checkpoint_mode str "min" "max" or "min"
callbacks list \| None None Lightning callbacks
default_root_dir str "." Root directory
gradient_clip_val float \| None None Gradient clipping
accumulate_grad_batches int 1 Gradient accumulation
val_check_interval float \| int 1.0 Validation frequency
enable_checkpointing bool True Save checkpoints
fast_dev_run bool \| int False Run N batches for debugging
json_progress bool False Enable JSON progress reporting
json_progress_every_n_steps int 10 JSON progress reporting frequency
json_progress_log_file str \| None None File path for JSON progress logs

TunerConfig

Configuration for automatic hyperparameter tuning (LR and batch size finding).

API Reference

autotimm.TunerConfig dataclass

Configuration for automatic hyperparameter tuning.

Parameters:

Name Type Description Default
auto_lr bool

Whether to use the learning rate finder before training. If True, the optimal learning rate will be found and applied. If False, the user-specified learning rate is used. Default: True.

True
auto_batch_size bool

Whether to use the batch size finder before training. If True, the optimal batch size will be found and applied. If False, the user-specified batch size is used. Default: True.

True
lr_find_kwargs dict[str, Any] | None

Additional kwargs passed to Tuner.lr_find(). Common options: min_lr, max_lr, num_training, mode ("exponential" or "linear"), early_stop_threshold. Default values are set if not provided.

None
batch_size_kwargs dict[str, Any] | None

Additional kwargs passed to Tuner.scale_batch_size(). Common options: mode ("power" or "binsearch"), steps_per_trial, init_val, max_trials. Default values are set if not provided.

None
Example

Use defaults (both auto_lr and auto_batch_size enabled)

config = TunerConfig()

Disable auto-tuning

config = TunerConfig(auto_lr=False, auto_batch_size=False)

Custom configuration

config = TunerConfig( ... auto_lr=True, ... auto_batch_size=True, ... lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1e-1, "num_training": 100}, ... batch_size_kwargs={"mode": "power", "init_val": 16}, ... )

Source code in src/autotimm/training/trainer.py
@dataclass
class TunerConfig:
    """Configuration for automatic hyperparameter tuning.

    Parameters:
        auto_lr: Whether to use the learning rate finder before training.
            If ``True``, the optimal learning rate will be found and applied.
            If ``False``, the user-specified learning rate is used.
            Default: ``True``.
        auto_batch_size: Whether to use the batch size finder before training.
            If ``True``, the optimal batch size will be found and applied.
            If ``False``, the user-specified batch size is used.
            Default: ``True``.
        lr_find_kwargs: Additional kwargs passed to ``Tuner.lr_find()``.
            Common options: ``min_lr``, ``max_lr``, ``num_training``,
            ``mode`` ("exponential" or "linear"), ``early_stop_threshold``.
            Default values are set if not provided.
        batch_size_kwargs: Additional kwargs passed to ``Tuner.scale_batch_size()``.
            Common options: ``mode`` ("power" or "binsearch"), ``steps_per_trial``,
            ``init_val``, ``max_trials``.
            Default values are set if not provided.

    Example:
        >>> # Use defaults (both auto_lr and auto_batch_size enabled)
        >>> config = TunerConfig()

        >>> # Disable auto-tuning
        >>> config = TunerConfig(auto_lr=False, auto_batch_size=False)

        >>> # Custom configuration
        >>> config = TunerConfig(
        ...     auto_lr=True,
        ...     auto_batch_size=True,
        ...     lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1e-1, "num_training": 100},
        ...     batch_size_kwargs={"mode": "power", "init_val": 16},
        ... )
    """

    auto_lr: bool = True
    auto_batch_size: bool = True
    lr_find_kwargs: dict[str, Any] | None = None
    batch_size_kwargs: dict[str, Any] | None = None

    def __post_init__(self) -> None:
        # Set sensible defaults for LR finder if not provided
        if self.lr_find_kwargs is None:
            self.lr_find_kwargs = {
                "min_lr": 1e-7,
                "max_lr": 1.0,
                "num_training": 100,
                "mode": "exponential",
                "early_stop_threshold": 4.0,
            }

        # Set sensible defaults for batch size finder if not provided
        if self.batch_size_kwargs is None:
            self.batch_size_kwargs = {
                "mode": "power",
                "steps_per_trial": 3,
                "init_val": 16,
                "max_trials": 25,
            }

Usage Examples

LR Finding Only

from autotimm import TunerConfig

config = TunerConfig(
    auto_lr=True,
    auto_batch_size=False,
)

Full Auto-Tuning

config = TunerConfig(
    auto_lr=True,
    auto_batch_size=True,
    lr_find_kwargs={
        "min_lr": 1e-6,
        "max_lr": 1.0,
        "num_training": 100,
    },
    batch_size_kwargs={
        "mode": "power",
        "init_val": 16,
    },
)

Custom LR Finder Settings

config = TunerConfig(
    auto_lr=True,
    auto_batch_size=False,
    lr_find_kwargs={
        "min_lr": 1e-7,
        "max_lr": 10.0,
        "num_training": 200,
        "mode": "exponential",
        "early_stop_threshold": 4.0,
    },
)

Custom Batch Size Finder Settings

config = TunerConfig(
    auto_lr=False,
    auto_batch_size=True,
    batch_size_kwargs={
        "mode": "binsearch",      # or "power"
        "steps_per_trial": 3,
        "init_val": 32,
        "max_trials": 25,
    },
)

Parameters

Parameter Type Default Description
auto_lr bool Required Enable LR finder
auto_batch_size bool Required Enable batch size finder
lr_find_kwargs dict \| None None LR finder options
batch_size_kwargs dict \| None None Batch size finder options

LR Finder Options (lr_find_kwargs)

Option Default Description
min_lr 1e-8 Minimum learning rate
max_lr 1.0 Maximum learning rate
num_training 100 Training steps
mode "exponential" "exponential" or "linear"
early_stop_threshold 4.0 Stop if loss exceeds factor

Batch Size Finder Options (batch_size_kwargs)

Option Default Description
mode "power" "power" (2x) or "binsearch"
steps_per_trial 3 Steps to run per trial
init_val 2 Initial batch size
max_trials 25 Maximum trials

Complete Example

import autotimm as at  # recommended alias
from autotimm import (
    AutoTrainer,
    ImageClassifier,
    ImageDataModule,
    LoggerConfig,
    MetricConfig,
    TunerConfig,
)
from pytorch_lightning.callbacks import EarlyStopping

# Data
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    image_size=224,
    batch_size=64,
)

# Metrics
metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
        prog_bar=True,
    ),
]

# Model
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    metrics=metrics,
)

# Trainer with all features
trainer = AutoTrainer(
    max_epochs=50,
    accelerator="gpu",
    devices=1,
    precision="bf16-mixed",
    logger=[
        LoggerConfig(backend="tensorboard", params={"save_dir": "logs/tb"}),
        LoggerConfig(backend="csv", params={"save_dir": "logs/csv"}),
    ],
    tuner_config=TunerConfig(
        auto_lr=True,
        auto_batch_size=False,
        lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1.0},
    ),
    checkpoint_monitor="val/accuracy",
    checkpoint_mode="max",
    gradient_clip_val=1.0,
    callbacks=[
        EarlyStopping(monitor="val/loss", patience=10, mode="min"),
    ],
)

# Train
trainer.fit(model, datamodule=data)
trainer.test(model, datamodule=data)