Akshay’s Gradient
ML Codingintermediate45 min

Efficient Data Loader

Exercise: Build an Efficient Data Loader

Implement a PyTorch-style data loader with batching, shuffling, and a custom collate function. Understanding data loading is critical for training efficiency -- it is often the bottleneck, not the GPU.

Problem Statement

Implement:

  1. SimpleDataset -- a map-style dataset class with __len__ and __getitem__
  2. DataLoader -- an iterator that yields batches with shuffling and a custom collate function
  3. collate_fn -- a function that pads variable-length sequences to the same length within a batch

Inputs: A list of (text_tokens, label) pairs where text_tokens are variable-length lists of integers.

Outputs: Batches of (padded_tokens, attention_mask, labels) tensors.

Key Concept

A data loader is responsible for: (1) sampling indices (sequential or shuffled), (2) grouping indices into batches, (3) fetching data for each index, and (4) collating individual samples into a batch tensor. The collate step is where variable-length sequences get padded, enabling batched GPU computation.

Interactive · Data Loader: Shuffle, Batch, Pad, Collate
┌───────────────────────────────────────────────────────────────┐
│          Data Loader Pipeline                                 │
│                                                               │
│  Dataset (10 samples, variable length):                       │
│  ┌─────────────────────────────────────┐                      │
│  │ 0: [1,5,3,8,2]     label=0         │                      │
│  │ 1: [4,7,2]         label=1         │                      │
│  │ 2: [9,1,6,3]       label=0         │                      │
│  │ 3: [2,8]           label=1         │                      │
│  │ ...                                │                      │
│  └─────────────────────────────────────┘                      │
│          │                                                    │
│          ▼ Shuffle indices                                    │
│  indices = [7, 2, 5, 0, 9, 3, 1, 8, 4, 6]                   │
│          │                                                    │
│          ▼ Batch (batch_size=3)                               │
│  batch 0: [7, 2, 5]                                          │
│  batch 1: [0, 9, 3]                                          │
│  batch 2: [1, 8, 4]                                          │
│  batch 3: [6]  (last batch, smaller)                         │
│          │                                                    │
│          ▼ Fetch + Collate (dynamic padding)                  │
│  ┌─────────────────────────────────────────┐                  │
│  │ Batch 0 samples:                        │                  │
│  │   [8,5,2,1]       len=4                │                  │
│  │   [9,1,6,3]       len=4                │                  │
│  │   [1,2,3,4,5,6,7] len=7                │                  │
│  │                                         │                  │
│  │ Padded (max_len=7):                     │                  │
│  │   [8,5,2,1,0,0,0]  mask=[1,1,1,1,0,0,0]│                  │
│  │   [9,1,6,3,0,0,0]  mask=[1,1,1,1,0,0,0]│                  │
│  │   [1,2,3,4,5,6,7]  mask=[1,1,1,1,1,1,1]│                  │
│  └─────────────────────────────────────────┘                  │
└───────────────────────────────────────────────────────────────┘
Warning

A common mistake in custom collate functions is padding to the global maximum sequence length instead of the batch maximum. If your longest sequence in the entire dataset is 512 tokens but the current batch has sequences of length 10-20, padding to 512 wastes 25x the compute on padding tokens. Always compute max_len per batch (dynamic padding). Another pitfall: forgetting to create the attention mask, causing the model to attend to padding tokens.

Hints

Info
  1. SimpleDataset.__getitem__ should return a single (tokens, label) pair.
  2. The DataLoader yields batches by slicing shuffled/sequential indices into chunks of batch_size.
  3. The collate function receives a list of samples and must: find the max sequence length, pad all sequences to that length, create an attention mask (1 for real tokens, 0 for padding), and stack labels.
  4. Use torch.nn.utils.rnn.pad_sequence or implement padding manually.
  5. Make DataLoader an iterator by implementing __iter__ (with yield).

Solution

import torch
from torch import Tensor
from typing import List, Tuple, Iterator, Callable, Optional
import random


class SimpleDataset:
    """Map-style dataset for (token_ids, label) pairs."""

    def __init__(self, data: List[Tuple[List[int], int]]) -> None:
        self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[List[int], int]:
        return self.data[idx]


