Advanced Customization¶
This guide covers advanced customization options in AutoTimm, including custom heads, loss functions, metrics, backbone modifications, and Lightning callbacks.
Custom Classification Heads¶
Basic Custom Head¶
import torch
import torch.nn as nn
from autotimm import ImageClassifier, MetricConfig
class CustomClassificationHead(nn.Module):
"""Custom classification head with multiple layers."""
def __init__(self, in_features: int, num_classes: int, hidden_dim: int = 512):
super().__init__()
self.head = nn.Sequential(
nn.Linear(in_features, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(hidden_dim // 2, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.head(x)
# Using custom head with ImageClassifier
class CustomImageClassifier(ImageClassifier):
"""ImageClassifier with custom head."""
def __init__(self, *args, hidden_dim: int = 512, **kwargs):
super().__init__(*args, **kwargs)
# Replace the default head
self.head = CustomClassificationHead(
in_features=self.backbone.num_features,
num_classes=self.num_classes,
hidden_dim=hidden_dim,
)
# Usage
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val"],
)
]
model = CustomImageClassifier(
backbone="resnet50",
num_classes=10,
metrics=metrics,
hidden_dim=1024,
)
Attention-Based Head¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionHead(nn.Module):
"""Classification head with self-attention."""
def __init__(self, in_features: int, num_classes: int, num_heads: int = 8):
super().__init__()
self.attention = nn.MultiheadAttention(
embed_dim=in_features,
num_heads=num_heads,
batch_first=True,
)
self.norm = nn.LayerNorm(in_features)
self.fc = nn.Linear(in_features, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, features] -> [B, 1, features] for attention
x = x.unsqueeze(1)
attn_out, _ = self.attention(x, x, x)
x = self.norm(x + attn_out)
x = x.squeeze(1)
return self.fc(x)
Custom Detection Heads¶
Modified FCOS Head¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from autotimm import DetectionHead
class CustomDetectionHead(DetectionHead):
"""Detection head with additional scale prediction."""
def __init__(
self,
in_channels: int = 256,
num_classes: int = 80,
num_convs: int = 4,
predict_scale: bool = True,
):
super().__init__(
in_channels=in_channels,
num_classes=num_classes,
num_convs=num_convs,
)
self.predict_scale = predict_scale
if predict_scale:
# Additional scale prediction branch
self.scale_pred = nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)
def forward(self, features):
cls_outputs, reg_outputs, centerness_outputs = super().forward(features)
if self.predict_scale:
scale_outputs = []
for feat in features:
reg_feat = self.reg_convs(feat)
scale_out = self.scale_pred(reg_feat)
scale_outputs.append(scale_out)
return cls_outputs, reg_outputs, centerness_outputs, scale_outputs
return cls_outputs, reg_outputs, centerness_outputs
Custom Segmentation Heads¶
UNet-Style Decoder¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetDecoder(nn.Module):
"""UNet-style decoder for segmentation."""
def __init__(self, encoder_channels: list, num_classes: int, decoder_channels: int = 256):
super().__init__()
self.num_classes = num_classes
# Decoder blocks (bottom-up)
self.decoder_blocks = nn.ModuleList()
in_channels = encoder_channels[-1]
for enc_ch in reversed(encoder_channels[:-1]):
self.decoder_blocks.append(
self._make_decoder_block(in_channels + enc_ch, decoder_channels)
)
in_channels = decoder_channels
# Final classifier
self.classifier = nn.Conv2d(decoder_channels, num_classes, kernel_size=1)
def _make_decoder_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, features: list) -> torch.Tensor:
# features: [C1, C2, C3, C4, C5] from encoder
x = features[-1]
for i, decoder_block in enumerate(self.decoder_blocks):
# Upsample
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
# Skip connection
skip = features[-(i + 2)]
if x.shape[2:] != skip.shape[2:]:
x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False)
# Concatenate and decode
x = torch.cat([x, skip], dim=1)
x = decoder_block(x)
return self.classifier(x)
Custom Loss Functions¶
Focal Tversky Loss¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalTverskyLoss(nn.Module):
"""Focal Tversky loss for highly imbalanced segmentation."""
def __init__(
self,
num_classes: int,
alpha: float = 0.7,
beta: float = 0.3,
gamma: float = 0.75,
smooth: float = 1.0,
ignore_index: int = 255,
):
super().__init__()
self.num_classes = num_classes
self.alpha = alpha # Weight for false negatives
self.beta = beta # Weight for false positives
self.gamma = gamma # Focal parameter
self.smooth = smooth
self.ignore_index = ignore_index
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
probs = F.softmax(logits, dim=1)
# Create valid mask
valid_mask = targets != self.ignore_index
# One-hot encode targets
targets_clamped = targets.clamp(0, self.num_classes - 1)
targets_one_hot = F.one_hot(targets_clamped, self.num_classes)
targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
# Apply valid mask
valid_mask_expanded = valid_mask.unsqueeze(1).float()
probs = probs * valid_mask_expanded
targets_one_hot = targets_one_hot * valid_mask_expanded
# Flatten
probs = probs.flatten(2)
targets_one_hot = targets_one_hot.flatten(2)
# Tversky components
tp = (probs * targets_one_hot).sum(dim=2)
fp = (probs * (1 - targets_one_hot)).sum(dim=2)
fn = ((1 - probs) * targets_one_hot).sum(dim=2)
# Tversky index
tversky = (tp + self.smooth) / (
tp + self.alpha * fn + self.beta * fp + self.smooth
)
# Focal Tversky loss
focal_tversky = (1 - tversky) ** self.gamma
return focal_tversky.mean()
Label Smoothing Cross Entropy¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
"""Cross entropy with label smoothing."""
def __init__(self, smoothing: float = 0.1, reduction: str = "mean"):
super().__init__()
self.smoothing = smoothing
self.reduction = reduction
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
n_classes = logits.size(-1)
# Create smoothed targets
with torch.inference_mode():
smooth_targets = torch.zeros_like(logits)
smooth_targets.fill_(self.smoothing / (n_classes - 1))
smooth_targets.scatter_(1, targets.unsqueeze(1), 1 - self.smoothing)
# Compute loss
log_probs = F.log_softmax(logits, dim=-1)
loss = -(smooth_targets * log_probs).sum(dim=-1)
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
Boundary Loss for Segmentation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt
class BoundaryLoss(nn.Module):
"""Boundary loss for segmentation tasks."""
def __init__(self, num_classes: int):
super().__init__()
self.num_classes = num_classes
def _compute_distance_map(self, mask: torch.Tensor) -> torch.Tensor:
"""Compute distance transform for a binary mask."""
mask_np = mask.cpu().numpy()
dist = distance_transform_edt(mask_np) + distance_transform_edt(1 - mask_np)
return torch.from_numpy(dist).to(mask.device).float()
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
probs = F.softmax(logits, dim=1)
total_loss = 0.0
for c in range(self.num_classes):
# Get class probability and target
class_prob = probs[:, c]
class_target = (targets == c).float()
# Compute distance map for target
dist_maps = []
for b in range(class_target.size(0)):
dist_map = self._compute_distance_map(class_target[b])
dist_maps.append(dist_map)
dist_map = torch.stack(dist_maps)
# Boundary loss
total_loss += (dist_map * class_prob).mean()
return total_loss / self.num_classes
Custom Metrics¶
Creating Torchmetrics Subclass¶
import torch
import torchmetrics
from autotimm import MetricConfig
class WeightedAccuracy(torchmetrics.Metric):
"""Accuracy weighted by class frequency."""
def __init__(self, num_classes: int, class_weights: torch.Tensor = None):
super().__init__()
self.num_classes = num_classes
if class_weights is None:
class_weights = torch.ones(num_classes)
self.register_buffer("class_weights", class_weights)
self.add_state("correct", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("total", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds = preds.argmax(dim=1)
for c in range(self.num_classes):
mask = target == c
self.correct[c] += (preds[mask] == target[mask]).sum()
self.total[c] += mask.sum()
def compute(self) -> torch.Tensor:
per_class_acc = self.correct / self.total.clamp(min=1)
weighted_acc = (per_class_acc * self.class_weights).sum() / self.class_weights.sum()
return weighted_acc
# Usage in MetricConfig
weighted_accuracy = MetricConfig(
name="weighted_accuracy",
backend="custom",
metric_class="mymodule.WeightedAccuracy", # Full path to your class
params={"num_classes": 10},
stages=["val", "test"],
)
F-Beta Score¶
import torch
import torchmetrics
class FBetaScore(torchmetrics.Metric):
"""F-beta score with configurable beta."""
def __init__(self, num_classes: int, beta: float = 1.0, average: str = "macro"):
super().__init__()
self.num_classes = num_classes
self.beta = beta
self.average = average
self.add_state("tp", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("fp", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("fn", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds = preds.argmax(dim=1)
for c in range(self.num_classes):
pred_c = preds == c
target_c = target == c
self.tp[c] += (pred_c & target_c).sum()
self.fp[c] += (pred_c & ~target_c).sum()
self.fn[c] += (~pred_c & target_c).sum()
def compute(self) -> torch.Tensor:
beta_sq = self.beta ** 2
precision = self.tp / (self.tp + self.fp).clamp(min=1e-7)
recall = self.tp / (self.tp + self.fn).clamp(min=1e-7)
fbeta = (1 + beta_sq) * precision * recall / (beta_sq * precision + recall).clamp(min=1e-7)
if self.average == "macro":
return fbeta.mean()
elif self.average == "micro":
tp_sum = self.tp.sum()
fp_sum = self.fp.sum()
fn_sum = self.fn.sum()
precision = tp_sum / (tp_sum + fp_sum).clamp(min=1e-7)
recall = tp_sum / (tp_sum + fn_sum).clamp(min=1e-7)
return (1 + beta_sq) * precision * recall / (beta_sq * precision + recall).clamp(min=1e-7)
else:
return fbeta
Backbone Modifications¶
Freezing Backbone Layers¶
from autotimm import ImageClassifier, MetricConfig
def freeze_backbone_layers(model, num_layers_to_freeze: int):
"""Freeze the first N layers of the backbone."""
# Get all backbone modules
backbone_modules = list(model.backbone.children())
for i, module in enumerate(backbone_modules):
if i < num_layers_to_freeze:
for param in module.parameters():
param.requires_grad = False
print(f"Froze layer {i}: {module.__class__.__name__}")
# Usage
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val"],
)
]
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
metrics=metrics,
)
# Freeze first 6 layers (up to and including layer2)
freeze_backbone_layers(model, 6)
Layer-Wise Learning Rate Decay¶
from autotimm import ImageClassifier
class LayerWiseLRClassifier(ImageClassifier):
"""Classifier with layer-wise learning rate decay."""
def __init__(self, *args, lr_decay: float = 0.9, **kwargs):
super().__init__(*args, **kwargs)
self.lr_decay = lr_decay
def configure_optimizers(self):
# Get backbone layers
backbone_layers = list(self.backbone.children())
num_layers = len(backbone_layers)
# Create parameter groups with decaying LR
param_groups = []
for i, layer in enumerate(backbone_layers):
layer_lr = self._lr * (self.lr_decay ** (num_layers - i - 1))
param_groups.append(
{"params": layer.parameters(), "lr": layer_lr}
)
# Head with base LR
param_groups.append({"params": self.head.parameters(), "lr": self._lr})
optimizer = torch.optim.AdamW(param_groups, weight_decay=self._weight_decay)
# Add scheduler if configured
if self._scheduler:
scheduler = self._create_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
return optimizer
Custom Feature Extraction¶
import timm
import torch
import torch.nn as nn
class MultiScaleBackbone(nn.Module):
"""Backbone that returns features at multiple scales."""
def __init__(self, backbone_name: str = "resnet50", pretrained: bool = True):
super().__init__()
self.backbone = timm.create_model(
backbone_name,
pretrained=pretrained,
features_only=True,
out_indices=[1, 2, 3, 4], # C2, C3, C4, C5
)
self.feature_channels = self.backbone.feature_info.channels()
def forward(self, x: torch.Tensor) -> list:
return self.backbone(x)
# Usage
backbone = MultiScaleBackbone("resnet50")
features = backbone(torch.randn(1, 3, 224, 224))
for i, feat in enumerate(features):
print(f"Feature {i}: {feat.shape}")
Extending Task Classes¶
Custom Training Step¶
import torch
from autotimm import ImageClassifier
class MixupClassifier(ImageClassifier):
"""Classifier with Mixup augmentation during training."""
def __init__(self, *args, mixup_alpha: float = 0.2, **kwargs):
super().__init__(*args, **kwargs)
self.mixup_alpha = mixup_alpha
def training_step(self, batch, batch_idx):
images, targets = batch
# Apply mixup
if self.training and self.mixup_alpha > 0:
images, targets_a, targets_b, lam = self._mixup_data(images, targets)
# Forward pass
logits = self(images)
# Mixed loss
loss = lam * self.criterion(logits, targets_a) + (1 - lam) * self.criterion(
logits, targets_b
)
else:
logits = self(images)
loss = self.criterion(logits, targets)
self.log("train/loss", loss, prog_bar=True)
return loss
def _mixup_data(self, x, y):
lam = torch.distributions.Beta(self.mixup_alpha, self.mixup_alpha).sample()
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
Custom Validation with Visualization¶
import torch
from autotimm import ImageClassifier
class VisualizingClassifier(ImageClassifier):
"""Classifier that logs sample predictions during validation."""
def validation_step(self, batch, batch_idx):
images, targets = batch
logits = self(images)
loss = self.criterion(logits, targets)
# Log metrics
self.log("val/loss", loss, prog_bar=True)
# Update metrics
for name, metric in self.val_metrics.items():
metric.update(logits, targets)
# Log sample predictions (first batch only)
if batch_idx == 0 and self.logger:
self._log_predictions(images, logits, targets)
return loss
def _log_predictions(self, images, logits, targets, num_samples=8):
preds = logits.argmax(dim=1)
# Create visualization
for i in range(min(num_samples, images.size(0))):
img = images[i].cpu()
pred = preds[i].item()
target = targets[i].item()
# Denormalize image
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img = img * std + mean
img = img.clamp(0, 1)
# Log to tensorboard
self.logger.experiment.add_image(
f"val/sample_{i}_pred{pred}_true{target}",
img,
self.current_epoch,
)
Custom Augmentation Pipelines¶
Using Albumentations¶
import albumentations as A
from albumentations.pytorch import ToTensorV2
from autotimm import ImageDataModule
def get_strong_augmentations(image_size: int = 224):
"""Create strong augmentation pipeline."""
return A.Compose(
[
A.RandomResizedCrop(height=image_size, width=image_size, scale=(0.6, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(
shift_limit=0.1,
scale_limit=0.2,
rotate_limit=30,
p=0.5,
),
A.OneOf(
[
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
A.MotionBlur(blur_limit=(3, 7), p=1.0),
A.MedianBlur(blur_limit=5, p=1.0),
],
p=0.3,
),
A.OneOf(
[
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0),
],
p=0.5,
),
A.CoarseDropout(
max_holes=8, max_height=32, max_width=32, fill_value=0, p=0.3
),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
# Usage with ImageDataModule
data = ImageDataModule(
data_dir="./data",
image_size=224,
batch_size=32,
train_transforms=get_strong_augmentations(224),
)
Lightning Callbacks¶
Gradient Monitoring Callback¶
import pytorch_lightning as pl
import torch
class GradientMonitor(pl.Callback):
"""Monitor gradient statistics during training."""
def __init__(self, log_every_n_steps: int = 100):
super().__init__()
self.log_every_n_steps = log_every_n_steps
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
if trainer.global_step % self.log_every_n_steps != 0:
return
grad_norms = {}
for name, param in pl_module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
# Simplify name for logging
simple_name = name.replace(".", "/")
grad_norms[f"grad_norm/{simple_name}"] = grad_norm
# Log statistics
if grad_norms:
values = list(grad_norms.values())
pl_module.log("grad_norm/mean", sum(values) / len(values))
pl_module.log("grad_norm/max", max(values))
pl_module.log("grad_norm/min", min(values))
Learning Rate Warmup Callback¶
import pytorch_lightning as pl
class LearningRateWarmup(pl.Callback):
"""Linear learning rate warmup."""
def __init__(self, warmup_steps: int = 1000, start_lr: float = 1e-7):
super().__init__()
self.warmup_steps = warmup_steps
self.start_lr = start_lr
self.target_lr = None
def on_train_start(self, trainer, pl_module):
# Store target LR
optimizer = trainer.optimizers[0]
self.target_lr = optimizer.param_groups[0]["lr"]
# Set initial LR
for param_group in optimizer.param_groups:
param_group["lr"] = self.start_lr
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if trainer.global_step >= self.warmup_steps:
return
# Linear warmup
progress = trainer.global_step / self.warmup_steps
current_lr = self.start_lr + progress * (self.target_lr - self.start_lr)
optimizer = trainer.optimizers[0]
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
pl_module.log("warmup/lr", current_lr)
Model EMA Callback¶
import copy
import pytorch_lightning as pl
import torch
class ModelEMA(pl.Callback):
"""Exponential Moving Average of model weights."""
def __init__(self, decay: float = 0.999):
super().__init__()
self.decay = decay
self.ema_model = None
def on_fit_start(self, trainer, pl_module):
# Create EMA model
self.ema_model = copy.deepcopy(pl_module)
for param in self.ema_model.parameters():
param.requires_grad_(False)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Update EMA weights
with torch.inference_mode():
for ema_param, param in zip(
self.ema_model.parameters(), pl_module.parameters()
):
ema_param.data.mul_(self.decay).add_(param.data, alpha=1 - self.decay)
def on_validation_epoch_start(self, trainer, pl_module):
# Swap to EMA weights for validation
self._swap_weights(pl_module)
def on_validation_epoch_end(self, trainer, pl_module):
# Swap back to training weights
self._swap_weights(pl_module)
def _swap_weights(self, pl_module):
for ema_param, param in zip(
self.ema_model.parameters(), pl_module.parameters()
):
ema_param.data, param.data = param.data.clone(), ema_param.data.clone()
Usage with AutoTrainer¶
from autotimm import AutoTrainer
trainer = AutoTrainer(
max_epochs=50,
callbacks=[
GradientMonitor(log_every_n_steps=100),
LearningRateWarmup(warmup_steps=500),
ModelEMA(decay=0.999),
],
)
trainer.fit(model, datamodule=data)
Performance Optimization¶
torch.compile (PyTorch 2.0+)¶
Enabled by default in all AutoTimm tasks for automatic optimization:
from autotimm import ImageClassifier, MetricConfig
# Default: torch.compile enabled
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
metrics=metrics,
)
Disable or customize:
# Disable compilation
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
compile_model=False, # Disable torch.compile
)
# Custom compile options
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
compile_kwargs={
"mode": "reduce-overhead", # "default", "reduce-overhead", "max-autotune"
"fullgraph": True, # Attempt full graph compilation
"dynamic": False, # Static vs dynamic shapes
},
)
Compile Modes:
| Mode | Description | Use Case |
|---|---|---|
default |
Balanced optimization | General purpose (recommended) |
reduce-overhead |
Lower latency | Small batches, inference |
max-autotune |
Maximum speed | Longer compile time, production |
What Gets Compiled:
- ImageClassifier: Backbone + Head
- ObjectDetector: Backbone + FPN + Detection Head
- SemanticSegmentor: Backbone + Segmentation Head
- InstanceSegmentor: Backbone + FPN + Detection Head + Mask Head
- YOLOXDetector: Backbone + Neck + Head
Notes:
- First run will be slower due to compilation
- Falls back gracefully on PyTorch < 2.0
- Compatible with all custom heads and modifications
Reproducibility¶
Automatic seeding enabled by default for reproducible experiments:
from autotimm import ImageClassifier, AutoTrainer, seed_everything
# Default: seed=None, deterministic=True (model) / False (trainer)
# Seeding is opt-in — pass seed=42 explicitly for reproducibility
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
seed=42, # opt-in for reproducibility
deterministic=True, # default
)
trainer = AutoTrainer(
max_epochs=10,
seed=42, # opt-in for reproducibility
)
Custom seeding:
# Use a different seed
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
seed=123,
deterministic=True,
)
trainer = AutoTrainer(max_epochs=10, seed=123)
Faster training (disable deterministic mode):
# Disable deterministic for speed
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
seed=42,
deterministic=False, # Enables cuDNN benchmark
)
trainer = AutoTrainer(
max_epochs=10,
deterministic=False, # Faster but not fully deterministic
)
Manual seeding:
# Seed everything manually
seed_everything(42, deterministic=True)
# Now create models without automatic seeding
model = ImageClassifier(
backbone="resnet50",
num_classes=10,
seed=None, # Don't seed again
)
Trainer seeding options:
# Use PyTorch Lightning's seeding (default)
trainer = AutoTrainer(
max_epochs=10,
seed=42,
use_autotimm_seeding=False, # Default
)
# Use AutoTimm's custom seeding
trainer = AutoTrainer(
max_epochs=10,
seed=42,
use_autotimm_seeding=True,
)
What's seeded:
- Python's
randommodule - NumPy's random number generator
- PyTorch (CPU & CUDA)
- Environment variables (
PYTHONHASHSEED) - cuDNN deterministic algorithms (when
deterministic=True)
Trade-offs:
| Setting | Speed | Reproducibility | Use Case |
|---|---|---|---|
deterministic=True |
Slower | Full | Research, debugging |
deterministic=False |
Faster | Partial | Production training |
seed=None |
Fastest | None | Exploring variance |
Best practices:
- Research papers: Set
seed=42, deterministic=Trueexplicitly - Production training: Set
seed=42, deterministic=Falseexplicitly - Debugging: Set
seed=42, deterministic=Trueexplicitly - Exploration: Use default
seed=None
Complete Example¶
import torch
import torch.nn as nn
import torchmetrics
from autotimm import (
AutoTrainer,
ImageClassifier,
ImageDataModule,
LoggerConfig,
MetricConfig,
)
# Custom head
class MLPHead(nn.Module):
def __init__(self, in_features, num_classes, hidden_dims=[512, 256]):
super().__init__()
layers = []
prev_dim = in_features
for dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, dim),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
])
prev_dim = dim
layers.append(nn.Linear(prev_dim, num_classes))
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
# Custom classifier
class CustomClassifier(ImageClassifier):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = MLPHead(
self.backbone.num_features,
self.num_classes,
hidden_dims=[512, 256],
)
def main():
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
prog_bar=True,
),
MetricConfig(
name="f1",
backend="torchmetrics",
metric_class="F1Score",
params={"task": "multiclass", "average": "macro"},
stages=["val", "test"],
),
]
data = ImageDataModule(
data_dir="./data",
dataset_name="CIFAR10",
image_size=224,
batch_size=64,
)
model = CustomClassifier(
backbone="resnet50",
num_classes=10,
metrics=metrics,
lr=1e-4,
)
trainer = AutoTrainer(
max_epochs=50,
logger=[LoggerConfig(backend="tensorboard", params={"save_dir": "logs"})],
checkpoint_monitor="val/accuracy",
gradient_clip_val=1.0,
)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()
See Also¶
- Training Guide - Training configuration
- Loss Comparison - Built-in loss functions
- Metric Selection - Built-in metrics
- API Reference: Heads - Head module API