Akshay’s Gradient
ML Codingadvanced55 min

Distributed Training Script

Exercise: Distributed Data Parallel (DDP) Training

Set up distributed training using PyTorch's DistributedDataParallel. This is how models are trained across multiple GPUs -- understanding the setup is essential for any ML engineer.

Problem Statement

Implement:

  1. setup_ddp(rank, world_size) -- initialize the distributed process group
  2. cleanup_ddp() -- destroy the process group
  3. train_ddp(rank, world_size, model_class, data, epochs) -- a complete DDP training function
  4. launch_distributed(...) -- spawn multiple processes for multi-GPU training

Each process manages one GPU, gets a shard of the data, and gradients are synchronized via all-reduce after each backward pass.

Inputs: Model class, training data, number of GPUs (world_size).

Outputs: A trained model with parameters synchronized across all GPUs.

Key Concept

DDP wraps a model so that after each backward pass, gradients are all-reduced (averaged) across all processes. Each process computes gradients on its own data shard, and after all-reduce, every process has the same averaged gradient. Since they start with the same weights and apply the same update, weights stay synchronized without explicit parameter broadcasting.

Interactive · DDP All-Reduce Gradient Synchronization
┌─────────────────────────────────────────────────────────────────────┐
│             Distributed Data Parallel (4 GPUs)                      │
│                                                                     │
│  GPU 0          GPU 1          GPU 2          GPU 3                 │
│  ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐              │
│  │Data 0  │    │Data 1  │    │Data 2  │    │Data 3  │              │
│  │shard   │    │shard   │    │shard   │    │shard   │              │
│  └───┬────┘    └───┬────┘    └───┬────┘    └───┬────┘              │
│      ▼             ▼             ▼             ▼                    │
│  ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐              │
│  │forward │    │forward │    │forward │    │forward │              │
│  │backward│    │backward│    │backward│    │backward│              │
│  │ grad_0 │    │ grad_1 │    │ grad_2 │    │ grad_3 │              │
│  └───┬────┘    └───┬────┘    └───┬────┘    └───┬────┘              │
│      │             │             │             │                    │
│      └─────────────┼─────────────┼─────────────┘                   │
│                    ▼                                                │
│            ┌──────────────┐                                         │
│            │  All-Reduce   │                                        │
│            │  avg(g0..g3)  │  ← Ring algorithm: O(P) per GPU       │
│            └──────┬───────┘                                         │
│                   │                                                 │
│      ┌────────────┼────────────┬────────────┐                      │
│      ▼            ▼            ▼            ▼                      │
│   avg_grad     avg_grad     avg_grad     avg_grad                  │
│   step()       step()       step()       step()                    │
│                                                                     │
│   All GPUs now have IDENTICAL parameters                            │
└─────────────────────────────────────────────────────────────────────┘
Warning

Always access the underlying model via ddp_model.module when saving checkpoints. Saving ddp_model.state_dict() directly includes the "module." prefix on all parameter keys, which causes errors when loading into a non-DDP model. Similarly, only save from rank 0 -- if all ranks try to save simultaneously, you risk file system corruption.

Interview Tip

When asked about DDP in interviews, emphasize that the all-reduce communication is overlapped with backward computation. DDP uses bucketed gradient synchronization: as soon as gradients for a group of parameters are ready, the all-reduce for that bucket starts while the backward pass continues computing gradients for earlier layers. This pipelining hides most of the communication latency.

Hints

Info
  1. Use torch.distributed.init_process_group(backend="nccl", ...) for GPU training.
  2. Set MASTER_ADDR and MASTER_PORT environment variables.
  3. Use DistributedSampler to partition data across processes -- each process gets a unique shard.
  4. Wrap the model with DDP(model, device_ids=[rank]).
  5. Set sampler.set_epoch(epoch) at the start of each epoch to ensure different shuffling.
  6. Only save the model from rank 0 to avoid file conflicts.

Solution

import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
from typing import List, Tuple


def setup_ddp(rank: int, world_size: int) -> None:
    """Initialize the distributed process group."""
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group(
        backend="nccl",  # Use "gloo" for CPU-only
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)


def cleanup_ddp() -> None:
    """Destroy the distributed process group."""
    dist.destroy_process_group()


class SimpleModel(nn.Module):
    def __init__(self, d_in: int, d_hidden: int, d_out: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def train_ddp(
    rank: int,
    world_size: int,
    d_in: int,
    d_hidden: int,
    d_out: int,
    epochs: int = 10,
) -> None:
    """Training function that runs on each process."""
    setup_ddp(rank, world_size)

    # Create model on this rank's GPU
    model = SimpleModel(d_in, d_hidden, d_out).to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # Create synthetic dataset (same on all ranks for demo)
    torch.manual_seed(0)  # Same data everywhere
    X = torch.randn(1000, d_in)
    y = torch.randint(0, d_out, (1000,))
    dataset = TensorDataset(X, y)

    # DistributedSampler partitions data across processes
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
    )
    loader = DataLoader(dataset, batch_size=32, sampler=sampler)

    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Ensure different shuffling each epoch
        ddp_model.train()
        epoch_loss = 0.0

        for x_batch, y_batch in loader:
            x_batch = x_batch.to(rank)
            y_batch = y_batch.to(rank)

            optimizer.zero_grad()
            logits = ddp_model(x_batch)
            loss = criterion(logits, y_batch)
            loss.backward()  # DDP automatically all-reduces gradients here
            optimizer.step()
            epoch_loss += loss.item()

        # Only print from rank 0
        if rank == 0:
            avg_loss = epoch_loss / len(loader)
            print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

    # Only save model from rank 0
    if rank == 0:
        torch.save(ddp_model.module.state_dict(), "/tmp/ddp_model.pt")
        print("Model saved from rank 0.")

    cleanup_ddp()