def collate_fn(
    batch: List[Tuple[List[int], int]], pad_value: int = 0
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Collate variable-length sequences into padded batch tensors.

    Returns:
        input_ids:      (batch_size, max_seq_len) padded token IDs
        attention_mask:  (batch_size, max_seq_len) 1=real token, 0=padding
        labels:          (batch_size,) integer labels
    """
    token_lists, labels = zip(*batch)

    # Find max length in this batch (dynamic padding)
    max_len = max(len(tokens) for tokens in token_lists)

    input_ids = []
    attention_mask = []
    for tokens in token_lists:
        seq_len = len(tokens)
        padding_len = max_len - seq_len
        input_ids.append(tokens + [pad_value] * padding_len)
        attention_mask.append([1] * seq_len + [0] * padding_len)

    return (
        torch.tensor(input_ids, dtype=torch.long),
        torch.tensor(attention_mask, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )


class DataLoader:
    """Simple data loader with batching and shuffling."""

    def __init__(
        self,
        dataset: SimpleDataset,
        batch_size: int = 32,
        shuffle: bool = True,
        collate: Optional[Callable] = None,
        drop_last: bool = False,
    ) -> None:
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.collate = collate or self._default_collate
        self.drop_last = drop_last

    @staticmethod
    def _default_collate(batch: List) -> Tuple:
        """Default collate: just zip into tuples of tensors."""
        return tuple(torch.tensor(items) for items in zip(*batch))

    def __len__(self) -> int:
        n = len(self.dataset)
        if self.drop_last:
            return n // self.batch_size
        return (n + self.batch_size - 1) // self.batch_size

    def __iter__(self) -> Iterator:
        indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(indices)

        # Group indices into batches
        for start in range(0, len(indices), self.batch_size):
            batch_indices = indices[start : start + self.batch_size]
            if self.drop_last and len(batch_indices) < self.batch_size:
                break

            # Fetch individual samples
            batch = [self.dataset[i] for i in batch_indices]

            # Collate into tensors
            yield self.collate(batch)


# ---------- demo ----------
if __name__ == "__main__":
    random.seed(42)

    # Create synthetic dataset: variable-length token sequences + labels
    data = [
        ([1, 5, 3, 8, 2], 0),
        ([4, 7, 2], 1),
        ([9, 1, 6, 3], 0),
        ([2, 8], 1),
        ([5, 3, 7, 1, 4, 6], 0),
        ([1, 2, 3, 4, 5, 6, 7], 1),
        ([3, 3], 0),
        ([8, 5, 2, 1], 1),
        ([1, 9, 4, 7, 2, 8], 0),
        ([6, 3, 1], 1),
    ]

    dataset = SimpleDataset(data)
    loader = DataLoader(dataset, batch_size=3, shuffle=True, collate=collate_fn)

    print(f"Dataset size: {len(dataset)}")
    print(f"Number of batches: {len(loader)}")
    print()

    for batch_idx, (input_ids, attn_mask, labels) in enumerate(loader):
        print(f"Batch {batch_idx}:")
        print(f"  input_ids shape:  {input_ids.shape}")
        print(f"  attn_mask shape:  {attn_mask.shape}")
        print(f"  labels:           {labels.tolist()}")
        print(f"  input_ids:\n{input_ids}")
        print(f"  attn_mask:\n{attn_mask}")
        print()

    # Verify shuffling produces different orderings
    orders = []
    for epoch in range(3):
        epoch_labels = []
        for _, _, labels in DataLoader(dataset, batch_size=3, shuffle=True, collate=collate_fn):
            epoch_labels.extend(labels.tolist())
        orders.append(epoch_labels)
    print("Epoch orderings differ:", orders[0] != orders[1] or orders[1] != orders[2])

Walkthrough

  1. SimpleDataset -- A minimal map-style dataset. __getitem__ returns one sample by index. This interface lets the data loader control the access pattern (sequential vs. shuffled).

  2. collate_fn -- The critical piece. It receives a list of individual samples and must produce batch tensors. For variable-length sequences, this means: (a) find the maximum length in this batch (dynamic padding -- more efficient than padding to a global max), (b) pad shorter sequences, (c) create an attention mask so the model can distinguish real tokens from padding.

  3. DataLoader.iter -- Each iteration: shuffle indices, group into batches, fetch samples, collate. Using yield makes it a generator, producing batches on demand rather than materializing all batches in memory.

  4. Dynamic padding -- Padding to the batch maximum rather than the dataset maximum saves computation. If one batch has sequences of length [5, 3, 4], we pad to 5, not to the global max of 100.

  5. drop_last -- When the dataset size is not divisible by batch_size, the last batch is smaller. This can cause issues in distributed training or with batch normalization. drop_last=True discards it.

Complexity Analysis

  • Shuffling: O(n) for Fisher-Yates shuffle.
  • Iteration: O(n) total to iterate through all samples.
  • Collation: O(B * L_max) per batch for padding, where B = batch size and L_max = max sequence length in the batch.
  • Memory: O(B * L_max) for the batch tensor. Dynamic padding minimizes this compared to global max padding.

The real PyTorch DataLoader additionally supports: multi-process workers (parallel data fetching), pin_memory (faster CPU-to-GPU transfer), and sampler customization (for distributed training, curriculum learning, etc.).

Interview Tips

Interview Tip

Interviewers focus on: (1) Why padding is necessary (GPU operations require regular tensor shapes) and the attention mask pattern. (2) Dynamic vs. static padding -- dynamic padding per batch is more efficient. (3) The performance impact of data loading -- use multiple workers, pin memory, and prefetch to hide I/O latency. (4) Bucketing/sorting by length to minimize padding waste. (5) IterableDataset vs. map-style dataset -- IterableDataset for streaming data that does not fit in memory.

Quiz

Quiz — 3 Questions

Why is dynamic padding (padding to the batch max) preferred over static padding (padding to the global max)?

Why would you sort or bucket sequences by length before batching?

What is the purpose of the drop_last parameter in a DataLoader?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.