PyTorch Collate Function Speed-Up Guide

code
tutorial
advanced
Author

Krishnatheja Vanka

Published

June 1, 2025

PyTorch Collate Function Speed-Up Guide

Introduction

The collate function in PyTorch is a crucial component for optimizing data loading performance. It determines how individual samples are combined into batches, and custom implementations can significantly speed up training by reducing data preprocessing overhead and memory operations.

Default vs Custom Collate Functions

Default Behavior

import torch
from torch.utils.data import DataLoader, Dataset
import time
import numpy as np

class SimpleDataset(Dataset):
    def __init__(self, size=1000):
        self.data = [torch.randn(3, 224, 224) for _ in range(size)]
        self.labels = [torch.randint(0, 10, (1,)).item() for _ in range(size)]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Using default collate function
dataset = SimpleDataset(1000)
default_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Timing default collate
start_time = time.time()
for batch_idx, (data, labels) in enumerate(default_loader):
    if batch_idx >= 10:  # Test first 10 batches
        break
default_time = time.time() - start_time
print(f"Default collate time: {default_time:.4f} seconds")
Default collate time: 0.0129 seconds

Custom Collate Function - Basic Optimization

def fast_collate(batch):
    """Optimized collate function for image data"""
    # Separate data and labels
    data, labels = zip(*batch)
    
    # Stack tensors directly (faster than default_collate for large tensors)
    data_tensor = torch.stack(data, dim=0)
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    
    return data_tensor, labels_tensor

# Using custom collate function
custom_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=fast_collate)

# Timing custom collate
start_time = time.time()
for batch_idx, (data, labels) in enumerate(custom_loader):
    if batch_idx >= 10:
        break
custom_time = time.time() - start_time
print(f"Custom collate time: {custom_time:.4f} seconds")
print(f"Speed improvement: {(default_time/custom_time - 1) * 100:.1f}%")
Custom collate time: 0.0107 seconds
Speed improvement: 20.6%

Advanced Optimizations

Memory-Efficient Collate for Variable-Length Sequences

import torch.nn.utils.rnn as rnn_utils

class VariableLengthDataset(Dataset):
    def __init__(self, size=1000):
        # Simulate variable-length sequences
        self.data = [torch.randn(np.random.randint(10, 100), 128) for _ in range(size)]
        self.labels = [torch.randint(0, 5, (1,)).item() for _ in range(size)]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def efficient_variable_collate(batch):
    """Efficient collate for variable-length sequences"""
    data, labels = zip(*batch)
    
    # Get sequence lengths for efficient packing
    lengths = torch.tensor([len(seq) for seq in data])
    
    # Pad sequences efficiently
    padded_data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    
    return padded_data, labels_tensor, lengths

# Performance comparison
var_dataset = VariableLengthDataset(500)

# Default collate (will fail for variable lengths, so we'll use a naive approach)
def naive_variable_collate(batch):
    data, labels = zip(*batch)
    max_len = max(len(seq) for seq in data)
    
    # Inefficient padding
    padded_data = []
    for seq in data:
        if len(seq) < max_len:
            padded_seq = torch.cat([seq, torch.zeros(max_len - len(seq), seq.size(1))])
        else:
            padded_seq = seq
        padded_data.append(padded_seq)
    
    return torch.stack(padded_data), torch.tensor(labels, dtype=torch.long)

# Timing comparison
naive_loader = DataLoader(var_dataset, batch_size=16, collate_fn=naive_variable_collate)
efficient_loader = DataLoader(var_dataset, batch_size=16, collate_fn=efficient_variable_collate)

# Naive approach timing
start_time = time.time()
for batch_idx, batch in enumerate(naive_loader):
    if batch_idx >= 10:
        break
naive_time = time.time() - start_time

# Efficient approach timing
start_time = time.time()
for batch_idx, batch in enumerate(efficient_loader):
    if batch_idx >= 10:
        break
efficient_time = time.time() - start_time

print(f"Naive variable collate time: {naive_time:.4f} seconds")
print(f"Efficient variable collate time: {efficient_time:.4f} seconds")
print(f"Speed improvement: {(naive_time/efficient_time - 1) * 100:.1f}%")
Naive variable collate time: 0.0015 seconds
Efficient variable collate time: 0.0013 seconds
Speed improvement: 14.5%

