Exercise: Distributed Data Parallel (DDP) Training
Set up distributed training using PyTorch's DistributedDataParallel. This is how models are trained across multiple GPUs -- understanding the setup is essential for any ML engineer.
Problem Statement
Implement:
setup_ddp(rank, world_size)-- initialize the distributed process groupcleanup_ddp()-- destroy the process grouptrain_ddp(rank, world_size, model_class, data, epochs)-- a complete DDP training functionlaunch_distributed(...)-- spawn multiple processes for multi-GPU training
Each process manages one GPU, gets a shard of the data, and gradients are synchronized via all-reduce after each backward pass.
Inputs: Model class, training data, number of GPUs (world_size).
Outputs: A trained model with parameters synchronized across all GPUs.
DDP wraps a model so that after each backward pass, gradients are all-reduced (averaged) across all processes. Each process computes gradients on its own data shard, and after all-reduce, every process has the same averaged gradient. Since they start with the same weights and apply the same update, weights stay synchronized without explicit parameter broadcasting.
┌─────────────────────────────────────────────────────────────────────┐
│ Distributed Data Parallel (4 GPUs) │
│ │
│ GPU 0 GPU 1 GPU 2 GPU 3 │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │Data 0 │ │Data 1 │ │Data 2 │ │Data 3 │ │
│ │shard │ │shard │ │shard │ │shard │ │
│ └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ │
│ ▼ ▼ ▼ ▼ │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │forward │ │forward │ │forward │ │forward │ │
│ │backward│ │backward│ │backward│ │backward│ │
│ │ grad_0 │ │ grad_1 │ │ grad_2 │ │ grad_3 │ │
│ └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ │
│ │ │ │ │ │
│ └─────────────┼─────────────┼─────────────┘ │
│ ▼ │
│ ┌──────────────┐ │
│ │ All-Reduce │ │
│ │ avg(g0..g3) │ ← Ring algorithm: O(P) per GPU │
│ └──────┬───────┘ │
│ │ │
│ ┌────────────┼────────────┬────────────┐ │
│ ▼ ▼ ▼ ▼ │
│ avg_grad avg_grad avg_grad avg_grad │
│ step() step() step() step() │
│ │
│ All GPUs now have IDENTICAL parameters │
└─────────────────────────────────────────────────────────────────────┘
Always access the underlying model via ddp_model.module when saving checkpoints. Saving ddp_model.state_dict() directly includes the "module." prefix on all parameter keys, which causes errors when loading into a non-DDP model. Similarly, only save from rank 0 -- if all ranks try to save simultaneously, you risk file system corruption.
When asked about DDP in interviews, emphasize that the all-reduce communication is overlapped with backward computation. DDP uses bucketed gradient synchronization: as soon as gradients for a group of parameters are ready, the all-reduce for that bucket starts while the backward pass continues computing gradients for earlier layers. This pipelining hides most of the communication latency.
Hints
- Use
torch.distributed.init_process_group(backend="nccl", ...)for GPU training. - Set
MASTER_ADDRandMASTER_PORTenvironment variables. - Use
DistributedSamplerto partition data across processes -- each process gets a unique shard. - Wrap the model with
DDP(model, device_ids=[rank]). - Set
sampler.set_epoch(epoch)at the start of each epoch to ensure different shuffling. - Only save the model from rank 0 to avoid file conflicts.
Solution
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
from typing import List, Tuple
def setup_ddp(rank: int, world_size: int) -> None:
"""Initialize the distributed process group."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(
backend="nccl", # Use "gloo" for CPU-only
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)
def cleanup_ddp() -> None:
"""Destroy the distributed process group."""
dist.destroy_process_group()
class SimpleModel(nn.Module):
def __init__(self, d_in: int, d_hidden: int, d_out: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def train_ddp(
rank: int,
world_size: int,
d_in: int,
d_hidden: int,
d_out: int,
epochs: int = 10,
) -> None:
"""Training function that runs on each process."""
setup_ddp(rank, world_size)
# Create model on this rank's GPU
model = SimpleModel(d_in, d_hidden, d_out).to(rank)
ddp_model = DDP(model, device_ids=[rank])
# Create synthetic dataset (same on all ranks for demo)
torch.manual_seed(0) # Same data everywhere
X = torch.randn(1000, d_in)
y = torch.randint(0, d_out, (1000,))
dataset = TensorDataset(X, y)
# DistributedSampler partitions data across processes
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
sampler.set_epoch(epoch) # Ensure different shuffling each epoch
ddp_model.train()
epoch_loss = 0.0
for x_batch, y_batch in loader:
x_batch = x_batch.to(rank)
y_batch = y_batch.to(rank)
optimizer.zero_grad()
logits = ddp_model(x_batch)
loss = criterion(logits, y_batch)
loss.backward() # DDP automatically all-reduces gradients here
optimizer.step()
epoch_loss += loss.item()
# Only print from rank 0
if rank == 0:
avg_loss = epoch_loss / len(loader)
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
# Only save model from rank 0
if rank == 0:
torch.save(ddp_model.module.state_dict(), "/tmp/ddp_model.pt")
print("Model saved from rank 0.")
cleanup_ddp()
def launch_distributed(world_size: int, **kwargs) -> None:
"""Spawn processes for distributed training."""
mp.spawn(
train_ddp,
args=(world_size, kwargs["d_in"], kwargs["d_hidden"], kwargs["d_out"]),
nprocs=world_size,
join=True,
)
# ---------- CPU-compatible demo (no multiple GPUs needed) ----------
def demo_ddp_concepts() -> None:
"""Demonstrate DDP concepts without requiring multiple GPUs."""
print("=== DDP Concepts Demo ===\n")
# 1. DistributedSampler behavior
dataset = TensorDataset(torch.arange(20))
print("Dataset indices: 0-19\n")
for rank in range(4):
sampler = DistributedSampler(dataset, num_replicas=4, rank=rank, shuffle=False)
indices = list(sampler)
print(f"Rank {rank} sees indices: {indices}")
# 2. All-reduce simulation
print("\n--- Simulated All-Reduce ---")
gradients = [torch.randn(3) for _ in range(4)]
for i, g in enumerate(gradients):
print(f"Rank {i} gradient: {g.tolist()}")
# All-reduce: average across ranks
avg_grad = torch.stack(gradients).mean(dim=0)
print(f"After all-reduce (average): {avg_grad.tolist()}")
print("All ranks now have this same averaged gradient.\n")
# 3. Parameter synchronization
print("--- Parameter Sync ---")
model = SimpleModel(4, 8, 2)
initial_params = {k: v.clone() for k, v in model.named_parameters()}
# Simulate: same gradient applied to same weights -> same result
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
x = torch.randn(4, 4)
loss = model(x).sum()
loss.backward()
optimizer.step()
print("After identical updates, all ranks would have identical parameters.")
print("DDP guarantees this by all-reducing gradients before optimizer.step().")
if __name__ == "__main__":
demo_ddp_concepts()
print("\n--- To run actual DDP training ---")
print("Requires multiple GPUs. Launch with:")
print(" torchrun --nproc_per_node=NUM_GPUS script.py")
print("Or use launch_distributed(world_size=NUM_GPUS, d_in=64, d_hidden=128, d_out=10)")
Walkthrough
-
Process group setup --
init_process_groupestablishes communication between processes. Each process is identified by itsrank(0 to world_size-1). NCCL is the fastest backend for GPU-to-GPU communication. -
DistributedSampler -- Partitions the dataset so each process sees a unique 1/world_size shard. With 4 GPUs and 1000 samples, each GPU trains on 250 samples per epoch.
set_epoch(epoch)is crucial: it changes the random seed for shuffling so shards differ between epochs. -
DDP wrapper --
DDP(model, device_ids=[rank])registers hooks on the model's backward pass. When any gradient is computed, DDP schedules an all-reduce operation. By the timeoptimizer.step()runs, all gradients are synchronized. -
All-reduce -- The key operation: each GPU starts with its local gradient, and after all-reduce, every GPU has the average gradient across all GPUs. This is mathematically equivalent to computing the gradient on the combined batch.
-
Model saving -- Only rank 0 saves to avoid file system conflicts. Access the underlying model via
ddp_model.module(DDP wraps the model in an extra layer).
Complexity Analysis
- Communication overhead: O(P) per all-reduce, where P = number of parameters. This is overlapped with backward computation (DDP pipelines communication with gradient computation for earlier layers).
- Linear scaling: Ideal speedup is world_size x (process twice the data in the same time). Communication overhead reduces this slightly.
- Memory per GPU: Each GPU stores a full copy of the model + optimizer states. For very large models, this does not fit and you need model parallelism (FSDP, tensor parallelism).
Interview Tips
Critical knowledge: (1) DDP vs. DataParallel -- DDP uses one process per GPU (no GIL bottleneck), while DataParallel uses threads (slower). Always use DDP. (2) FSDP vs. DDP -- FSDP shards parameters across GPUs (for models that do not fit on one GPU), DDP replicates them. (3) Communication: all-reduce via ring algorithm is O(P) per GPU regardless of world_size. (4) Gradient accumulation with DDP: use model.no_sync() context manager for micro-batches to skip redundant all-reduce. (5) torchrun vs. mp.spawn: torchrun is the modern launcher with better fault tolerance.
Quiz
Quiz — 3 Questions
Why must you call sampler.set_epoch(epoch) at the start of each training epoch?
What is the communication complexity of all-reduce using the ring algorithm?
When should you use FSDP (Fully Sharded Data Parallel) instead of DDP?