Task-Specific Interpretation¶
AutoTimm provides specialized interpretation functions for object detection and semantic segmentation tasks.
Overview¶
While classification interpretation is straightforward, detection and segmentation tasks require specialized handling:
- Object Detection: Explain individual detections with bbox-aware visualization
- Semantic Segmentation: Explain class predictions with uncertainty quantification
Object Detection Interpretation¶
Function: explain_detection()¶
import autotimm as at # recommended alias
from autotimm.interpretation import explain_detection
results = explain_detection(
model: nn.Module,
image: Union[str, Image.Image, np.ndarray, torch.Tensor],
method: str = 'gradcam',
detection_threshold: float = 0.5,
target_layer: Optional[Union[str, nn.Module]] = None,
bbox: Optional[List[float]] = None,
colormap: str = 'viridis',
alpha: float = 0.4,
save_path: Optional[str] = None,
max_detections: int = 10
) -> Dict
Parameters:
model(nn.Module): Object detection model (ObjectDetector, YOLOXDetector, etc.)image: Input imagemethod(str): Interpretation method ('gradcam', 'gradcam++', 'integrated_gradients')detection_threshold(float): Minimum confidence score (default: 0.5)target_layer(Optional): Layer for interpretation (None = auto-detect)bbox(Optional[List[float]]): Specific bbox to explain [x1, y1, x2, y2] (None = all)colormap(str): Colormap for heatmapsalpha(float): Overlay transparencysave_path(Optional[str]): Path to save visualizationmax_detections(int): Maximum detections to explain (default: 10)
Returns:
- Dict: Contains:
- 'detections': List of detection dictionaries
- 'heatmaps': Heatmap for each detection
- 'visualization': Combined visualization (if save_path provided)
Basic Usage¶
from autotimm import ObjectDetector
from autotimm.interpretation import explain_detection
from PIL import Image
# Load model
model = ObjectDetector.load_from_checkpoint("detector.ckpt", compile_model=False)
# Load image
image = Image.open("street.jpg")
# Explain all detections
results = explain_detection(
model,
image,
method='gradcam',
detection_threshold=0.5,
save_path='detection_explanation.png'
)
# Print detections
for i, det in enumerate(results['detections']):
print(f"Detection {i}:")
print(f" Class: {det['class_id']} ({det['class_name']})")
print(f" Confidence: {det['score']:.3f}")
print(f" BBox: {det['bbox']}")
Explain Specific Detection¶
Explain only a specific bounding box:
# Define bbox of interest [x1, y1, x2, y2]
bbox = [100, 150, 300, 400]
results = explain_detection(
model,
image,
bbox=bbox,
method='gradcam',
save_path='single_detection.png'
)
High-Confidence Detections Only¶
Filter to only explain high-confidence detections:
results = explain_detection(
model,
image,
detection_threshold=0.8, # Only explain confident detections
max_detections=5, # Limit to top 5
save_path='confident_detections.png'
)
With GradCAM++¶
Use GradCAM++ for better localization of multiple objects:
results = explain_detection(
model,
image,
method='gradcam++',
detection_threshold=0.5,
save_path='detection_gradcampp.png'
)
Understanding the Output¶
The returned dictionary contains:
results = {
'detections': [
{
'class_id': 2,
'class_name': 'car',
'score': 0.89,
'bbox': [120.5, 180.2, 340.8, 420.6], # [x1, y1, x2, y2]
},
# ... more detections
],
'heatmaps': [
np.ndarray, # Heatmap for first detection
# ... more heatmaps
],
'visualization': np.ndarray, # Combined visualization (if save_path)
}
Visualization Format¶
The saved visualization includes: - Original image with bounding boxes - Class labels and confidence scores - Heatmap overlays for each detection (masked to bbox) - Color-coded bboxes per detection
Semantic Segmentation Interpretation¶
Function: explain_segmentation()¶
from autotimm.interpretation import explain_segmentation
results = explain_segmentation(
model: nn.Module,
image: Union[str, Image.Image, np.ndarray, torch.Tensor],
target_class: Optional[int] = None,
method: str = 'gradcam',
target_layer: Optional[Union[str, nn.Module]] = None,
show_uncertainty: bool = False,
uncertainty_method: str = 'entropy',
colormap: str = 'viridis',
alpha: float = 0.4,
save_path: Optional[str] = None
) -> Dict
Parameters:
model(nn.Module): Semantic segmentation model (SemanticSegmentor, etc.)image: Input imagetarget_class(Optional[int]): Specific class to explain (None = all classes)method(str): Interpretation methodtarget_layer(Optional): Layer for interpretationshow_uncertainty(bool): Whether to visualize prediction uncertaintyuncertainty_method(str): Uncertainty quantification method. Options:'entropy'(Shannon entropy of class probabilities),'margin'(difference between top-2 probabilities)colormap(str): Colormap for heatmapsalpha(float): Overlay transparencysave_path(Optional[str]): Path to save visualization
Returns:
- Dict: Contains:
- 'prediction': Predicted segmentation mask
- 'heatmap': Attribution heatmap (if target_class specified)
- 'uncertainty': Uncertainty map (if show_uncertainty=True)
- 'visualization': Combined visualization
Basic Usage¶
from autotimm import SemanticSegmentor
from autotimm.interpretation import explain_segmentation
from PIL import Image
# Load model
model = SemanticSegmentor.load_from_checkpoint("segmentor.ckpt", compile_model=False)
# Load image
image = Image.open("cityscape.jpg")
# Explain segmentation
results = explain_segmentation(
model,
image,
save_path='segmentation_explanation.png'
)
# Check predictions
mask = results['prediction']
unique_classes = np.unique(mask)
print(f"Detected classes: {unique_classes}")
Explain Specific Class¶
Focus on a particular class (e.g., "road" or "car"):
# Explain class 0 (e.g., "road")
results = explain_segmentation(
model,
image,
target_class=0,
method='gradcam',
save_path='road_explanation.png'
)
# The heatmap shows what influenced the "road" predictions
heatmap = results['heatmap']
With Uncertainty Quantification¶
Visualize model uncertainty to identify ambiguous regions:
results = explain_segmentation(
model,
image,
show_uncertainty=True,
uncertainty_method='entropy',
save_path='segmentation_with_uncertainty.png'
)
# High uncertainty regions (entropy-based)
uncertainty = results['uncertainty']
high_uncertainty = uncertainty > 0.8 # Threshold
print(f"High uncertainty pixels: {high_uncertainty.sum()}")
Entropy vs. Margin Uncertainty¶
Entropy: Measures overall prediction confidence across all classes - High entropy = model uncertain across many classes - Low entropy = model confident in prediction
Margin: Measures gap between top-2 predictions - Low margin = close call between two classes - High margin = clear winner
# Entropy uncertainty (overall confidence)
results_entropy = explain_segmentation(
model,
image,
show_uncertainty=True,
uncertainty_method='entropy',
save_path='uncertainty_entropy.png'
)
# Margin uncertainty (binary confusion)
results_margin = explain_segmentation(
model,
image,
show_uncertainty=True,
uncertainty_method='margin',
save_path='uncertainty_margin.png'
)
Multi-Class Analysis¶
Analyze multiple classes of interest:
import matplotlib.pyplot as plt
classes_of_interest = [0, 1, 2] # e.g., road, sidewalk, building
class_names = ['road', 'sidewalk', 'building']
fig, axes = plt.subplots(1, len(classes_of_interest), figsize=(15, 5))
for idx, (class_id, class_name) in enumerate(zip(classes_of_interest, class_names)):
results = explain_segmentation(
model,
image,
target_class=class_id,
method='gradcam',
)
axes[idx].imshow(results['heatmap'], cmap='hot')
axes[idx].set_title(f'{class_name} (class {class_id})')
axes[idx].axis('off')
plt.savefig('multi_class_analysis.png')
Understanding the Output¶
results = {
'prediction': np.ndarray, # Shape (H, W), dtype int (class IDs)
'heatmap': np.ndarray, # Shape (H, W), dtype float [0, 1] (if target_class)
'uncertainty': np.ndarray, # Shape (H, W), dtype float [0, 1] (if show_uncertainty)
'visualization': np.ndarray, # Shape (H, W, 3), dtype uint8 (if save_path)
}
Visualization Format¶
The saved visualization includes: - Top left: Original image - Top right: Predicted segmentation mask (colored by class) - Bottom left: Interpretation heatmap (if target_class specified) - Bottom right: Uncertainty map (if show_uncertainty=True)
Advanced Use Cases¶
1. False Positive Analysis (Detection)¶
Investigate why model detected a false positive:
# Get all detections
results = explain_detection(model, image, detection_threshold=0.3)
# Analyze low-confidence detections (potential false positives)
for det in results['detections']:
if det['score'] < 0.5:
print(f"Low-confidence detection:")
print(f" Class: {det['class_name']}")
print(f" Score: {det['score']:.3f}")
# Check heatmap to see what triggered detection
2. Boundary Refinement Analysis (Segmentation)¶
Identify where segmentation boundaries are uncertain:
results = explain_segmentation(
model,
image,
show_uncertainty=True,
uncertainty_method='margin',
)
uncertainty = results['uncertainty']
# Find uncertain boundaries
from scipy import ndimage
edges = ndimage.sobel(results['prediction'])
uncertain_edges = (edges > 0) & (uncertainty > 0.7)
print(f"Uncertain boundary pixels: {uncertain_edges.sum()}")
3. Class Confusion Analysis (Segmentation)¶
Understand which classes the model confuses:
# Get predictions and uncertainty
results = explain_segmentation(
model,
image,
show_uncertainty=True,
uncertainty_method='entropy',
)
prediction = results['prediction']
uncertainty = results['uncertainty']
# Find high-uncertainty regions for each class
for class_id in range(num_classes):
class_mask = (prediction == class_id)
class_uncertainty = uncertainty[class_mask]
if len(class_uncertainty) > 0:
print(f"Class {class_id}:")
print(f" Mean uncertainty: {class_uncertainty.mean():.3f}")
print(f" Highly uncertain pixels: {(class_uncertainty > 0.8).sum()}")
4. Multi-Scale Detection Analysis¶
Analyze detections at different scales:
from torchvision import transforms
scales = [0.5, 1.0, 1.5]
for scale in scales:
# Resize image
resize = transforms.Resize((int(image.height * scale), int(image.width * scale)))
scaled_image = resize(image)
# Explain detections
results = explain_detection(
model,
scaled_image,
detection_threshold=0.5,
save_path=f'detection_scale_{scale}.png'
)
print(f"Scale {scale}: {len(results['detections'])} detections")
5. Temporal Consistency (Video)¶
Check interpretation consistency across video frames:
import cv2
video = cv2.VideoCapture("video.mp4")
frame_heatmaps = []
for frame_idx in range(0, 100, 10): # Sample every 10 frames
video.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = video.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
results = explain_detection(
model,
frame_pil,
detection_threshold=0.5,
)
if results['heatmaps']:
frame_heatmaps.append(results['heatmaps'][0]) # First detection
# Analyze temporal stability
heatmap_diffs = []
for i in range(len(frame_heatmaps) - 1):
diff = np.abs(frame_heatmaps[i+1] - frame_heatmaps[i]).mean()
heatmap_diffs.append(diff)
print(f"Mean heatmap change: {np.mean(heatmap_diffs):.3f}")
Performance Considerations¶
Object Detection¶
Factors affecting speed: - Number of detections (more = slower) - Image resolution (larger = slower) - Interpretation method (GradCAM fastest, Integrated Gradients slowest)
Optimization tips:
# Process only top detections
results = explain_detection(
model,
image,
max_detections=5, # Limit explanations
detection_threshold=0.7, # Higher threshold
)
# Use faster method
results = explain_detection(
model,
image,
method='gradcam', # Faster than integrated_gradients
)
Semantic Segmentation¶
Factors affecting speed: - Image resolution (major factor) - Target class (None = faster, specific class = slower) - Uncertainty calculation (adds overhead)
Optimization tips:
# Disable uncertainty for speed
results = explain_segmentation(
model,
image,
show_uncertainty=False, # Faster
)
# Resize image for faster processing
from torchvision import transforms
resize = transforms.Resize((512, 512))
small_image = resize(image)
results = explain_segmentation(model, small_image)
Troubleshooting¶
For task-specific interpretation issues, see the Troubleshooting - Interpretation including:
- Detection: No explanations generated
- Detection: Poor localization
- Segmentation: Blank heatmaps
- High memory usage
- Process fewer detections
- Disable uncertainty calculation
# Resize image
from torchvision import transforms
resize = transforms.Resize((512, 512))
small_image = resize(image)
# Explain with reduced memory
results = explain_detection(
model,
small_image,
max_detections=3, # Fewer detections
)
Comparison with Classification¶
| Aspect | Classification | Detection | Segmentation |
|---|---|---|---|
| Output | Single class | Multiple bboxes | Per-pixel class |
| Target | Whole image | Per-detection | Per-class or all |
| Heatmap | Single | Per-bbox | Single or per-class |
| Uncertainty | N/A | Per-detection confidence | Per-pixel entropy/margin |
| Speed | Fast | Medium (× num_detections) | Slower (depends on resolution) |
See Also¶
- Interpretation Methods - Available interpretation methods
- Feature Visualization - Analyze learned features
- Main Guide - Overview and quick start