GPU-Accelerated Collate Function

def gpu_accelerated_collate(batch, device='cuda'):
    """Collate function that moves data to GPU during batching"""
    if not torch.cuda.is_available():
        device = 'cpu'
    
    data, labels = zip(*batch)
    
    # Stack and move to GPU in one operation
    data_tensor = torch.stack(data, dim=0).to(device, non_blocking=True)
    labels_tensor = torch.tensor(labels, dtype=torch.long).to(device, non_blocking=True)
    
    return data_tensor, labels_tensor

# Performance comparison with GPU transfer
if torch.cuda.is_available():
    device = 'cuda'
    
    # Standard approach: CPU collate + GPU transfer
    standard_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # GPU-accelerated collate
    gpu_loader = DataLoader(dataset, batch_size=32, shuffle=True, 
                           collate_fn=lambda batch: gpu_accelerated_collate(batch, device))
    
    # Timing standard approach
    start_time = time.time()
    for batch_idx, (data, labels) in enumerate(standard_loader):
        data, labels = data.to(device), labels.to(device)
        if batch_idx >= 10:
            break
    standard_gpu_time = time.time() - start_time
    
    # Timing GPU-accelerated collate
    start_time = time.time()
    for batch_idx, (data, labels) in enumerate(gpu_loader):
        if batch_idx >= 10:
            break
    gpu_collate_time = time.time() - start_time
    
    print(f"Standard CPU->GPU time: {standard_gpu_time:.4f} seconds")
    print(f"GPU-accelerated collate time: {gpu_collate_time:.4f} seconds")
    print(f"Speed improvement: {(standard_gpu_time/gpu_collate_time - 1) * 100:.1f}%")

Memory-Mapped File Collate for Large Datasets

import mmap

class MemoryMappedDataset(Dataset):
    """Dataset using memory-mapped files for efficient large data loading"""
    def __init__(self, data_array, labels_array):
        self.data = data_array
        self.labels = labels_array
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Return views instead of copies when possible
        return torch.from_numpy(self.data[idx].copy()), self.labels[idx]

def zero_copy_collate(batch):
    """Zero-copy collate function for numpy arrays"""
    data, labels = zip(*batch)
    
    # Use torch.from_numpy for zero-copy conversion when possible
    try:
        # Stack numpy arrays first, then convert to tensor
        data_array = np.stack([d.numpy() if isinstance(d, torch.Tensor) else d for d in data])
        data_tensor = torch.from_numpy(data_array)
    except:
        # Fallback to regular stacking
        data_tensor = torch.stack(data)
    
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    return data_tensor, labels_tensor

# Create sample data for demonstration
sample_data = np.random.randn(1000, 3, 224, 224).astype(np.float32)
sample_labels = np.random.randint(0, 10, 1000)

mmap_dataset = MemoryMappedDataset(sample_data, sample_labels)
mmap_loader = DataLoader(mmap_dataset, batch_size=32, collate_fn=zero_copy_collate)

Specialized Collate Functions

Multi-Modal Data Collate

class MultiModalDataset(Dataset):
    def __init__(self, size=1000):
        self.images = [torch.randn(3, 224, 224) for _ in range(size)]
        self.text = [torch.randint(0, 1000, (np.random.randint(5, 50),)) for _ in range(size)]
        self.labels = [torch.randint(0, 10, (1,)).item() for _ in range(size)]
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'image': self.images[idx],
            'text': self.text[idx],
            'label': self.labels[idx]
        }

def multimodal_collate(batch):
    """Efficient collate for multi-modal data"""
    # Separate different modalities
    images = [item['image'] for item in batch]
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # Batch images
    image_batch = torch.stack(images)
    
    # Batch variable-length text with padding
    text_lengths = torch.tensor([len(text) for text in texts])
    text_batch = rnn_utils.pad_sequence(texts, batch_first=True, padding_value=0)
    
    # Batch labels
    label_batch = torch.tensor(labels, dtype=torch.long)
    
    return {
        'images': image_batch,
        'texts': text_batch,
        'text_lengths': text_lengths,
        'labels': label_batch
    }

