Training Callbacks¶
PyTorch Lightning callbacks for automatic interpretation and feature monitoring during training.
Overview¶
AutoTimm provides two callbacks for integrating interpretation into your training workflow:
InterpretationCallback: Automatically generate and log explanations during trainingFeatureMonitorCallback: Track feature statistics across layers
Both callbacks integrate seamlessly with PyTorch Lightning loggers (TensorBoard, Weights & Biases, MLflow).
InterpretationCallback¶
Automatically generate explanations for sample images during training and log them to your tracking platform.
Class: InterpretationCallback¶
import autotimm as at # recommended alias
from autotimm.interpretation import InterpretationCallback
callback = InterpretationCallback(
sample_images: Union[torch.Tensor, List[torch.Tensor], List[str]],
sample_labels: Optional[List[int]] = None,
method: Literal['gradcam', 'gradcam++', 'integrated_gradients'] = 'gradcam',
target_layer: Optional[Union[str, torch.nn.Module]] = None,
log_every_n_epochs: int = 5,
log_every_n_steps: Optional[int] = None,
num_samples: int = 8,
colormap: str = "viridis",
alpha: float = 0.4,
prefix: str = "interpretation"
)
Parameters:
sample_images: Images to explain during training. Can be torch tensors, list of tensors, or list of file paths. Will be sampled down tonum_samplesif more providedsample_labels(Optional[List[int]]): Ground truth labels for sample imagesmethod(str): Interpretation method to use. Options:'gradcam'(default, fast),'gradcam++'(better for multiple objects),'integrated_gradients'(pixel-level attributions)target_layer(Optional): Layer to use for interpretation (None = auto-detect)log_every_n_epochs(int): Generate explanations every N epochs (default: 5)log_every_n_steps(Optional[int]): Alternative: log every N steps (overrides epochs)num_samples(int): Number of images to explain (default: 8)colormap(str): Matplotlib colormap for heatmaps (default: "viridis")alpha(float): Overlay transparency (0-1, default: 0.4)prefix(str): Prefix for logged images (default: "interpretation")
Basic Usage¶
from autotimm import AutoTrainer, ImageClassifier, ImageDataModule
from autotimm.interpretation import InterpretationCallback
from PIL import Image
# Prepare sample images for monitoring
sample_images = [Image.open(f"samples/img_{i}.jpg") for i in range(8)]
# Create callback
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=5,
num_samples=8,
)
# Create model and trainer
model = ImageClassifier(backbone="resnet50", num_classes=10)
data = ImageDataModule(train_dir="data/train", val_dir="data/val")
# Train with automatic interpretation
trainer = AutoTrainer(
max_epochs=100,
callbacks=[interp_callback],
logger="tensorboard", # or "wandb", "mlflow"
)
trainer.fit(model, datamodule=data)
With TensorBoard¶
from autotimm.core.loggers import LoggerConfig
# Configure TensorBoard logger
logger_config = LoggerConfig(
backend="tensorboard",
save_dir="logs/",
name="my_experiment",
)
# Create callback
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=5,
)
# Train
trainer = AutoTrainer(
max_epochs=100,
callbacks=[interp_callback],
logger=logger_config.create_logger(),
)
trainer.fit(model, datamodule=data)
View in TensorBoard:
Navigate to the "IMAGES" tab to see interpretations logged during training.
With Weights & Biases¶
import wandb
from pytorch_lightning.loggers import WandbLogger
# Initialize W&B
wandb.init(project="my_project", name="experiment_1")
# Create callback
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam++",
log_every_n_epochs=5,
)
# Train
trainer = AutoTrainer(
max_epochs=100,
callbacks=[interp_callback],
logger=WandbLogger(),
)
trainer.fit(model, datamodule=data)
With MLflow¶
from pytorch_lightning.loggers import MLFlowLogger
# Create callback
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=10,
)
# Train
trainer = AutoTrainer(
max_epochs=100,
callbacks=[interp_callback],
logger=MLFlowLogger(experiment_name="my_experiment"),
)
trainer.fit(model, datamodule=data)
Step-Based Logging¶
Log explanations based on training steps instead of epochs:
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_steps=1000, # Log every 1000 steps
)
Multiple Methods¶
Compare different interpretation methods during training:
callbacks = [
InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=5,
prefix="interp_gradcam",
),
InterpretationCallback(
sample_images=sample_images,
method="gradcam++",
log_every_n_epochs=5,
prefix="interp_gradcampp",
),
]
trainer = AutoTrainer(max_epochs=100, callbacks=callbacks)
Custom Visualization¶
Customize colormap and overlay settings:
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
colormap="hot", # Use 'hot' colormap
alpha=0.5, # More transparent overlay
log_every_n_epochs=5,
)
FeatureMonitorCallback¶
Monitor feature statistics during training to understand how features evolve.
Class: FeatureMonitorCallback¶
from autotimm.interpretation import FeatureMonitorCallback
callback = FeatureMonitorCallback(
layer_names: List[str],
log_every_n_epochs: int = 1,
num_batches: int = 10
)
Parameters:
layer_names(List[str]): Names of layers to monitor (e.g., ["backbone.layer2", "backbone.layer4"])log_every_n_epochs(int): Log statistics every N epochs (default: 1)num_batches(int): Number of batches to accumulate statistics over (default: 10)
Basic Usage¶
from autotimm import AutoTrainer, ImageClassifier
from autotimm.interpretation import FeatureMonitorCallback
# Create callback
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer2", "backbone.layer3", "backbone.layer4"],
log_every_n_epochs=1,
num_batches=10,
)
# Create model
model = ImageClassifier(backbone="resnet50", num_classes=10)
# Train with feature monitoring
trainer = AutoTrainer(
max_epochs=100,
callbacks=[feature_callback],
logger="tensorboard",
)
trainer.fit(model, datamodule=data)
Logged Metrics¶
The callback logs the following metrics for each monitored layer:
features/{layer_name}/mean: Mean activationfeatures/{layer_name}/std: Standard deviationfeatures/{layer_name}/sparsity: Fraction of zero activationsfeatures/{layer_name}/max: Maximum activation
Example in TensorBoard:
features/backbone.layer2/mean: 0.234
features/backbone.layer2/std: 0.156
features/backbone.layer2/sparsity: 0.345
features/backbone.layer2/max: 2.456
Monitoring Specific Layers¶
Monitor only the final layers for efficiency:
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer4"], # Only final layer
log_every_n_epochs=1,
)
Fine-Grained Monitoring¶
Monitor every layer for detailed analysis:
# Get all conv layer names
layer_names = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
layer_names.append(name)
feature_callback = FeatureMonitorCallback(
layer_names=layer_names,
log_every_n_epochs=1,
)
Adjusting Batch Count¶
Control how many batches to use for statistics:
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer2", "backbone.layer3", "backbone.layer4"],
log_every_n_epochs=1,
num_batches=20, # More batches = more accurate but slower
)
Combined Usage¶
Use both callbacks together for comprehensive monitoring:
from autotimm import AutoTrainer, ImageClassifier
from autotimm.interpretation import InterpretationCallback, FeatureMonitorCallback
# Sample images for interpretation
sample_images = [Image.open(f"samples/img_{i}.jpg") for i in range(8)]
# Interpretation callback
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=5,
num_samples=8,
)
# Feature monitoring callback
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer2", "backbone.layer3", "backbone.layer4"],
log_every_n_epochs=1,
num_batches=10,
)
# Create model and trainer
model = ImageClassifier(backbone="resnet50", num_classes=10)
trainer = AutoTrainer(
max_epochs=100,
callbacks=[interp_callback, feature_callback],
logger="tensorboard",
)
trainer.fit(model, datamodule=data)
Use Cases¶
1. Debugging Model Learning¶
Check if model is learning meaningful patterns:
# Monitor early in training
interp_callback = InterpretationCallback(
sample_images=sample_images,
method="gradcam",
log_every_n_epochs=1, # Log every epoch
)
What to look for: - Epoch 1-10: Random/noisy heatmaps - Epoch 10-30: Heatmaps start focusing on relevant regions - Epoch 30+: Clear, focused heatmaps on discriminative features
2. Detecting Overfitting¶
Monitor feature sparsity to detect overfitting:
Warning signs: - Rapidly increasing sparsity (>70%): Model becoming too specialized - Decreasing mean activation: Features dying out - High variance in statistics: Unstable training
3. Comparing Architectures¶
Compare how different models learn:
models_to_test = ["resnet18", "resnet50", "efficientnet_b0"]
for backbone in models_to_test:
model = ImageClassifier(backbone=backbone, num_classes=10)
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer4"],
log_every_n_epochs=1,
)
trainer = AutoTrainer(
max_epochs=50,
callbacks=[feature_callback],
logger=WandbLogger(project="architecture_comparison", name=backbone),
)
trainer.fit(model, datamodule=data)
4. Transfer Learning Analysis¶
Monitor how features adapt during fine-tuning:
# Start with pretrained model
model = ImageClassifier(backbone="resnet50", num_classes=10, pretrained=True)
# Monitor features during fine-tuning
callbacks = [
InterpretationCallback(sample_images, log_every_n_epochs=1),
FeatureMonitorCallback(["backbone.layer4"], log_every_n_epochs=1),
]
trainer = AutoTrainer(max_epochs=30, callbacks=callbacks)
trainer.fit(model, datamodule=data)
What to look for: - Early epochs: Gradual shift in heatmap focus - Feature statistics: Small changes (good), large changes (may need lower LR)
5. Curriculum Learning¶
Adjust sample images as training progresses:
class DynamicInterpretationCallback(InterpretationCallback):
def on_train_epoch_start(self, trainer, pl_module):
# Update sample images based on current performance
if trainer.current_epoch % 10 == 0:
# Load harder examples
self.sample_images = load_challenging_samples()
super().on_train_epoch_start(trainer, pl_module)
Performance Considerations¶
Callback Overhead¶
InterpretationCallback:
- Forward pass per image: ~10-50ms (depending on model size)
- Total overhead per log: num_samples * forward_time + visualization_time
- Recommendation: Log every 5-10 epochs for balance
FeatureMonitorCallback:
- Hook overhead: Minimal (<1% of training time)
- Accumulation: Proportional to num_batches
- Recommendation: Monitor 3-5 key layers, use 10-20 batches
Optimizing Logging Frequency¶
# For quick experiments (low overhead)
interp_callback = InterpretationCallback(
sample_images=sample_images[:4], # Fewer samples
log_every_n_epochs=10, # Less frequent
)
# For detailed analysis (higher overhead)
interp_callback = InterpretationCallback(
sample_images=sample_images,
log_every_n_epochs=1,
)
Memory Management¶
# Reduce memory usage
feature_callback = FeatureMonitorCallback(
layer_names=["backbone.layer4"], # Monitor fewer layers
num_batches=5, # Use fewer batches
)
Troubleshooting¶
For interpretation callback issues, see the Troubleshooting - Interpretation including:
- Interpretations not logged
- Feature monitoring not working
- High memory usage
- Slow training with callbacks log_every_n_epochs=10, # Less frequent )
Monitor fewer layers¶
feature_callback = FeatureMonitorCallback( layer_names=["backbone.layer4"], # Only final layer )
---
## Advanced Customization
### Custom Interpretation Method
```python
from autotimm.interpretation import InterpretationCallback
from autotimm.interpretation import IntegratedGradients
class CustomInterpretationCallback(InterpretationCallback):
def on_train_start(self, trainer, pl_module):
# Use custom method
self.explainer = IntegratedGradients(
pl_module,
baseline='blur',
steps=30,
)
Custom Logging Logic¶
class CustomFeatureCallback(FeatureMonitorCallback):
def _compute_and_log_statistics(self, trainer):
super()._compute_and_log_statistics(trainer)
# Add custom metrics
for name, acts in self.activations.items():
if len(acts) > 0:
all_acts = torch.cat(acts, dim=0)
# Custom metric: L1 norm
l1_norm = all_acts.abs().mean().item()
trainer.logger.log_metrics(
{f"features/{name}/l1_norm": l1_norm},
step=trainer.global_step
)
See Also¶
- Feature Visualization - Analyze features after training
- Interpretation Methods - Available interpretation methods
- Main Guide - Overview and quick start