Akshay’s Gradient
ML Codingintermediate35 min

Gradient Accumulation

Exercise: Gradient Accumulation

Implement gradient accumulation to simulate large batch training on limited GPU memory. This is how practitioners train with effective batch sizes of thousands on a single GPU.

Problem Statement

Implement a train_with_accumulation function that:

  1. Processes micro-batches sequentially, accumulating gradients
  2. Only calls optimizer.step() after every accumulation_steps micro-batches
  3. Correctly scales the loss so the effective gradient is averaged over the full accumulated batch
  4. Works with mixed precision (AMP) as well

Compare the results with standard full-batch training to verify equivalence.

Inputs: Model, data loader, optimizer, number of accumulation steps.

Outputs: Training loss history, equivalent to training with batch_size * accumulation_steps effective batch size.

Key Concept

Gradient accumulation exploits the linearity of gradients: the gradient of a sum equals the sum of gradients. By calling .backward() multiple times before .step(), gradients from each micro-batch add up in the .grad attributes. Dividing the loss by accumulation_steps ensures the accumulated gradient equals the gradient of the full batch.

Interactive · Gradient Accumulation Flow
┌─────────────────────────────────────────────────────────────────────┐
│        Gradient Accumulation (accumulation_steps = 4)               │
│                                                                     │
│   Micro-batch 1        Micro-batch 2       Micro-batch 3           │
│   ┌──────────┐         ┌──────────┐        ┌──────────┐            │
│   │ loss/4   │         │ loss/4   │        │ loss/4   │            │
│   │ .backward│         │ .backward│        │ .backward│            │
│   └────┬─────┘         └────┬─────┘        └────┬─────┘            │
│        │                    │                    │                   │
│        ▼                    ▼                    ▼                   │
│   .grad = g1          .grad = g1+g2       .grad = g1+g2+g3        │
│                                                                     │
│   Micro-batch 4                                                     │
│   ┌──────────┐                                                      │
│   │ loss/4   │        ┌────────────────────┐                        │
│   │ .backward│        │  optimizer.step()  │  ← apply accumulated  │
│   └────┬─────┘        │  zero_grad()       │    gradient            │
│        │              └────────────────────┘                        │
│        ▼                                                            │
│   .grad = g1+g2+g3+g4 = avg gradient over effective batch          │
│                        ≡ gradient from batch_size * 4               │
│                                                                     │
│   GPU Memory: O(micro_batch) instead of O(effective_batch)          │
│   Result:     Mathematically equivalent to large-batch training     │
└─────────────────────────────────────────────────────────────────────┘
Warning

Do NOT apply gradient clipping after each micro-batch. Clipping must happen after all accumulation steps, right before optimizer.step(). Clipping per micro-batch changes the effective gradient direction and breaks the equivalence with large-batch training.

Hints

Info
  1. Do NOT call optimizer.zero_grad() at every micro-batch -- only every accumulation_steps steps.
  2. Divide the loss by accumulation_steps before calling .backward(). This ensures the accumulated gradient is a proper average, not a sum.
  3. Call optimizer.step() and optimizer.zero_grad() only when (step + 1) % accumulation_steps == 0.
  4. Handle the last incomplete accumulation window (if dataset size is not divisible).
  5. With AMP, use scaler.scale(loss / accumulation_steps).backward() and only call scaler.step() at accumulation boundaries.

Solution

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple, Iterator


def create_data(
    num_samples: int, d_in: int, d_out: int, micro_batch_size: int
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    """Create synthetic data as a list of micro-batches."""
    data = []
    for i in range(0, num_samples, micro_batch_size):
        batch_size = min(micro_batch_size, num_samples - i)
        x = torch.randn(batch_size, d_in)
        y = torch.randint(0, d_out, (batch_size,))
        data.append((x, y))
    return data


def train_standard(
    model: nn.Module,
    data: List[Tuple[torch.Tensor, torch.Tensor]],
    lr: float,
    epochs: int,
) -> List[float]:
    """Standard training (one step per batch)."""
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        for x, y in data:
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(data))
    return losses


