Skip to content

Semantic Segmentation

AutoTimm provides state-of-the-art semantic segmentation with DeepLabV3+ and FCN architectures.

Overview

Semantic segmentation assigns a class label to every pixel in an image. AutoTimm supports:

  • Architectures: DeepLabV3+ (ASPP + decoder), FCN
  • Backbones: Any timm model with multi-scale features (ResNet, EfficientNet, etc.)
  • Losses: CrossEntropy, Dice, Focal, Combined (CE + Dice), Tversky
  • Datasets: PNG masks, COCO stuff, Cityscapes, Pascal VOC

Quick Start

from autotimm import SemanticSegmentor, SegmentationDataModule, MetricConfig, AutoTrainer

# Data
data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",  # or "png", "coco", "voc"
    image_size=512,
    batch_size=8,
    augmentation_preset="default",
)

# Metrics
metrics = [
    MetricConfig(
        name="iou",
        backend="torchmetrics",
        metric_class="JaccardIndex",
        params={
            "task": "multiclass",
            "num_classes": 19,
            "average": "macro",
            "ignore_index": 255,
        },
        stages=["val", "test"],
        prog_bar=True,
    ),
]

# Model
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",  # or "fcn"
    loss_type="combined",        # CE + Dice
    dice_weight=1.0,
    metrics=metrics,
)

# Train
trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=data)

Architectures

DeepLabV3+

DeepLabV3+ combines ASPP (Atrous Spatial Pyramid Pooling) with a decoder for high-quality segmentation.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=21,
    head_type="deeplabv3plus",
    # DeepLabV3+ uses both high-level (C5) and low-level (C2) features
)

Features:

  • ASPP module with multiple dilation rates (6, 12, 18)
  • Low-level feature fusion (C2 + ASPP output)
  • Decoder with 3x3 convolutions
  • Output stride: 4 (¼ of input resolution)

FCN (Fully Convolutional Network)

A simpler baseline architecture for comparison.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=21,
    head_type="fcn",
    # FCN uses only the highest-level feature (C5)
)

Features:

  • Single high-level feature processing
  • Lightweight and fast
  • Good baseline for simple datasets

Loss Functions

Cross-Entropy

Standard pixel-wise classification loss.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    loss_type="ce",
    ignore_index=255,  # Ignore unlabeled pixels
)

Dice Loss

Overlap-based loss that handles class imbalance well.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    loss_type="dice",
)

Formula: 1 - (2 * |X ∩ Y|) / (|X| + |Y|)

Combined Loss

Best of both worlds: CE for pixel-wise accuracy + Dice for overlap.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    loss_type="combined",
    ce_weight=1.0,
    dice_weight=1.0,
)

Focal Loss

Handles severe class imbalance by down-weighting easy examples.

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    loss_type="focal",
)

Tversky Loss

Generalization of Dice with separate control over false positives and false negatives.

from autotimm.loss import TverskyLoss

# Use directly (not via loss_type parameter)
loss_fn = TverskyLoss(
    num_classes=19,
    alpha=0.3,  # Weight for false positives
    beta=0.7,   # Weight for false negatives
)

Datasets

PNG Format

Simple image + mask pairs.

data/
  train/
    images/
      img001.jpg
      img002.jpg
    masks/
      img001.png
      img002.png
  val/
    images/
    masks/
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    image_size=512,
    batch_size=8,
)

Cityscapes

data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",
    image_size=512,
    batch_size=8,
)

Expected structure:

cityscapes/
  leftImg8bit/
    train/
    val/
  gtFine/
    train/
    val/

Pascal VOC

data = SegmentationDataModule(
    data_dir="./VOC2012",
    format="voc",
    image_size=512,
    batch_size=8,
)

COCO Stuff

data = SegmentationDataModule(
    data_dir="./coco",
    format="coco",
    image_size=512,
    batch_size=8,
)

Data Augmentation

Presets

# Light augmentation
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    image_size=512,
    augmentation_preset="light",
)

# Default augmentation
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    augmentation_preset="default",
)

# Strong augmentation
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    augmentation_preset="strong",
)

Custom Transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

transforms = A.Compose([
    A.RandomScale(scale_limit=0.5),
    A.RandomCrop(height=512, width=512),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.4, contrast=0.4),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    custom_train_transforms=transforms,
)

Metrics

IoU (Jaccard Index)

MetricConfig(
    name="iou",
    backend="torchmetrics",
    metric_class="JaccardIndex",
    params={
        "task": "multiclass",
        "num_classes": 19,
        "average": "macro",  # or "micro", "weighted", None
        "ignore_index": 255,
    },
    stages=["val", "test"],
    prog_bar=True,
)

Per-Class IoU

MetricConfig(
    name="iou_per_class",
    backend="torchmetrics",
    metric_class="JaccardIndex",
    params={
        "task": "multiclass",
        "num_classes": 19,
        "average": None,  # Returns per-class scores
    },
    stages=["val"],
)

Pixel Accuracy

MetricConfig(
    name="accuracy",
    backend="torchmetrics",
    metric_class="Accuracy",
    params={
        "task": "multiclass",
        "num_classes": 19,
        "ignore_index": 255,
    },
    stages=["val", "test"],
)