def launch_distributed(world_size: int, **kwargs) -> None:
    """Spawn processes for distributed training."""
    mp.spawn(
        train_ddp,
        args=(world_size, kwargs["d_in"], kwargs["d_hidden"], kwargs["d_out"]),
        nprocs=world_size,
        join=True,
    )


# ---------- CPU-compatible demo (no multiple GPUs needed) ----------
def demo_ddp_concepts() -> None:
    """Demonstrate DDP concepts without requiring multiple GPUs."""
    print("=== DDP Concepts Demo ===\n")

    # 1. DistributedSampler behavior
    dataset = TensorDataset(torch.arange(20))
    print("Dataset indices: 0-19\n")

    for rank in range(4):
        sampler = DistributedSampler(dataset, num_replicas=4, rank=rank, shuffle=False)
        indices = list(sampler)
        print(f"Rank {rank} sees indices: {indices}")

    # 2. All-reduce simulation
    print("\n--- Simulated All-Reduce ---")
    gradients = [torch.randn(3) for _ in range(4)]
    for i, g in enumerate(gradients):
        print(f"Rank {i} gradient: {g.tolist()}")

    # All-reduce: average across ranks
    avg_grad = torch.stack(gradients).mean(dim=0)
    print(f"After all-reduce (average): {avg_grad.tolist()}")
    print("All ranks now have this same averaged gradient.\n")

    # 3. Parameter synchronization
    print("--- Parameter Sync ---")
    model = SimpleModel(4, 8, 2)
    initial_params = {k: v.clone() for k, v in model.named_parameters()}

    # Simulate: same gradient applied to same weights -> same result
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    x = torch.randn(4, 4)
    loss = model(x).sum()
    loss.backward()
    optimizer.step()

    print("After identical updates, all ranks would have identical parameters.")
    print("DDP guarantees this by all-reducing gradients before optimizer.step().")


if __name__ == "__main__":
    demo_ddp_concepts()
    print("\n--- To run actual DDP training ---")
    print("Requires multiple GPUs. Launch with:")
    print("  torchrun --nproc_per_node=NUM_GPUS script.py")
    print("Or use launch_distributed(world_size=NUM_GPUS, d_in=64, d_hidden=128, d_out=10)")

Walkthrough

  1. Process group setup -- init_process_group establishes communication between processes. Each process is identified by its rank (0 to world_size-1). NCCL is the fastest backend for GPU-to-GPU communication.

  2. DistributedSampler -- Partitions the dataset so each process sees a unique 1/world_size shard. With 4 GPUs and 1000 samples, each GPU trains on 250 samples per epoch. set_epoch(epoch) is crucial: it changes the random seed for shuffling so shards differ between epochs.

  3. DDP wrapper -- DDP(model, device_ids=[rank]) registers hooks on the model's backward pass. When any gradient is computed, DDP schedules an all-reduce operation. By the time optimizer.step() runs, all gradients are synchronized.

  4. All-reduce -- The key operation: each GPU starts with its local gradient, and after all-reduce, every GPU has the average gradient across all GPUs. This is mathematically equivalent to computing the gradient on the combined batch.

  5. Model saving -- Only rank 0 saves to avoid file system conflicts. Access the underlying model via ddp_model.module (DDP wraps the model in an extra layer).

Complexity Analysis

  • Communication overhead: O(P) per all-reduce, where P = number of parameters. This is overlapped with backward computation (DDP pipelines communication with gradient computation for earlier layers).
  • Linear scaling: Ideal speedup is world_size x (process twice the data in the same time). Communication overhead reduces this slightly.
  • Memory per GPU: Each GPU stores a full copy of the model + optimizer states. For very large models, this does not fit and you need model parallelism (FSDP, tensor parallelism).

Interview Tips

Interview Tip

Critical knowledge: (1) DDP vs. DataParallel -- DDP uses one process per GPU (no GIL bottleneck), while DataParallel uses threads (slower). Always use DDP. (2) FSDP vs. DDP -- FSDP shards parameters across GPUs (for models that do not fit on one GPU), DDP replicates them. (3) Communication: all-reduce via ring algorithm is O(P) per GPU regardless of world_size. (4) Gradient accumulation with DDP: use model.no_sync() context manager for micro-batches to skip redundant all-reduce. (5) torchrun vs. mp.spawn: torchrun is the modern launcher with better fault tolerance.

Quiz

Quiz — 3 Questions

Why must you call sampler.set_epoch(epoch) at the start of each training epoch?

What is the communication complexity of all-reduce using the ring algorithm?

When should you use FSDP (Fully Sharded Data Parallel) instead of DDP?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.