Skip to content

Metrics

Metric configuration and management for training.

MetricConfig

Configuration for a single metric.

API Reference

autotimm.MetricConfig dataclass

Configuration for a single metric.

All parameters are required - no defaults are provided to ensure explicit configuration.

Parameters:

Name Type Description Default
name str

Unique identifier for this metric (used in logging).

required
backend str

Metric backend type. One of "torchmetrics" or "custom".

required
metric_class str

The metric class name (for torchmetrics/custom).

required
params dict[str, Any]

Parameters passed to the metric constructor/function.

required
stages list[str]

List of stages where this metric applies: "train", "val", "test".

required
log_on_step bool

Whether to log on each step.

False
log_on_epoch bool

Whether to log on epoch end.

True
prog_bar bool

Whether to show in progress bar.

False
Example

config = MetricConfig( ... name="accuracy", ... backend="torchmetrics", ... metric_class="Accuracy", ... params={"task": "multiclass"}, ... stages=["train", "val", "test"], ... prog_bar=True, ... )

Source code in src/autotimm/core/metrics.py
@dataclass
class MetricConfig:
    """Configuration for a single metric.

    All parameters are required - no defaults are provided to ensure
    explicit configuration.

    Parameters:
        name: Unique identifier for this metric (used in logging).
        backend: Metric backend type. One of ``"torchmetrics"`` or ``"custom"``.
        metric_class: The metric class name (for torchmetrics/custom).
        params: Parameters passed to the metric constructor/function.
        stages: List of stages where this metric applies: ``"train"``,
            ``"val"``, ``"test"``.
        log_on_step: Whether to log on each step.
        log_on_epoch: Whether to log on epoch end.
        prog_bar: Whether to show in progress bar.

    Example:
        >>> config = MetricConfig(
        ...     name="accuracy",
        ...     backend="torchmetrics",
        ...     metric_class="Accuracy",
        ...     params={"task": "multiclass"},
        ...     stages=["train", "val", "test"],
        ...     prog_bar=True,
        ... )
    """

    name: str
    backend: str
    metric_class: str
    params: dict[str, Any]
    stages: list[str]
    log_on_step: bool = False
    log_on_epoch: bool = True
    prog_bar: bool = False

    def __post_init__(self) -> None:
        if not self.name:
            raise ValueError("name is required")
        if not self.backend:
            raise ValueError("backend is required")
        if not self.metric_class:
            raise ValueError("metric_class is required")
        if not self.stages:
            raise ValueError("stages is required (e.g., ['train', 'val', 'test'])")

        self.backend = self.backend.lower()
        valid_backends = {"torchmetrics", "custom"}
        if self.backend not in valid_backends:
            raise ValueError(
                f"Unknown backend '{self.backend}'. "
                f"Valid backends: {', '.join(sorted(valid_backends))}"
            )

        valid_stages = {"train", "val", "test"}
        for stage in self.stages:
            if stage.lower() not in valid_stages:
                raise ValueError(
                    f"Unknown stage '{stage}'. "
                    f"Valid stages: {', '.join(sorted(valid_stages))}"
                )
        self.stages = [s.lower() for s in self.stages]

Usage Examples

Basic Accuracy

from autotimm import MetricConfig

accuracy = MetricConfig(
    name="accuracy",
    backend="torchmetrics",
    metric_class="Accuracy",
    params={"task": "multiclass"},
    stages=["train", "val", "test"],
    prog_bar=True,
)

F1 Score

f1 = MetricConfig(
    name="f1",
    backend="torchmetrics",
    metric_class="F1Score",
    params={"task": "multiclass", "average": "macro"},
    stages=["val", "test"],
)

Top-K Accuracy

top5 = MetricConfig(
    name="top5_accuracy",
    backend="torchmetrics",
    metric_class="Accuracy",
    params={"task": "multiclass", "top_k": 5},
    stages=["val", "test"],
)

Custom Metric

custom = MetricConfig(
    name="custom",
    backend="custom",
    metric_class="mypackage.metrics.CustomMetric",
    params={"threshold": 0.5},
    stages=["val"],
)

Parameters

Parameter Type Default Description
name str Required Unique identifier
backend str Required "torchmetrics" or "custom"
metric_class str Required Class name or full path
params dict Required Constructor parameters
stages list[str] Required ["train", "val", "test"]
log_on_step bool False Log each step
log_on_epoch bool True Log each epoch
prog_bar bool False Show in progress bar

MetricManager

Manages multiple metrics across training stages.

API Reference

