Instance Segmentation¶
AutoTimm provides Mask R-CNN style instance segmentation combining FCOS object detection with mask prediction.
Overview¶
Instance segmentation detects objects and predicts pixel-precise masks for each instance. AutoTimm supports:
- Architecture: FCOS detection + Mask Head (Mask R-CNN style)
- Backbones: Any timm model with FPN features
- Losses: Detection loss (classification + bbox + centerness) + Binary mask loss
- Datasets: COCO instance segmentation format
- Metrics: Mask mAP, bbox mAP (torchmetrics)
Quick Start¶
from autotimm import InstanceSegmentor, InstanceSegmentationDataModule, MetricConfig, AutoTrainer
# Data
data = InstanceSegmentationDataModule(
data_dir="./coco",
image_size=640,
batch_size=4,
)
# Metrics
metrics = [
MetricConfig(
name="mask_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "segm"},
stages=["val", "test"],
prog_bar=True,
),
]
# Model
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
metrics=metrics,
lr=1e-4,
mask_loss_weight=1.0,
)
# Train
trainer = AutoTrainer(max_epochs=12)
trainer.fit(model, datamodule=data)
Architecture¶
Detection Branch¶
InstanceSegmentor uses the same FCOS detection head as ObjectDetector: - Classification head (80 classes for COCO) - Bounding box regression head - Centerness head
Mask Branch¶
The mask head predicts pixel-precise masks for each detected instance: - Takes ROI-aligned features from FPN - Applies 4 conv layers → deconv → 1x1 conv - Outputs [N, num_classes, mask_size, mask_size] - Default mask_size: 28x28 (upsampled to bbox size during inference)
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
fpn_channels=256,
mask_size=28, # ROI mask resolution
)
Loss Functions¶
Detection Loss¶
Same as ObjectDetector: - Focal Loss: Classification with class imbalance handling - GIoU Loss: Bounding box regression - Centerness Loss: Quality estimation
Mask Loss¶
Binary cross-entropy for per-instance masks:
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
mask_loss_weight=1.0, # Weight for mask loss
)
The mask loss is only computed for positive (detected) instances.
Dataset Format¶
COCO Instance Format¶
AutoTimm uses COCO JSON format with segmentation annotations:
coco/
train2017/
000000000001.jpg
000000000002.jpg
val2017/
annotations/
instances_train2017.json
instances_val2017.json
The JSON annotations include: - Bounding boxes (COCO format: [x, y, width, height]) - Segmentation masks (RLE or polygon format) - Category IDs
Mask Formats¶
COCO supports two mask formats: 1. RLE (Run-Length Encoding): Compressed binary masks 2. Polygon: List of [x, y] coordinates
AutoTimm automatically decodes both formats using pycocotools.
Data Augmentation¶
Presets¶
# Light augmentation
data = InstanceSegmentationDataModule(
data_dir="./coco",
image_size=640,
augmentation_preset="light",
)
# Default augmentation
data = InstanceSegmentationDataModule(
data_dir="./coco",
augmentation_preset="default",
)
# Strong augmentation
data = InstanceSegmentationDataModule(
data_dir="./coco",
augmentation_preset="strong",
)
Custom Transforms¶
import albumentations as A
from albumentations.pytorch import ToTensorV2
transforms = A.Compose([
A.RandomScale(scale_limit=0.1),
A.HorizontalFlip(p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))
data = InstanceSegmentationDataModule(
data_dir="./coco",
custom_train_transforms=transforms,
)
Note: Masks are automatically transformed alongside boxes.
Metrics¶
Mask mAP¶
Primary metric for instance segmentation:
MetricConfig(
name="mask_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "segm"},
stages=["val", "test"],
prog_bar=True,
)
Bbox mAP¶
Also track detection performance:
MetricConfig(
name="bbox_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "bbox"},
stages=["val", "test"],
)
Inference¶
Predict on Images¶
import torch
from PIL import Image
from torchvision import transforms as T
# Load model
model = InstanceSegmentor.load_from_checkpoint("best_model.ckpt", compile_model=False)
model.eval()
# Preprocess image
image = Image.open("test.jpg")
transform = T.Compose([
T.Resize((640, 640)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.inference_mode():
predictions = model.predict(image_tensor)
# predictions is a list of dicts, one per image:
# {
# 'boxes': [N, 4], # xyxy format
# 'labels': [N], # class indices
# 'scores': [N], # confidence scores
# 'masks': [N, H, W] # binary masks (0 or 1)
# }
for i, pred in enumerate(predictions):
print(f"Image {i}: {len(pred['boxes'])} instances")
for box, label, score, mask in zip(pred['boxes'], pred['labels'], pred['scores'], pred['masks']):
print(f" Class {label}, score {score:.3f}, mask size {mask.shape}")
Visualize Results¶
import matplotlib.pyplot as plt
import numpy as np
def visualize_instance_segmentation(image, prediction, threshold=0.5):
"""Visualize instance segmentation results."""
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
ax.imshow(image)
boxes = prediction['boxes']
labels = prediction['labels']
scores = prediction['scores']
masks = prediction['masks']
# Filter by score threshold
keep = scores > threshold
boxes = boxes[keep]
labels = labels[keep]
scores = scores[keep]
masks = masks[keep]
# Draw boxes and masks
for box, label, score, mask in zip(boxes, labels, scores, masks):
x1, y1, x2, y2 = box
# Draw box
rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
fill=False, color='red', linewidth=2)
ax.add_patch(rect)
# Draw label
ax.text(x1, y1-5, f"Class {label}: {score:.2f}",
color='white', fontsize=10,
bbox=dict(facecolor='red', alpha=0.5))
# Overlay mask
mask_overlay = np.zeros((*mask.shape, 4))
mask_overlay[mask > 0.5] = [1, 0, 0, 0.4] # Red with alpha
ax.imshow(mask_overlay)
plt.axis('off')
plt.tight_layout()
plt.show()
# Use it
with torch.inference_mode():
predictions = model.predict(image_tensor)
visualize_instance_segmentation(image, predictions[0])
Advanced Options¶
Freeze Backbone¶
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
freeze_backbone=True, # Only train detection + mask heads
)
Custom Optimizer¶
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
optimizer="adamw",
lr=1e-4,
weight_decay=1e-5,
)
Detection Parameters¶
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
score_thresh=0.05, # Minimum confidence
nms_thresh=0.5, # NMS IoU threshold
max_detections_per_image=100, # Max instances per image
)
torch.compile (PyTorch 2.0+)¶
Enabled by default for faster training and inference:
# Default: torch.compile enabled
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
)
# Disable if needed
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
compile_model=False,
)
# Custom compile options
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
compile_kwargs={"mode": "reduce-overhead"},
)
What gets compiled: Backbone + FPN + Detection Head + Mask Head
See ImageClassifier for compile mode details.
Best Practices¶
1. Choose Appropriate Image Size¶
Larger images capture more detail but require more memory:
# Balanced (default)
data = InstanceSegmentationDataModule(data_dir="./coco", image_size=640)
# High quality (more memory)
data = InstanceSegmentationDataModule(data_dir="./coco", image_size=800)
# Fast training (less accurate)
data = InstanceSegmentationDataModule(data_dir="./coco", image_size=512)
2. Adjust Mask Loss Weight¶
Balance detection and mask quality:
# Emphasize mask quality
model = InstanceSegmentor(backbone="resnet50", num_classes=80, mask_loss_weight=2.0)
# Emphasize detection
model = InstanceSegmentor(backbone="resnet50", num_classes=80, mask_loss_weight=0.5)
3. Use Appropriate Backbones¶
- ResNet-50/101: Good balance
- Swin Transformer: Best accuracy
- EfficientNet: Memory efficient
- ResNet-18: Fast training/inference
4. Gradient Clipping¶
Instance segmentation can have unstable gradients:
Example: COCO Training¶
Complete example for COCO instance segmentation:
from autotimm import InstanceSegmentor, InstanceSegmentationDataModule, MetricConfig, AutoTrainer
# Data
data = InstanceSegmentationDataModule(
data_dir="./coco",
image_size=640,
batch_size=4,
num_workers=4,
augmentation_preset="default",
)
# Metrics
metrics = [
MetricConfig(
name="mask_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "segm"},
stages=["val", "test"],
prog_bar=True,
),
MetricConfig(
name="bbox_mAP",
backend="torchmetrics",
metric_class="MeanAveragePrecision",
params={"box_format": "xyxy", "iou_type": "bbox"},
stages=["val", "test"],
),
]
# Model
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
fpn_channels=256,
mask_size=28,
mask_loss_weight=1.0,
score_thresh=0.05,
nms_thresh=0.5,
metrics=metrics,
lr=1e-4,
weight_decay=1e-5,
optimizer="adamw",
scheduler="cosine",
)
# Trainer
trainer = AutoTrainer(
max_epochs=12,
accelerator="auto",
devices=1,
precision="16-mixed",
gradient_clip_val=1.0,
)
# Train
trainer.fit(model, datamodule=data)
trainer.test(model, datamodule=data)
Troubleshooting¶
For instance segmentation issues, see the Troubleshooting Overview including:
- Out of memory
- Slow training
- Poor mask quality
- NaN loss issues
API Reference¶
See InstanceSegmentor API for complete parameter documentation.