Inference

Predict on Images

import torch
from PIL import Image
from torchvision import transforms

# Load model
model = SemanticSegmentor.load_from_checkpoint("best_model.ckpt", compile_model=False)
model.eval()

# Load and preprocess image
image = Image.open("test.jpg")
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
image_tensor = transform(image).unsqueeze(0)

# Predict
with torch.inference_mode():
    predictions = model.predict(image_tensor)

# predictions shape: [1, H, W] with class indices

Batch Prediction

images = torch.randn(4, 3, 512, 512)  # Batch of 4 images

with torch.inference_mode():
    predictions = model.predict(images)
    # Shape: [4, H, W]

Get Probabilities

with torch.inference_mode():
    logits = model.predict(images, return_logits=True)
    probs = torch.softmax(logits, dim=1)
    # Shape: [B, num_classes, H, W]

Advanced Options

Freeze Backbone

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    freeze_backbone=True,  # Only train the head
)

Custom Optimizer

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    optimizer="adamw",
    lr=1e-4,
    weight_decay=1e-5,
)

Learning Rate Scheduler

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    scheduler="cosine",
    scheduler_kwargs={
        "T_max": 100,  # epochs
    },
)

Class Weights

For imbalanced datasets:

import torch

class_weights = torch.tensor([1.0, 2.5, 1.5, ...])  # One per class

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    loss_type="ce",
    # Note: Pass via loss initialization, not directly to model
)

torch.compile (PyTorch 2.0+)

Enabled by default for faster training and inference:

# Default: torch.compile enabled
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
)

# Disable if needed
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    compile_model=False,
)

# Custom compile options
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    compile_kwargs={"mode": "reduce-overhead"},
)

What gets compiled: Backbone + Segmentation Head

See ImageClassifier for compile mode details.

Best Practices

1. Choose the Right Architecture

  • DeepLabV3+: Best quality, slower training
  • FCN: Faster, good for simple datasets or baselines

2. Select Appropriate Loss

  • Combined (CE + Dice): Best overall performance
  • Dice only: Good for class imbalance
  • CE only: Fast, works well when classes are balanced
  • Focal: Severe class imbalance

3. Handle Unlabeled Pixels

Always set ignore_index=255 for datasets with unlabeled regions:

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    ignore_index=255,
)

4. Use Appropriate Backbones

  • ResNet-50/101: Good balance of speed and accuracy
  • EfficientNet: More efficient for similar accuracy
  • MobileNet: Fastest, lower accuracy
  • ConvNeXt/Swin: Highest accuracy, slower

5. Image Size Considerations

  • Larger images = better quality but slower and more memory
  • Typical sizes: 512x512 (balanced), 768x768 (high quality), 1024x1024 (best quality)
data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    image_size=512,  # Adjust based on your needs
)

Example: Cityscapes Training

Complete example for Cityscapes dataset:

from autotimm import SemanticSegmentor, SegmentationDataModule, MetricConfig, AutoTrainer

# Data
data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",
    image_size=512,
    batch_size=8,
    num_workers=4,
    augmentation_preset="default",
)

# Metrics
metrics = [
    MetricConfig(
        name="mIoU",
        backend="torchmetrics",
        metric_class="JaccardIndex",
        params={
            "task": "multiclass",
            "num_classes": 19,
            "average": "macro",
            "ignore_index": 255,
        },
        stages=["val", "test"],
        prog_bar=True,
    ),
    MetricConfig(
        name="pixel_acc",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={
            "task": "multiclass",
            "num_classes": 19,
            "ignore_index": 255,
        },
        stages=["val", "test"],
    ),
]

# Model
model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    loss_type="combined",
    ce_weight=1.0,
    dice_weight=1.0,
    ignore_index=255,
    metrics=metrics,
    lr=1e-4,
    weight_decay=1e-5,
    optimizer="adamw",
    scheduler="cosine",
)

# Trainer
trainer = AutoTrainer(
    max_epochs=200,
    accelerator="auto",
    devices=1,
    precision="16-mixed",  # Mixed precision training
)

# Train
trainer.fit(model, datamodule=data)

Troubleshooting

For semantic segmentation training issues, see the Troubleshooting Overview including:

  • Out of memory
  • Poor segmentation results
  • Loss function selection
  • Performance optimization

Out of Memory

Reduce batch size or image size:

data = SegmentationDataModule(
    data_dir="./data",
    format="png",
    image_size=384,  # Smaller
    batch_size=4,    # Smaller
)

Slow Training

  • Use smaller backbone (e.g., ResNet-18, MobileNet)
  • Use FCN instead of DeepLabV3+
  • Enable mixed precision training
  • Reduce image size

Poor Accuracy

  • Try combined loss instead of CE only
  • Increase training epochs
  • Use stronger augmentation
  • Try larger backbone
  • Check for class imbalance

Class Imbalance

  • Use Dice or Focal loss
  • Add class weights
  • Use stronger augmentation for minority classes

API Reference

See SemanticSegmentor API for complete parameter documentation.