multimodal_dataset = MultiModalDataset(500)
multimodal_loader = DataLoader(multimodal_dataset, batch_size=16, collate_fn=multimodal_collate)

# Test the multimodal loader
sample_batch = next(iter(multimodal_loader))
print("Multimodal batch shapes:")
for key, value in sample_batch.items():
    print(f"  {key}: {value.shape}")

Augmentation-Aware Collate

import torchvision.transforms as transforms

def augmentation_collate(batch, transform=None):
    """Collate function that applies augmentations during batching"""
    data, labels = zip(*batch)
    
    if transform:
        # Apply augmentations during collation
        augmented_data = [transform(img) for img in data]
        data_tensor = torch.stack(augmented_data)
    else:
        data_tensor = torch.stack(data)
    
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    return data_tensor, labels_tensor

# Define augmentation pipeline
augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2)
])

# Create collate function with augmentation
aug_collate_fn = lambda batch: augmentation_collate(batch, augment_transform)

aug_loader = DataLoader(dataset, batch_size=32, collate_fn=aug_collate_fn)

Performance Tips and Best Practices

1. Minimize Data Copying

def efficient_collate_tips(batch):
    """Demonstrates efficient collate practices"""
    data, labels = zip(*batch)
    
    # TIP 1: Use torch.stack instead of torch.cat when possible
    # torch.stack is faster for same-sized tensors
    data_tensor = torch.stack(data, dim=0)  # Faster
    # data_tensor = torch.cat([d.unsqueeze(0) for d in data], dim=0)  # Slower
    
    # TIP 2: Use appropriate dtypes to save memory
    labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long for indices
    
    # TIP 3: Pre-allocate tensors when size is known
    # This is more relevant for complex batching scenarios
    
    return data_tensor, labels_tensor

2. Memory Usage Optimization

def memory_efficient_collate(batch):
    """Memory-efficient collate function"""
    data, labels = zip(*batch)
    
    # Pre-allocate output tensor to avoid multiple allocations
    batch_size = len(data)
    data_shape = data[0].shape
    
    # Allocate output tensor once
    output_tensor = torch.empty((batch_size,) + data_shape, dtype=data[0].dtype)
    
    # Fill the tensor in-place
    for i, tensor in enumerate(data):
        output_tensor[i] = tensor
    
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    return output_tensor, labels_tensor

3. Benchmarking Your Collate Functions

def benchmark_collate_functions():
    """Comprehensive benchmarking of different collate approaches"""
    dataset = SimpleDataset(1000)
    batch_size = 32
    num_batches = 20
    
    collate_functions = {
        'default': None,
        'fast_collate': fast_collate,
        'efficient_tips': efficient_collate_tips,
        'memory_efficient': memory_efficient_collate
    }
    
    results = {}
    
    for name, collate_fn in collate_functions.items():
        loader = DataLoader(dataset, batch_size=batch_size, 
                          collate_fn=collate_fn, shuffle=True)
        
        start_time = time.time()
        for batch_idx, (data, labels) in enumerate(loader):
            if batch_idx >= num_batches:
                break
        
        elapsed_time = time.time() - start_time
        results[name] = elapsed_time
        print(f"{name}: {elapsed_time:.4f} seconds")
    
    # Calculate improvements
    baseline = results['default']
    for name, time_taken in results.items():
        if name != 'default':
            improvement = (baseline / time_taken - 1) * 100
            print(f"{name} improvement: {improvement:.1f}%")

# Run the benchmark
benchmark_collate_functions()

Key Takeaways

  1. Use torch.stack() instead of torch.cat() for same-sized tensors
  2. Minimize data copying by working with tensor views when possible
  3. Pre-allocate tensors when batch sizes and shapes are known
  4. Consider GPU transfer during collation for better pipeline efficiency
  5. Use appropriate data types to optimize memory usage
  6. Profile your specific use case as optimal strategies vary by data type and size
  7. Leverage specialized functions like pad_sequence for variable-length data

Custom collate functions can provide significant performance improvements, especially for large datasets or complex data structures. The key is to minimize unnecessary data operations and memory allocations while taking advantage of PyTorch’s optimized tensor operations.