Exercise: Gradient Accumulation
Implement gradient accumulation to simulate large batch training on limited GPU memory. This is how practitioners train with effective batch sizes of thousands on a single GPU.
Problem Statement
Implement a train_with_accumulation function that:
- Processes micro-batches sequentially, accumulating gradients
- Only calls
optimizer.step()after everyaccumulation_stepsmicro-batches - Correctly scales the loss so the effective gradient is averaged over the full accumulated batch
- Works with mixed precision (AMP) as well
Compare the results with standard full-batch training to verify equivalence.
Inputs: Model, data loader, optimizer, number of accumulation steps.
Outputs: Training loss history, equivalent to training with batch_size * accumulation_steps effective batch size.
Gradient accumulation exploits the linearity of gradients: the gradient of a sum equals the sum of gradients. By calling .backward() multiple times before .step(), gradients from each micro-batch add up in the .grad attributes. Dividing the loss by accumulation_steps ensures the accumulated gradient equals the gradient of the full batch.
┌─────────────────────────────────────────────────────────────────────┐
│ Gradient Accumulation (accumulation_steps = 4) │
│ │
│ Micro-batch 1 Micro-batch 2 Micro-batch 3 │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ loss/4 │ │ loss/4 │ │ loss/4 │ │
│ │ .backward│ │ .backward│ │ .backward│ │
│ └────┬─────┘ └────┬─────┘ └────┬─────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ .grad = g1 .grad = g1+g2 .grad = g1+g2+g3 │
│ │
│ Micro-batch 4 │
│ ┌──────────┐ │
│ │ loss/4 │ ┌────────────────────┐ │
│ │ .backward│ │ optimizer.step() │ ← apply accumulated │
│ └────┬─────┘ │ zero_grad() │ gradient │
│ │ └────────────────────┘ │
│ ▼ │
│ .grad = g1+g2+g3+g4 = avg gradient over effective batch │
│ ≡ gradient from batch_size * 4 │
│ │
│ GPU Memory: O(micro_batch) instead of O(effective_batch) │
│ Result: Mathematically equivalent to large-batch training │
└─────────────────────────────────────────────────────────────────────┘
Do NOT apply gradient clipping after each micro-batch. Clipping must happen after all accumulation steps, right before optimizer.step(). Clipping per micro-batch changes the effective gradient direction and breaks the equivalence with large-batch training.
Hints
- Do NOT call
optimizer.zero_grad()at every micro-batch -- only everyaccumulation_stepssteps. - Divide the loss by
accumulation_stepsbefore calling.backward(). This ensures the accumulated gradient is a proper average, not a sum. - Call
optimizer.step()andoptimizer.zero_grad()only when(step + 1) % accumulation_steps == 0. - Handle the last incomplete accumulation window (if dataset size is not divisible).
- With AMP, use
scaler.scale(loss / accumulation_steps).backward()and only callscaler.step()at accumulation boundaries.
Solution
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple, Iterator
def create_data(
num_samples: int, d_in: int, d_out: int, micro_batch_size: int
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Create synthetic data as a list of micro-batches."""
data = []
for i in range(0, num_samples, micro_batch_size):
batch_size = min(micro_batch_size, num_samples - i)
x = torch.randn(batch_size, d_in)
y = torch.randint(0, d_out, (batch_size,))
data.append((x, y))
return data
def train_standard(
model: nn.Module,
data: List[Tuple[torch.Tensor, torch.Tensor]],
lr: float,
epochs: int,
) -> List[float]:
"""Standard training (one step per batch)."""
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
for epoch in range(epochs):
epoch_loss = 0.0
for x, y in data:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(epoch_loss / len(data))
return losses
def train_with_accumulation(
model: nn.Module,
data: List[Tuple[torch.Tensor, torch.Tensor]],
lr: float,
epochs: int,
accumulation_steps: int,
) -> List[float]:
"""Training with gradient accumulation."""
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
for epoch in range(epochs):
epoch_loss = 0.0
optimizer.zero_grad() # Zero at the start of epoch
for step, (x, y) in enumerate(data):
# Forward pass
logits = model(x)
loss = criterion(logits, y)
# Scale loss by accumulation steps to get correct average
scaled_loss = loss / accumulation_steps
scaled_loss.backward() # Gradients accumulate in .grad
epoch_loss += loss.item()
# Step only at accumulation boundaries
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Handle remaining micro-batches (if any)
remaining = len(data) % accumulation_steps
if remaining != 0:
optimizer.step()
optimizer.zero_grad()
losses.append(epoch_loss / len(data))
return losses
def train_with_accumulation_and_amp(
model: nn.Module,
data: List[Tuple[torch.Tensor, torch.Tensor]],
lr: float,
epochs: int,
accumulation_steps: int,
) -> List[float]:
"""Gradient accumulation combined with mixed precision."""
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()
losses = []
for epoch in range(epochs):
epoch_loss = 0.0
optimizer.zero_grad()
for step, (x, y) in enumerate(data):
with autocast():
logits = model(x)
loss = criterion(logits, y)
# Scale for accumulation, then scale for AMP
scaler.scale(loss / accumulation_steps).backward()
epoch_loss += loss.item()
if (step + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
remaining = len(data) % accumulation_steps
if remaining != 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
losses.append(epoch_loss / len(data))
return losses
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
d_in, d_out = 64, 10
micro_batch_size = 16
accumulation_steps = 4
effective_batch_size = micro_batch_size * accumulation_steps # 64
# Create large-batch data for reference
large_data = create_data(256, d_in, d_out, effective_batch_size)
# Create micro-batch data
micro_data = create_data(256, d_in, d_out, micro_batch_size)
# Simple linear model for exact comparison
model_large = nn.Linear(d_in, d_out)
model_accum = nn.Linear(d_in, d_out)
# Use same initial weights
model_accum.load_state_dict(model_large.state_dict())
# Train both
losses_large = train_standard(model_large, large_data, lr=0.01, epochs=10)
losses_accum = train_with_accumulation(
model_accum, micro_data, lr=0.01, epochs=10, accumulation_steps=accumulation_steps
)
print("Large-batch losses:", [f"{l:.4f}" for l in losses_large])
print("Accum losses: ", [f"{l:.4f}" for l in losses_accum])
# Verify weight similarity (won't be exact due to step timing differences)
w_diff = (model_large.weight - model_accum.weight).abs().max().item()
print(f"\nMax weight difference: {w_diff:.6f}")
print("Weights are approximately equal." if w_diff < 0.1 else "Weights diverged.")
Walkthrough
-
Loss scaling -- We divide the loss by
accumulation_stepsbefore.backward(). Since gradients are additive, accumulatingKscaled-by-1/K gradients is equivalent to one gradient from a K-times-larger batch. Without this scaling, gradients would be K times too large. -
Delayed optimizer step --
optimizer.zero_grad()andoptimizer.step()are called only at accumulation boundaries. Between these calls,.backward()adds to the existing.gradtensors. -
Boundary handling -- If the number of micro-batches is not divisible by
accumulation_steps, the remaining gradients must still be applied. In production, you might rescale the remaining gradients to account for the smaller accumulation window. -
AMP integration --
GradScalerwraps the already-scaled loss. The order is: divide by accumulation_steps, then let scaler.scale() multiply by the AMP scale factor. Both scalings compose linearly. -
Memory savings -- The GPU only ever holds one micro-batch's activations. Peak memory is
O(micro_batch_size)instead ofO(effective_batch_size), enabling much larger effective batch sizes.
Complexity Analysis
- Compute: Identical to training with the full batch. Each sample is processed once per epoch regardless.
- Memory: O(micro_batch_size * model_activations) peak, not O(effective_batch_size * model_activations). For a model with 100M activation values and accumulation_steps=8, this saves 7/8 of activation memory.
- Speed: Slightly slower than true large-batch training because micro-batches are processed sequentially (no parallelism within one accumulation window). But memory savings often allow larger models or longer sequences.
Interview Tips
Key points to cover: (1) Why divide loss by accumulation_steps, not multiply learning rate (both work mathematically, but loss scaling keeps gradient norms comparable, which matters for gradient clipping and logging). (2) Interaction with learning rate warmup: the effective batch size determines the appropriate peak learning rate (linear scaling rule). (3) Interaction with batch normalization: BN statistics are computed per micro-batch, not per effective batch, which can change behavior. (4) Interaction with gradient clipping: clip after accumulation, not after each micro-batch. (5) This is how every large model (GPT-4, LLaMA) is trained -- understanding it is essential.
Quiz
Quiz — 3 Questions
Why must the loss be divided by accumulation_steps before calling backward()?
When using gradient accumulation with DDP (DistributedDataParallel), how do you avoid redundant gradient synchronization?
How does gradient accumulation interact with batch normalization?