autotimm.MetricManager

Manages multiple metrics for training/validation/testing.

This class creates and manages metric instances from explicit configurations. No default values are provided - all configuration must be specified.

Parameters:

Name Type Description Default
configs list[MetricConfig]

List of MetricConfig objects defining each metric.

required
num_classes int

Number of classes (required for classification metrics).

required

Attributes:

Name Type Description
train_metrics

Dict of metrics for training stage.

val_metrics

Dict of metrics for validation stage.

test_metrics

Dict of metrics for test stage.

Example

manager = MetricManager( ... configs=[ ... MetricConfig( ... name="accuracy", ... backend="torchmetrics", ... metric_class="Accuracy", ... params={"task": "multiclass"}, ... stages=["train", "val", "test"], ... prog_bar=True, ... ), ... MetricConfig( ... name="f1", ... backend="torchmetrics", ... metric_class="F1Score", ... params={"task": "multiclass", "average": "macro"}, ... stages=["val"], ... ), ... ], ... num_classes=10, ... ) train_metrics = manager.get_train_metrics()

Source code in src/autotimm/core/metrics.py
class MetricManager:
    """Manages multiple metrics for training/validation/testing.

    This class creates and manages metric instances from explicit configurations.
    No default values are provided - all configuration must be specified.

    Parameters:
        configs: List of ``MetricConfig`` objects defining each metric.
        num_classes: Number of classes (required for classification metrics).

    Attributes:
        train_metrics: Dict of metrics for training stage.
        val_metrics: Dict of metrics for validation stage.
        test_metrics: Dict of metrics for test stage.

    Example:
        >>> manager = MetricManager(
        ...     configs=[
        ...         MetricConfig(
        ...             name="accuracy",
        ...             backend="torchmetrics",
        ...             metric_class="Accuracy",
        ...             params={"task": "multiclass"},
        ...             stages=["train", "val", "test"],
        ...             prog_bar=True,
        ...         ),
        ...         MetricConfig(
        ...             name="f1",
        ...             backend="torchmetrics",
        ...             metric_class="F1Score",
        ...             params={"task": "multiclass", "average": "macro"},
        ...             stages=["val"],
        ...         ),
        ...     ],
        ...     num_classes=10,
        ... )
        >>> train_metrics = manager.get_train_metrics()
    """

    def __init__(self, configs: list[MetricConfig], num_classes: int) -> None:
        if not configs:
            raise ValueError("At least one MetricConfig is required")
        if num_classes <= 0:
            raise ValueError("num_classes must be positive")

        self._configs = configs
        self._num_classes = num_classes

        # Store metrics with their configs by stage
        self._train_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}
        self._val_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}
        self._test_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}

        self._initialize_metrics()

    def _initialize_metrics(self) -> None:
        """Initialize all metric instances from configs."""
        for config in self._configs:
            if "train" in config.stages:
                metric = self._create_metric(config)
                self._train_metrics[config.name] = (metric, config)
            if "val" in config.stages:
                metric = self._create_metric(config)
                self._val_metrics[config.name] = (metric, config)
            if "test" in config.stages:
                metric = self._create_metric(config)
                self._test_metrics[config.name] = (metric, config)

    def _create_metric(self, config: MetricConfig) -> torch.nn.Module:
        """Create a single metric instance from config."""
        backend = config.backend
        params = config.params.copy()

        # Auto-inject num_classes / num_labels if the user didn't provide them.
        # These are tracked separately so _create_torchmetrics_metric can
        # filter them out when the metric constructor doesn't accept them.
        auto_injected: set[str] = set()

        if "num_classes" not in params:
            params["num_classes"] = self._num_classes
            auto_injected.add("num_classes")

        if "num_labels" not in params:
            params["num_labels"] = self._num_classes
            auto_injected.add("num_labels")

        if backend == "torchmetrics":
            return self._create_torchmetrics_metric(
                config.metric_class, params, auto_injected
            )
        elif backend == "custom":
            return self._create_custom_metric(config.metric_class, params)
        else:
            raise ValueError(f"Unknown backend: {backend}")

    def _create_torchmetrics_metric(
        self,
        metric_class: str,
        params: dict[str, Any],
        auto_injected: set[str] | None = None,
    ) -> torchmetrics.Metric:
        """Create a torchmetrics Metric instance."""
        # Try top-level torchmetrics first, then torchmetrics.classification
        if hasattr(torchmetrics, metric_class):
            metric_cls = getattr(torchmetrics, metric_class)
        elif hasattr(torchmetrics.classification, metric_class):
            metric_cls = getattr(torchmetrics.classification, metric_class)
        else:
            raise ValueError(
                f"Unknown torchmetrics metric: {metric_class}. "
                f"Check torchmetrics documentation for available metrics."
            )

        # Try creating the metric.  If it fails due to unexpected kwargs
        # that were auto-injected, remove them and retry.
        try:
            return metric_cls(**params)
        except (TypeError, ValueError):
            if not auto_injected:
                raise
            # Remove auto-injected params and retry
            filtered = {k: v for k, v in params.items() if k not in auto_injected}
            return metric_cls(**filtered)

    def _create_custom_metric(
        self, metric_class: str, params: dict[str, Any]
    ) -> torch.nn.Module:
        """Create a custom metric from a fully qualified class path."""
        import importlib

        if "." not in metric_class:
            raise ValueError(
                f"custom metric_class must be a fully qualified path "
                f"(e.g., 'mypackage.metrics.CustomMetric'), got: {metric_class}"
            )

        module_path, class_name = metric_class.rsplit(".", 1)
        module = importlib.import_module(module_path)
        metric_cls = getattr(module, class_name)
        return metric_cls(**params)

    def get_train_metrics(self) -> torch.nn.ModuleDict:
        """Return ModuleDict of train metrics for Lightning module."""
        return torch.nn.ModuleDict(
            {name: metric for name, (metric, _) in self._train_metrics.items()}
        )

    def get_val_metrics(self) -> torch.nn.ModuleDict:
        """Return ModuleDict of validation metrics for Lightning module."""
        return torch.nn.ModuleDict(
            {name: metric for name, (metric, _) in self._val_metrics.items()}
        )

    def get_test_metrics(self) -> torch.nn.ModuleDict:
        """Return ModuleDict of test metrics for Lightning module."""
        return torch.nn.ModuleDict(
            {name: metric for name, (metric, _) in self._test_metrics.items()}
        )

    def get_metric_config(self, stage: str, name: str) -> MetricConfig | None:
        """Get the config for a specific metric.

        Parameters:
            stage: One of "train", "val", "test".
            name: The metric name.

        Returns:
            The MetricConfig if found, None otherwise.
        """
        metrics_dict = {
            "train": self._train_metrics,
            "val": self._val_metrics,
            "test": self._test_metrics,
        }.get(stage, {})

        if name in metrics_dict:
            return metrics_dict[name][1]
        return None

    @property
    def configs(self) -> list[MetricConfig]:
        """Return the configurations used to create the metrics."""
        return self._configs

    @property
    def num_classes(self) -> int:
        """Return the number of classes."""
        return self._num_classes

    def __len__(self) -> int:
        """Return total number of unique metric configs."""
        return len(self._configs)

    def __iter__(self):
        """Iterate over the metric configs."""
        return iter(self._configs)

    def __getitem__(self, index: int) -> MetricConfig:
        """Get a metric config by index."""
        return self._configs[index]

    def get_metric_by_name(
        self, name: str, stage: str | None = None
    ) -> torch.nn.Module | None:
        """Get a metric instance by name.

        Parameters:
            name: The metric name to search for.
            stage: Optional stage to search in ("train", "val", "test").
                If None, searches in order: val, train, test.

        Returns:
            The first matching metric instance, or None if not found.
        """
        if stage is not None:
            metrics_dict = {
                "train": self._train_metrics,
                "val": self._val_metrics,
                "test": self._test_metrics,
            }.get(stage, {})
            if name in metrics_dict:
                return metrics_dict[name][0]
            return None

        # Search in order: val, train, test
        for metrics_dict in [
            self._val_metrics,
            self._train_metrics,
            self._test_metrics,
        ]:
            if name in metrics_dict:
                return metrics_dict[name][0]
        return None

    def get_config_by_name(self, name: str) -> MetricConfig | None:
        """Get a metric config by name.

        Parameters:
            name: The metric name to search for.

        Returns:
            The matching MetricConfig, or None if not found.
        """
        for config in self._configs:
            if config.name == name:
                return config
        return None

