import torch
import torchvision.transforms as T
from torchvision.transforms import functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as pltAlbumentations vs TorchVision Transforms: Complete Code Guide

Overview
This guide compares two popular image augmentation libraries for PyTorch:
- TorchVision Transforms: Built-in PyTorch library for basic image transformations
- Albumentations: Fast, flexible library with advanced augmentation techniques
Installation
# TorchVision (comes with PyTorch)
pip install torch torchvision
# Albumentations
pip install albumentationsBasic Setup and Imports
Key Differences at a Glance
| Feature | TorchVision | Albumentations |
|---|---|---|
| Input Format | PIL Image, Tensor | NumPy array (OpenCV format) |
| Performance | Moderate | Fast (optimized) |
| Augmentation Variety | Basic to intermediate | Extensive advanced options |
| Bounding Box Support | Limited | Excellent |
| Segmentation Masks | Basic | Advanced |
| Keypoint Support | No | Yes |
| Probability Control | Limited | Fine-grained |
Basic Transformations Comparison
Image Loading and Format Differences
# TorchVision approach
def load_image_torchvision(path):
return Image.open(path).convert('RGB')
# Albumentations approach
def load_image_albumentations(path):
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Example usage
img_path = "cat.jpg"
torch_img = load_image_torchvision(img_path)
albu_img = load_image_albumentations(img_path)
Basic Augmentations
# TorchVision transforms
torchvision_transform = T.Compose([
T.Resize((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(degrees=15),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Albumentations equivalent
albumentations_transform = A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=15, p=1.0),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Apply transforms
torch_result = torchvision_transform(torch_img)
albu_result = albumentations_transform(image=albu_img)['image']Advanced Augmentations
Albumentations Exclusive Features
# Advanced geometric transformations
advanced_geometric = A.Compose([
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.8),
A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=0.3)
])
# Weather and lighting effects
weather_effects = A.Compose([
A.RandomRain(slant_lower=-10, slant_upper=10, drop_length=20, p=0.3),
A.RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3, p=0.3),
A.RandomFog(fog_coef_lower=0.3, fog_coef_upper=1, p=0.3),
A.RandomSunFlare(flare_roi=(0, 0, 1, 0.5), angle_lower=0, p=0.3)
])
# Noise and blur effects
noise_blur = A.Compose([
A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3),
A.MotionBlur(blur_limit=7, p=0.3),
A.MedianBlur(blur_limit=7, p=0.3),
A.Blur(blur_limit=7, p=0.3)
])TorchVision v2 Enhanced Features
import torchvision.transforms.v2 as T2
# TorchVision v2 with better functionality
torchvision_v2_transform = T2.Compose([
T2.Resize((224, 224)),
T2.RandomHorizontalFlip(p=0.5),
T2.RandomChoice([
T2.ColorJitter(brightness=0.3),
T2.ColorJitter(contrast=0.3),
T2.ColorJitter(saturation=0.3)
]),
T2.RandomApply([T2.GaussianBlur(kernel_size=3)], p=0.3),
T2.ToTensor(),
T2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])Working with Bounding Boxes
Albumentations (Excellent Support)
# Define bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
bboxes = [[50, 50, 150, 150, 'person'], [200, 100, 300, 200, 'car']]
bbox_transform = A.Compose([
A.Resize(416, 416),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
# Apply transform
transformed = bbox_transform(
image=image,
bboxes=[[50, 50, 150, 150], [200, 100, 300, 200]],
class_labels=['person', 'car']
)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']
transformed_labels = transformed['class_labels']TorchVision (Limited Support)
# TorchVision v2 has some bbox support
import torchvision.transforms.v2 as T2
bbox_torchvision = T2.Compose([
T2.Resize((416, 416)),
T2.RandomHorizontalFlip(p=0.5),
T2.ToTensor()
])
# Requires manual handling of bounding boxes
# Less intuitive than AlbumentationsWorking with Segmentation Masks
Albumentations
# Segmentation mask handling
segmentation_transform = A.Compose([
A.Resize(512, 512),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
A.Normalize(),
ToTensorV2()
])
# Apply to image and mask simultaneously
result = segmentation_transform(image=image, mask=mask)
transformed_image = result['image']
transformed_mask = result['mask']TorchVision
# TorchVision requires separate handling
def apply_transform_to_mask(transform, image, mask):
# Manual synchronization needed
seed = torch.randint(0, 2**32, size=(1,)).item()
torch.manual_seed(seed)
transformed_image = transform(image)
torch.manual_seed(seed)
# Apply only geometric transforms to mask
mask_transform = T.Compose([
T.Resize((512, 512)),
T.RandomHorizontalFlip(p=0.5),
T.ToTensor()
])
transformed_mask = mask_transform(mask)
return transformed_image, transformed_maskPerformance Comparison
import time
def benchmark_transforms(image, iterations=1000):
# TorchVision timing
start_time = time.time()
for _ in range(iterations):
_ = torchvision_transform(image.copy())
torch_time = time.time() - start_time
# Convert to numpy for Albumentations
np_image = np.array(image)
# Albumentations timing
start_time = time.time()
for _ in range(iterations):
_ = albumentations_transform(image=np_image.copy())
albu_time = time.time() - start_time
print(f"TorchVision: {torch_time:.3f}s ({iterations} iterations)")
print(f"Albumentations: {albu_time:.3f}s ({iterations} iterations)")
print(f"Speedup: {torch_time/albu_time:.2f}x")
# Run benchmark
benchmark_transforms(torch_img)TorchVision: 37.706s (1000 iterations)
Albumentations: 2.810s (1000 iterations)
Speedup: 13.42x
Custom Pipeline Examples
Data Science Pipeline with Albumentations
def create_training_pipeline():
return A.Compose([
# Geometric transformations
A.OneOf([
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30),
A.ElasticTransform(alpha=1, sigma=50),
A.GridDistortion(num_steps=5, distort_limit=0.3),
], p=0.5),
# Color augmentations
A.OneOf([
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
], p=0.8),
# Noise and blur
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.ISONoise(),
A.MultiplicativeNoise(),
], p=0.3),
A.OneOf([
A.MotionBlur(blur_limit=5),
A.MedianBlur(blur_limit=5),
A.GaussianBlur(blur_limit=5),
], p=0.3),
# Final processing
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def create_validation_pipeline():
return A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])TorchVision Pipeline for Simple Cases
def create_simple_training_pipeline():
return T.Compose([
T.Resize((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(degrees=15),
T.RandomApply([T.ColorJitter(0.3, 0.3, 0.3, 0.1)], p=0.8),
T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.3),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def create_simple_validation_pipeline():
return T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])Dataset Integration
PyTorch Dataset with Albumentations
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image for Albumentations (OpenCV format)
image = cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
transformed = self.transform(image=image)
image = transformed['image']
return image, self.labels[idx]
# Usage
train_dataset = CustomDataset(
image_paths=train_paths,
labels=train_labels,
transform=create_training_pipeline()
)PyTorch Dataset with TorchVision
class TorchVisionDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image for TorchVision (PIL format)
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
return image, self.labels[idx]
# Usage
train_dataset = TorchVisionDataset(
image_paths=train_paths,
labels=train_labels,
transform=create_simple_training_pipeline()
)When to Use Which Library
Choose Albumentations When:
- Working with object detection or segmentation tasks
- Need advanced augmentation techniques (weather effects, distortions)
- Performance is critical (processing large datasets)
- Working with bounding boxes or keypoints
- Need fine-grained control over augmentation probabilities
- Dealing with medical or satellite imagery
Choose TorchVision When:
- Building simple image classification models
- Working within pure PyTorch ecosystem
- Need basic augmentations only
- Prototyping quickly
- Following PyTorch tutorials or established workflows
- Working with pre-trained models that expect specific preprocessing
Best Practices
Albumentations Best Practices
# Use ReplayCompose for debugging
replay_transform = A.ReplayCompose([
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.3),
])
result = replay_transform(image=image)
transformed_image = result['image']
replay_data = result['replay']
# Apply same transforms to another image
result2 = A.ReplayCompose.replay(replay_data, image=another_image)
# Efficient bbox handling
bbox_params = A.BboxParams(
format='pascal_voc',
min_area=1024, # Filter out small boxes
min_visibility=0.3, # Filter out mostly occluded boxes
label_fields=['class_labels']
)TorchVision Best Practices
# Use functional API for custom control
def custom_transform(image):
if torch.rand(1) < 0.5:
image = TF.hflip(image)
# Apply rotation with custom logic
angle = torch.randint(-30, 30, (1,)).item()
image = TF.rotate(image, angle)
return image
# Combine with standard transforms
combined_transform = T.Compose([
T.Lambda(custom_transform),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])Conclusion
Both libraries have their strengths:
Albumentations excels in:
- Advanced augmentation techniques
- Performance optimization
- Computer vision tasks beyond classification
- Professional production environments
TorchVision is ideal for:
- Simple classification tasks
- Learning and prototyping
- Tight PyTorch integration
- Basic augmentation needs
Choose based on your specific requirements, with Albumentations being the go-to choice for advanced computer vision projects and TorchVision for simpler classification tasks.