def train_with_accumulation(
    model: nn.Module,
    data: List[Tuple[torch.Tensor, torch.Tensor]],
    lr: float,
    epochs: int,
    accumulation_steps: int,
) -> List[float]:
    """Training with gradient accumulation."""
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        optimizer.zero_grad()  # Zero at the start of epoch

        for step, (x, y) in enumerate(data):
            # Forward pass
            logits = model(x)
            loss = criterion(logits, y)

            # Scale loss by accumulation steps to get correct average
            scaled_loss = loss / accumulation_steps
            scaled_loss.backward()  # Gradients accumulate in .grad

            epoch_loss += loss.item()

            # Step only at accumulation boundaries
            if (step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        # Handle remaining micro-batches (if any)
        remaining = len(data) % accumulation_steps
        if remaining != 0:
            optimizer.step()
            optimizer.zero_grad()

        losses.append(epoch_loss / len(data))
    return losses


def train_with_accumulation_and_amp(
    model: nn.Module,
    data: List[Tuple[torch.Tensor, torch.Tensor]],
    lr: float,
    epochs: int,
    accumulation_steps: int,
) -> List[float]:
    """Gradient accumulation combined with mixed precision."""
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        optimizer.zero_grad()

        for step, (x, y) in enumerate(data):
            with autocast():
                logits = model(x)
                loss = criterion(logits, y)

            # Scale for accumulation, then scale for AMP
            scaler.scale(loss / accumulation_steps).backward()

            epoch_loss += loss.item()

            if (step + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        remaining = len(data) % accumulation_steps
        if remaining != 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        losses.append(epoch_loss / len(data))
    return losses


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    d_in, d_out = 64, 10
    micro_batch_size = 16
    accumulation_steps = 4
    effective_batch_size = micro_batch_size * accumulation_steps  # 64

    # Create large-batch data for reference
    large_data = create_data(256, d_in, d_out, effective_batch_size)

    # Create micro-batch data
    micro_data = create_data(256, d_in, d_out, micro_batch_size)

    # Simple linear model for exact comparison
    model_large = nn.Linear(d_in, d_out)
    model_accum = nn.Linear(d_in, d_out)
    # Use same initial weights
    model_accum.load_state_dict(model_large.state_dict())

    # Train both
    losses_large = train_standard(model_large, large_data, lr=0.01, epochs=10)
    losses_accum = train_with_accumulation(
        model_accum, micro_data, lr=0.01, epochs=10, accumulation_steps=accumulation_steps
    )

    print("Large-batch losses:", [f"{l:.4f}" for l in losses_large])
    print("Accum losses:      ", [f"{l:.4f}" for l in losses_accum])

    # Verify weight similarity (won't be exact due to step timing differences)
    w_diff = (model_large.weight - model_accum.weight).abs().max().item()
    print(f"\nMax weight difference: {w_diff:.6f}")
    print("Weights are approximately equal." if w_diff < 0.1 else "Weights diverged.")

Walkthrough

  1. Loss scaling -- We divide the loss by accumulation_steps before .backward(). Since gradients are additive, accumulating K scaled-by-1/K gradients is equivalent to one gradient from a K-times-larger batch. Without this scaling, gradients would be K times too large.

  2. Delayed optimizer step -- optimizer.zero_grad() and optimizer.step() are called only at accumulation boundaries. Between these calls, .backward() adds to the existing .grad tensors.

  3. Boundary handling -- If the number of micro-batches is not divisible by accumulation_steps, the remaining gradients must still be applied. In production, you might rescale the remaining gradients to account for the smaller accumulation window.

  4. AMP integration -- GradScaler wraps the already-scaled loss. The order is: divide by accumulation_steps, then let scaler.scale() multiply by the AMP scale factor. Both scalings compose linearly.

  5. Memory savings -- The GPU only ever holds one micro-batch's activations. Peak memory is O(micro_batch_size) instead of O(effective_batch_size), enabling much larger effective batch sizes.

Complexity Analysis

  • Compute: Identical to training with the full batch. Each sample is processed once per epoch regardless.
  • Memory: O(micro_batch_size * model_activations) peak, not O(effective_batch_size * model_activations). For a model with 100M activation values and accumulation_steps=8, this saves 7/8 of activation memory.
  • Speed: Slightly slower than true large-batch training because micro-batches are processed sequentially (no parallelism within one accumulation window). But memory savings often allow larger models or longer sequences.

Interview Tips

Interview Tip

Key points to cover: (1) Why divide loss by accumulation_steps, not multiply learning rate (both work mathematically, but loss scaling keeps gradient norms comparable, which matters for gradient clipping and logging). (2) Interaction with learning rate warmup: the effective batch size determines the appropriate peak learning rate (linear scaling rule). (3) Interaction with batch normalization: BN statistics are computed per micro-batch, not per effective batch, which can change behavior. (4) Interaction with gradient clipping: clip after accumulation, not after each micro-batch. (5) This is how every large model (GPT-4, LLaMA) is trained -- understanding it is essential.

Quiz

Quiz — 3 Questions

Why must the loss be divided by accumulation_steps before calling backward()?

When using gradient accumulation with DDP (DistributedDataParallel), how do you avoid redundant gradient synchronization?

How does gradient accumulation interact with batch normalization?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.