configs property

configs: list[MetricConfig]

Return the configurations used to create the metrics.

num_classes property

num_classes: int

Return the number of classes.

__init__

__init__(configs: list[MetricConfig], num_classes: int) -> None
Source code in src/autotimm/core/metrics.py
def __init__(self, configs: list[MetricConfig], num_classes: int) -> None:
    if not configs:
        raise ValueError("At least one MetricConfig is required")
    if num_classes <= 0:
        raise ValueError("num_classes must be positive")

    self._configs = configs
    self._num_classes = num_classes

    # Store metrics with their configs by stage
    self._train_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}
    self._val_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}
    self._test_metrics: dict[str, tuple[torch.nn.Module, MetricConfig]] = {}

    self._initialize_metrics()

get_train_metrics

get_train_metrics() -> torch.nn.ModuleDict

Return ModuleDict of train metrics for Lightning module.

Source code in src/autotimm/core/metrics.py
def get_train_metrics(self) -> torch.nn.ModuleDict:
    """Return ModuleDict of train metrics for Lightning module."""
    return torch.nn.ModuleDict(
        {name: metric for name, (metric, _) in self._train_metrics.items()}
    )

get_val_metrics

get_val_metrics() -> torch.nn.ModuleDict

Return ModuleDict of validation metrics for Lightning module.

