import torch.nn.utils.rnn as rnn_utils
class VariableLengthDataset(Dataset):
def __init__(self, size=1000):
# Simulate variable-length sequences
self.data = [torch.randn(np.random.randint(10, 100), 128) for _ in range(size)]
self.labels = [torch.randint(0, 5, (1,)).item() for _ in range(size)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
def efficient_variable_collate(batch):
"""Efficient collate for variable-length sequences"""
data, labels = zip(*batch)
# Get sequence lengths for efficient packing
lengths = torch.tensor([len(seq) for seq in data])
# Pad sequences efficiently
padded_data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return padded_data, labels_tensor, lengths
# Performance comparison
var_dataset = VariableLengthDataset(500)
# Default collate (will fail for variable lengths, so we'll use a naive approach)
def naive_variable_collate(batch):
data, labels = zip(*batch)
max_len = max(len(seq) for seq in data)
# Inefficient padding
padded_data = []
for seq in data:
if len(seq) < max_len:
padded_seq = torch.cat([seq, torch.zeros(max_len - len(seq), seq.size(1))])
else:
padded_seq = seq
padded_data.append(padded_seq)
return torch.stack(padded_data), torch.tensor(labels, dtype=torch.long)
# Timing comparison
naive_loader = DataLoader(var_dataset, batch_size=16, collate_fn=naive_variable_collate)
efficient_loader = DataLoader(var_dataset, batch_size=16, collate_fn=efficient_variable_collate)
# Naive approach timing
start_time = time.time()
for batch_idx, batch in enumerate(naive_loader):
if batch_idx >= 10:
break
naive_time = time.time() - start_time
# Efficient approach timing
start_time = time.time()
for batch_idx, batch in enumerate(efficient_loader):
if batch_idx >= 10:
break
efficient_time = time.time() - start_time
print(f"Naive variable collate time: {naive_time:.4f} seconds")
print(f"Efficient variable collate time: {efficient_time:.4f} seconds")
print(f"Speed improvement: {(naive_time/efficient_time - 1) * 100:.1f}%")