import torch
import torchvision.transforms as T
from torchvision.transforms import functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
Albumentations vs TorchVision Transforms: Complete Code Guide
Overview
This guide compares two popular image augmentation libraries for PyTorch:
- TorchVision Transforms: Built-in PyTorch library for basic image transformations
- Albumentations: Fast, flexible library with advanced augmentation techniques
Installation
# TorchVision (comes with PyTorch)
pip install torch torchvision
# Albumentations
pip install albumentations
Basic Setup and Imports
Key Differences at a Glance
Feature | TorchVision | Albumentations |
---|---|---|
Input Format | PIL Image, Tensor | NumPy array (OpenCV format) |
Performance | Moderate | Fast (optimized) |
Augmentation Variety | Basic to intermediate | Extensive advanced options |
Bounding Box Support | Limited | Excellent |
Segmentation Masks | Basic | Advanced |
Keypoint Support | No | Yes |
Probability Control | Limited | Fine-grained |
Basic Transformations Comparison
Image Loading and Format Differences
# TorchVision approach
def load_image_torchvision(path):
return Image.open(path).convert('RGB')
# Albumentations approach
def load_image_albumentations(path):
= cv2.imread(path)
image return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Example usage
= "cat.jpg"
img_path = load_image_torchvision(img_path)
torch_img = load_image_albumentations(img_path) albu_img
Basic Augmentations
# TorchVision transforms
= T.Compose([
torchvision_transform 224, 224)),
T.Resize((=0.5),
T.RandomHorizontalFlip(p=15),
T.RandomRotation(degrees=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.ColorJitter(brightness
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean
])
# Albumentations equivalent
= A.Compose([
albumentations_transform 224, 224),
A.Resize(=0.5),
A.HorizontalFlip(p=15, p=1.0),
A.Rotate(limit=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),
A.ColorJitter(brightness=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
A.Normalize(mean
ToTensorV2()
])
# Apply transforms
= torchvision_transform(torch_img)
torch_result = albumentations_transform(image=albu_img)['image'] albu_result
Advanced Augmentations
Albumentations Exclusive Features
# Advanced geometric transformations
= A.Compose([
advanced_geometric =0.1, scale_limit=0.2, rotate_limit=30, p=0.8),
A.ShiftScaleRotate(shift_limit=1, sigma=50, alpha_affine=50, p=0.3),
A.ElasticTransform(alpha=5, distort_limit=0.3, p=0.3),
A.GridDistortion(num_steps=0.5, shift_limit=0.5, p=0.3)
A.OpticalDistortion(distort_limit
])
# Weather and lighting effects
= A.Compose([
weather_effects =-10, slant_upper=10, drop_length=20, p=0.3),
A.RandomRain(slant_lower=0.1, snow_point_upper=0.3, p=0.3),
A.RandomSnow(snow_point_lower=0.3, fog_coef_upper=1, p=0.3),
A.RandomFog(fog_coef_lower=(0, 0, 1, 0.5), angle_lower=0, p=0.3)
A.RandomSunFlare(flare_roi
])
# Noise and blur effects
= A.Compose([
noise_blur =(10.0, 50.0), p=0.3),
A.GaussNoise(var_limit=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3),
A.ISONoise(color_shift=7, p=0.3),
A.MotionBlur(blur_limit=7, p=0.3),
A.MedianBlur(blur_limit=7, p=0.3)
A.Blur(blur_limit ])
TorchVision v2 Enhanced Features
import torchvision.transforms.v2 as T2
# TorchVision v2 with better functionality
= T2.Compose([
torchvision_v2_transform 224, 224)),
T2.Resize((=0.5),
T2.RandomHorizontalFlip(p
T2.RandomChoice([=0.3),
T2.ColorJitter(brightness=0.3),
T2.ColorJitter(contrast=0.3)
T2.ColorJitter(saturation
]),=3)], p=0.3),
T2.RandomApply([T2.GaussianBlur(kernel_size
T2.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T2.Normalize(mean ])
Working with Bounding Boxes
Albumentations (Excellent Support)
# Define bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
= [[50, 50, 150, 150, 'person'], [200, 100, 300, 200, 'car']]
bboxes
= A.Compose([
bbox_transform 416, 416),
A.Resize(=0.5),
A.HorizontalFlip(p=0.3),
A.RandomBrightnessContrast(p=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
A.ShiftScaleRotate(shift_limit=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
], bbox_params
# Apply transform
= bbox_transform(
transformed =image,
image=[[50, 50, 150, 150], [200, 100, 300, 200]],
bboxes=['person', 'car']
class_labels
)
= transformed['image']
transformed_image = transformed['bboxes']
transformed_bboxes = transformed['class_labels'] transformed_labels
TorchVision (Limited Support)
# TorchVision v2 has some bbox support
import torchvision.transforms.v2 as T2
= T2.Compose([
bbox_torchvision 416, 416)),
T2.Resize((=0.5),
T2.RandomHorizontalFlip(p
T2.ToTensor()
])
# Requires manual handling of bounding boxes
# Less intuitive than Albumentations
Working with Segmentation Masks
Albumentations
# Segmentation mask handling
= A.Compose([
segmentation_transform 512, 512),
A.Resize(=0.5),
A.HorizontalFlip(p=0.3),
A.RandomBrightnessContrast(p=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
A.ShiftScaleRotate(shift_limit
A.Normalize(),
ToTensorV2()
])
# Apply to image and mask simultaneously
= segmentation_transform(image=image, mask=mask)
result = result['image']
transformed_image = result['mask'] transformed_mask
TorchVision
# TorchVision requires separate handling
def apply_transform_to_mask(transform, image, mask):
# Manual synchronization needed
= torch.randint(0, 2**32, size=(1,)).item()
seed
torch.manual_seed(seed)= transform(image)
transformed_image
torch.manual_seed(seed)# Apply only geometric transforms to mask
= T.Compose([
mask_transform 512, 512)),
T.Resize((=0.5),
T.RandomHorizontalFlip(p
T.ToTensor()
])= mask_transform(mask)
transformed_mask
return transformed_image, transformed_mask
Performance Comparison
import time
def benchmark_transforms(image, iterations=1000):
# TorchVision timing
= time.time()
start_time for _ in range(iterations):
= torchvision_transform(image.copy())
_ = time.time() - start_time
torch_time
# Convert to numpy for Albumentations
= np.array(image)
np_image
# Albumentations timing
= time.time()
start_time for _ in range(iterations):
= albumentations_transform(image=np_image.copy())
_ = time.time() - start_time
albu_time
print(f"TorchVision: {torch_time:.3f}s ({iterations} iterations)")
print(f"Albumentations: {albu_time:.3f}s ({iterations} iterations)")
print(f"Speedup: {torch_time/albu_time:.2f}x")
# Run benchmark
benchmark_transforms(torch_img)
TorchVision: 38.434s (1000 iterations)
Albumentations: 3.152s (1000 iterations)
Speedup: 12.19x
Custom Pipeline Examples
Data Science Pipeline with Albumentations
def create_training_pipeline():
return A.Compose([
# Geometric transformations
A.OneOf([=0.1, scale_limit=0.2, rotate_limit=30),
A.ShiftScaleRotate(shift_limit=1, sigma=50),
A.ElasticTransform(alpha=5, distort_limit=0.3),
A.GridDistortion(num_steps=0.5),
], p
# Color augmentations
A.OneOf([=0.3, contrast=0.3, saturation=0.3, hue=0.1),
A.ColorJitter(brightness=20, sat_shift_limit=30, val_shift_limit=20),
A.HueSaturationValue(hue_shift_limit=0.3, contrast_limit=0.3),
A.RandomBrightnessContrast(brightness_limit=0.8),
], p
# Noise and blur
A.OneOf([=(10, 50)),
A.GaussNoise(var_limit
A.ISONoise(),
A.MultiplicativeNoise(),=0.3),
], p
A.OneOf([=5),
A.MotionBlur(blur_limit=5),
A.MedianBlur(blur_limit=5),
A.GaussianBlur(blur_limit=0.3),
], p
# Final processing
224, 224),
A.Resize(=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
A.Normalize(mean
ToTensorV2()
])
def create_validation_pipeline():
return A.Compose([
224, 224),
A.Resize(=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
A.Normalize(mean
ToTensorV2() ])
TorchVision Pipeline for Simple Cases
def create_simple_training_pipeline():
return T.Compose([
224, 224)),
T.Resize((=0.5),
T.RandomHorizontalFlip(p=15),
T.RandomRotation(degrees0.3, 0.3, 0.3, 0.1)], p=0.8),
T.RandomApply([T.ColorJitter(=3)], p=0.3),
T.RandomApply([T.GaussianBlur(kernel_size
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean
])
def create_simple_validation_pipeline():
return T.Compose([
224, 224)),
T.Resize((
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean ])
Dataset Integration
PyTorch Dataset with Albumentations
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image for Albumentations (OpenCV format)
= cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image
if self.transform:
= self.transform(image=image)
transformed = transformed['image']
image
return image, self.labels[idx]
# Usage
= CustomDataset(
train_dataset =train_paths,
image_paths=train_labels,
labels=create_training_pipeline()
transform )
PyTorch Dataset with TorchVision
class TorchVisionDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image for TorchVision (PIL format)
= Image.open(self.image_paths[idx]).convert('RGB')
image
if self.transform:
= self.transform(image)
image
return image, self.labels[idx]
# Usage
= TorchVisionDataset(
train_dataset =train_paths,
image_paths=train_labels,
labels=create_simple_training_pipeline()
transform )
When to Use Which Library
Choose Albumentations When:
- Working with object detection or segmentation tasks
- Need advanced augmentation techniques (weather effects, distortions)
- Performance is critical (processing large datasets)
- Working with bounding boxes or keypoints
- Need fine-grained control over augmentation probabilities
- Dealing with medical or satellite imagery
Choose TorchVision When:
- Building simple image classification models
- Working within pure PyTorch ecosystem
- Need basic augmentations only
- Prototyping quickly
- Following PyTorch tutorials or established workflows
- Working with pre-trained models that expect specific preprocessing
Best Practices
Albumentations Best Practices
# Use ReplayCompose for debugging
= A.ReplayCompose([
replay_transform =0.5),
A.HorizontalFlip(p=0.3),
A.RandomBrightnessContrast(p
])
= replay_transform(image=image)
result = result['image']
transformed_image = result['replay']
replay_data
# Apply same transforms to another image
= A.ReplayCompose.replay(replay_data, image=another_image)
result2
# Efficient bbox handling
= A.BboxParams(
bbox_params format='pascal_voc',
=1024, # Filter out small boxes
min_area=0.3, # Filter out mostly occluded boxes
min_visibility=['class_labels']
label_fields )
TorchVision Best Practices
# Use functional API for custom control
def custom_transform(image):
if torch.rand(1) < 0.5:
= TF.hflip(image)
image
# Apply rotation with custom logic
= torch.randint(-30, 30, (1,)).item()
angle = TF.rotate(image, angle)
image
return image
# Combine with standard transforms
= T.Compose([
combined_transform
T.Lambda(custom_transform),
T.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
T.Normalize(mean ])
Conclusion
Both libraries have their strengths:
Albumentations excels in:
- Advanced augmentation techniques
- Performance optimization
- Computer vision tasks beyond classification
- Professional production environments
TorchVision is ideal for:
- Simple classification tasks
- Learning and prototyping
- Tight PyTorch integration
- Basic augmentation needs
Choose based on your specific requirements, with Albumentations being the go-to choice for advanced computer vision projects and TorchVision for simpler classification tasks.