Source code in src/autotimm/core/metrics.py
def get_val_metrics(self) -> torch.nn.ModuleDict:
    """Return ModuleDict of validation metrics for Lightning module."""
    return torch.nn.ModuleDict(
        {name: metric for name, (metric, _) in self._val_metrics.items()}
    )

get_test_metrics

get_test_metrics() -> torch.nn.ModuleDict

Return ModuleDict of test metrics for Lightning module.

Source code in src/autotimm/core/metrics.py
def get_test_metrics(self) -> torch.nn.ModuleDict:
    """Return ModuleDict of test metrics for Lightning module."""
    return torch.nn.ModuleDict(
        {name: metric for name, (metric, _) in self._test_metrics.items()}
    )

get_metric_config

get_metric_config(stage: str, name: str) -> MetricConfig | None

Get the config for a specific metric.

Parameters:

Name Type Description Default
stage str

One of "train", "val", "test".

required
name str

The metric name.

required

Returns:

Type Description
MetricConfig | None

The MetricConfig if found, None otherwise.

Source code in src/autotimm/core/metrics.py
def get_metric_config(self, stage: str, name: str) -> MetricConfig | None:
    """Get the config for a specific metric.

    Parameters:
        stage: One of "train", "val", "test".
        name: The metric name.

    Returns:
        The MetricConfig if found, None otherwise.
    """
    metrics_dict = {
        "train": self._train_metrics,
        "val": self._val_metrics,
        "test": self._test_metrics,
    }.get(stage, {})

    if name in metrics_dict:
        return metrics_dict[name][1]
    return None

get_metric_by_name

get_metric_by_name(name: str, stage: str | None = None) -> torch.nn.Module | None

Get a metric instance by name.

Parameters:

Name Type Description Default
name str

The metric name to search for.

required
stage str | None

Optional stage to search in ("train", "val", "test"). If None, searches in order: val, train, test.

None

Returns:

Type Description
Module | None

The first matching metric instance, or None if not found.

Source code in src/autotimm/core/metrics.py
def get_metric_by_name(
    self, name: str, stage: str | None = None
) -> torch.nn.Module | None:
    """Get a metric instance by name.

    Parameters:
        name: The metric name to search for.
        stage: Optional stage to search in ("train", "val", "test").
            If None, searches in order: val, train, test.

    Returns:
        The first matching metric instance, or None if not found.
    """
    if stage is not None:
        metrics_dict = {
            "train": self._train_metrics,
            "val": self._val_metrics,
            "test": self._test_metrics,
        }.get(stage, {})
        if name in metrics_dict:
            return metrics_dict[name][0]
        return None

    # Search in order: val, train, test
    for metrics_dict in [
        self._val_metrics,
        self._train_metrics,
        self._test_metrics,
    ]:
        if name in metrics_dict:
            return metrics_dict[name][0]
    return None

get_config_by_name

get_config_by_name(name: str) -> MetricConfig | None

Get a metric config by name.

Parameters:

Name Type Description Default
name str

The metric name to search for.

required

Returns:

Type Description
MetricConfig | None

The matching MetricConfig, or None if not found.

Source code in src/autotimm/core/metrics.py
def get_config_by_name(self, name: str) -> MetricConfig | None:
    """Get a metric config by name.

    Parameters:
        name: The metric name to search for.

    Returns:
        The matching MetricConfig, or None if not found.
    """
    for config in self._configs:
        if config.name == name:
            return config
    return None

Usage Examples

Basic Usage

from autotimm import MetricConfig, MetricManager


def main():
    metric_configs = [
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["train", "val", "test"],
        ),
        MetricConfig(
            name="f1",
            backend="torchmetrics",
            metric_class="F1Score",
            params={"task": "multiclass", "average": "macro"},
            stages=["val", "test"],
        ),
    ]

    manager = MetricManager(configs=metric_configs, num_classes=10)

    print(f"Number of metrics: {len(manager)}")
    print(f"Number of classes: {manager.num_classes}")


if __name__ == "__main__":
    main()

Access Stage Metrics

def main():
    # ... create manager ...

    train_metrics = manager.get_train_metrics()  # ModuleDict
    val_metrics = manager.get_val_metrics()
    test_metrics = manager.get_test_metrics()

    # Use in training loop
    for name, metric in train_metrics.items():
        metric.update(preds, targets)
        value = metric.compute()


if __name__ == "__main__":
    main()

Get Metric Config

def main():
    # ... create manager ...

    config = manager.get_metric_config("val", "accuracy")
    print(config.prog_bar)  # True


if __name__ == "__main__":
    main()

Access Metrics by Name

def main():
    # ... create manager ...

    # Get metric instance by name
    accuracy_metric = manager.get_metric_by_name("accuracy")
    accuracy_metric = manager.get_metric_by_name("accuracy", stage="val")

    # Get config by name
    config = manager.get_config_by_name("accuracy")
    print(config.stages)  # ["train", "val", "test"]


if __name__ == "__main__":
    main()

Iterate Over Configs

def main():
    # ... create manager ...

    # Iterate over all configs
    for config in manager:
        print(f"{config.name}: {config.stages}")

    # Access by index
    first_config = manager[0]
    print(f"First metric: {first_config.name}")

    # Length
    print(f"Number of metrics: {len(manager)}")


if __name__ == "__main__":
    main()

Parameters

Parameter Type Description
configs list[MetricConfig] List of metric configs
num_classes int Number of classes

Methods

Method Returns Description
get_train_metrics() ModuleDict Train stage metrics
get_val_metrics() ModuleDict Validation stage metrics
get_test_metrics() ModuleDict Test stage metrics
get_metric_config(stage, name) MetricConfig \| None Get config by stage/name
get_metric_by_name(name, stage) Module \| None Get metric instance by name
get_config_by_name(name) MetricConfig \| None Get config by name
len(manager) int Number of metric configs
iter(manager) Iterator Iterate over configs
manager[i] MetricConfig Get config by index

LoggingConfig

Configuration for enhanced logging during training.

API Reference

autotimm.LoggingConfig dataclass

Configuration for enhanced logging during training.

Parameters:

Name Type Description Default
log_learning_rate bool

Whether to log the current learning rate.

required
log_gradient_norm bool

Whether to log gradient norms.

required
log_weight_norm bool

Whether to log weight norms.

False
log_confusion_matrix bool

Whether to log confusion matrix at epoch end.

False
log_predictions bool

Whether to log sample predictions/images.

False
predictions_per_epoch int

Number of sample predictions to log per epoch.

8
verbosity int

Logging verbosity level (0=minimal, 1=normal, 2=verbose).

1
Example

config = LoggingConfig( ... log_learning_rate=True, ... log_gradient_norm=True, ... log_confusion_matrix=True, ... )

Source code in src/autotimm/core/metrics.py
@dataclass
class LoggingConfig:
    """Configuration for enhanced logging during training.

    Parameters:
        log_learning_rate: Whether to log the current learning rate.
        log_gradient_norm: Whether to log gradient norms.
        log_weight_norm: Whether to log weight norms.
        log_confusion_matrix: Whether to log confusion matrix at epoch end.
        log_predictions: Whether to log sample predictions/images.
        predictions_per_epoch: Number of sample predictions to log per epoch.
        verbosity: Logging verbosity level (0=minimal, 1=normal, 2=verbose).

    Example:
        >>> config = LoggingConfig(
        ...     log_learning_rate=True,
        ...     log_gradient_norm=True,
        ...     log_confusion_matrix=True,
        ... )
    """

    log_learning_rate: bool
    log_gradient_norm: bool
    log_weight_norm: bool = False
    log_confusion_matrix: bool = False
    log_predictions: bool = False
    predictions_per_epoch: int = 8
    verbosity: int = 1

    def __post_init__(self) -> None:
        if self.verbosity not in (0, 1, 2):
            raise ValueError("verbosity must be 0, 1, or 2")
        if self.predictions_per_epoch < 0:
            raise ValueError("predictions_per_epoch must be non-negative")

Usage Examples

Basic Logging

from autotimm import LoggingConfig

config = LoggingConfig(
    log_learning_rate=True,
    log_gradient_norm=True,
)

Full Logging

config = LoggingConfig(
    log_learning_rate=True,
    log_gradient_norm=True,
    log_weight_norm=True,
    log_confusion_matrix=True,
    log_predictions=False,
    predictions_per_epoch=8,
    verbosity=2,
)

Parameters

Parameter Type Default Description
log_learning_rate bool Required Log LR each step
log_gradient_norm bool Required Log gradient norms
log_weight_norm bool False Log weight norms
log_confusion_matrix bool False Log confusion matrix
log_predictions bool False Log sample predictions
predictions_per_epoch int 8 Predictions to log
verbosity int 1 0=minimal, 1=normal, 2=verbose

Logged Values

Metric Key Condition
Learning rate train/lr log_learning_rate=True
Gradient norm train/grad_norm log_gradient_norm=True
Weight norm train/weight_norm log_weight_norm=True
Confusion matrix val/confusion_matrix log_confusion_matrix=True

Common Torchmetrics

Classification

Metric Class Common Params
Accuracy task="multiclass", top_k=1
F1Score task="multiclass", average="macro"
Precision task="multiclass", average="macro"
Recall task="multiclass", average="macro"
AUROC task="multiclass"
ConfusionMatrix task="multiclass"

Binary Classification

Metric Class Common Params
Accuracy task="binary"
F1Score task="binary"
AUROC task="binary"
Precision task="binary"
Recall task="binary"

Multi-Label Classification

Use torchmetrics.classification.Multilabel* metrics with ImageClassifier(multi_label=True). These are resolved automatically from the torchmetrics.classification submodule.

Metric Class Common Params
MultilabelAccuracy num_labels=N
MultilabelF1Score num_labels=N, average="macro"
MultilabelPrecision num_labels=N, average="macro"
MultilabelRecall num_labels=N, average="macro"
MultilabelAUROC num_labels=N
MultilabelHammingDistance num_labels=N
MetricConfig(
    name="accuracy",
    backend="torchmetrics",
    metric_class="MultilabelAccuracy",
    params={"num_labels": 4},
    stages=["train", "val"],
    prog_bar=True,
)

Note: num_classes and num_labels are auto-injected by MetricManager from the model's num_classes value. Auto-injected parameters that a metric doesn't accept are automatically filtered out.

Average Options

Value Description
"micro" Global average
"macro" Unweighted class average
"weighted" Weighted by class support
"none" Per-class values

Full Example

from autotimm import (
    AutoTrainer,
    ImageClassifier,
    ImageDataModule,
    LoggerConfig,
    LoggingConfig,
    MetricConfig,
    MetricManager,
)


def main():
    # Define metrics
    metric_configs = [
        MetricConfig(
            name="accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass"},
            stages=["train", "val", "test"],
            prog_bar=True,
        ),
        MetricConfig(
            name="top5_accuracy",
            backend="torchmetrics",
            metric_class="Accuracy",
            params={"task": "multiclass", "top_k": 5},
            stages=["val", "test"],
        ),
        MetricConfig(
            name="f1",
            backend="torchmetrics",
            metric_class="F1Score",
            params={"task": "multiclass", "average": "macro"},
            stages=["val", "test"],
            prog_bar=True,
        ),
        MetricConfig(
            name="precision",
            backend="torchmetrics",
            metric_class="Precision",
            params={"task": "multiclass", "average": "macro"},
            stages=["test"],
        ),
        MetricConfig(
            name="recall",
            backend="torchmetrics",
            metric_class="Recall",
            params={"task": "multiclass", "average": "macro"},
            stages=["test"],
        ),
    ]

    # Create MetricManager
    metric_manager = MetricManager(configs=metric_configs, num_classes=10)

    # Data
    data = ImageDataModule(
        data_dir="./data",
        dataset_name="CIFAR10",
        image_size=224,
        batch_size=64,
    )

    # Create model
    model = ImageClassifier(
        backbone="resnet50",
        num_classes=10,
        metrics=metric_manager,
        logging_config=LoggingConfig(
            log_learning_rate=True,
            log_gradient_norm=True,
            log_confusion_matrix=True,
        ),
    )

    # Trainer
    trainer = AutoTrainer(
        max_epochs=10,
        logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
    )

    trainer.fit(model, datamodule=data)
    trainer.test(model, datamodule=data)


if __name__ == "__main__":
    main()