Akshay’s Gradient
ML Codingadvanced50 min

Mixed Precision Training Loop

Exercise: Mixed-Precision Training with AMP

Implement a training loop using Automatic Mixed Precision (AMP) -- the standard technique for cutting training memory and time nearly in half on modern GPUs.

Problem Statement

Implement a training function train_with_amp that:

  1. Uses torch.cuda.amp.autocast to run forward passes in float16 (or bfloat16)
  2. Uses torch.cuda.amp.GradScaler to prevent underflow in float16 gradients
  3. Properly handles the scale-unscale-step-update cycle
  4. Compare training with and without AMP to verify correctness

Inputs: A model, data loader, optimizer, loss function, and number of epochs.

Outputs: Training loss history, demonstrating that AMP achieves similar loss to float32 training.

Key Concept

Mixed precision uses float16 for most operations (faster, less memory) while keeping a float32 master copy of weights for accumulation. The challenge is that float16 has a tiny dynamic range (min ~6e-8), so small gradients underflow to zero. GradScaler multiplies the loss by a large factor before backward, inflating gradients to representable range, then unscales before the optimizer step.

Interactive · Mixed-Precision Training: The Scale-Unscale Cycle
┌───────────────────────────────────────────────────────────────┐
│         Mixed-Precision Training Cycle                        │
│                                                               │
│  ┌─────────────────────────────────────┐                      │
│  │  FORWARD (autocast context)         │                      │
│  │  ┌────────────┐  ┌────────────┐    │                      │
│  │  │  Linear    │  │  Softmax   │    │                      │
│  │  │  (fp16)    │  │  (fp32)    │    │                      │
│  │  │  fast!     │  │  precise!  │    │                      │
│  │  └────────────┘  └────────────┘    │                      │
│  └────────────────────┬────────────────┘                      │
│                       │                                       │
│                       ▼ loss (fp32)                            │
│  ┌─────────────────────────────────────┐                      │
│  │  SCALE: loss × 65536                │                      │
│  │  (inflate to prevent fp16 underflow)│                      │
│  └────────────────────┬────────────────┘                      │
│                       │                                       │
│                       ▼                                       │
│  ┌─────────────────────────────────────┐                      │
│  │  BACKWARD: compute scaled gradients │                      │
│  │  grads are 65536× larger than true  │                      │
│  └────────────────────┬────────────────┘                      │
│                       │                                       │
│                       ▼                                       │
│  ┌─────────────────────────────────────┐                      │
│  │  UNSCALE: grads / 65536             │  ──▶ inf/nan?        │
│  │  CHECK: any overflow?               │      skip step,      │
│  └────────────────────┬────────────────┘      halve scale     │
│                       │ (no overflow)                         │
│                       ▼                                       │
│  ┌─────────────────────────────────────┐                      │
│  │  OPTIMIZER STEP (fp32 master weights)│                      │
│  │  UPDATE SCALE (increase if stable)  │                      │
│  └─────────────────────────────────────┘                      │
│                                                               │
│  Memory: fp16 activations ≈ half the memory of fp32           │
│  Speed:  tensor cores operate on fp16 matrices ≈ 2× faster   │
└───────────────────────────────────────────────────────────────┘
Warning

Do NOT wrap the backward pass in autocast(). Only the forward pass should run under autocast. The backward pass should compute gradients in the original dtypes. Wrapping backward in autocast can cause incorrect gradient types and numerical issues. Also, never call optimizer.step() directly with AMP -- always use scaler.step(optimizer), which handles gradient unscaling and overflow detection.

Interview Tip

Know which operations stay in fp32 under autocast: softmax, layer normalization, loss functions, and any operation that involves large reductions. These are kept in fp32 because they are numerically sensitive (small errors in summing many values accumulate). Matrix multiplications and convolutions run in fp16 because tensor cores handle them natively and the per-element error is acceptable.

Hints

Info
  1. Wrap the forward pass and loss computation in with torch.cuda.amp.autocast():.
  2. Call scaler.scale(loss).backward() instead of loss.backward().
  3. Call scaler.step(optimizer) instead of optimizer.step().
  4. Call scaler.update() after each step to adjust the scale factor.
  5. Do NOT call autocast on the backward pass -- only the forward.
  6. GradScaler automatically handles inf/nan detection: if gradients overflow, it skips the step and reduces the scale factor.

Solution

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple


class SimpleModel(nn.Module):
    """A small MLP for demonstration."""

    def __init__(self, d_in: int, d_hidden: int, d_out: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def train_fp32(
    model: nn.Module,
    data: List[Tuple[torch.Tensor, torch.Tensor]],
    lr: float = 1e-3,
    epochs: int = 10,
) -> List[float]:
    """Standard float32 training loop."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses: List[float] = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        for x_batch, y_batch in data:
            optimizer.zero_grad()
            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(data))

    return losses


def train_with_amp(
    model: nn.Module,
    data: List[Tuple[torch.Tensor, torch.Tensor]],
    lr: float = 1e-3,
    epochs: int = 10,
    use_bfloat16: bool = False,
) -> List[float]:
    """Training loop with Automatic Mixed Precision."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # GradScaler prevents gradient underflow in float16
    # Not needed for bfloat16 (it has the same exponent range as float32)
    scaler = GradScaler(enabled=not use_bfloat16)
    amp_dtype = torch.bfloat16 if use_bfloat16 else torch.float16

    losses: List[float] = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        for x_batch, y_batch in data:
            optimizer.zero_grad()

            # Forward pass in mixed precision
            with autocast(dtype=amp_dtype):
                logits = model(x_batch)
                loss = criterion(logits, y_batch)
            # Note: loss is float32 (autocast promotes the loss)

            # Backward pass: scale loss to prevent fp16 underflow
            scaler.scale(loss).backward()

            # Unscale gradients, then step (skips if inf/nan detected)
            scaler.step(optimizer)

            # Adjust the scale factor for the next iteration
            scaler.update()

            epoch_loss += loss.item()
        losses.append(epoch_loss / len(data))

    return losses


def compare_dtypes_in_autocast(model: nn.Module, x: torch.Tensor) -> None:
    """Show which operations run in fp16 vs fp32 under autocast."""
    print("\n--- Dtype inspection under autocast ---")
    with autocast():
        # Linear layers run in fp16 (fast on tensor cores)
        h = model.net[0](x)
        print(f"After Linear:  {h.dtype}")  # float16

        # ReLU preserves dtype
        h = model.net[1](h)
        print(f"After ReLU:    {h.dtype}")  # float16

        # Loss computation is promoted to float32 for numerical stability
        loss = nn.CrossEntropyLoss()(model(x), torch.zeros(x.size(0), dtype=torch.long, device=x.device))
        print(f"Loss dtype:    {loss.dtype}")  # float32


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Create synthetic data
    d_in, d_hidden, d_out = 128, 256, 10
    num_batches, batch_size = 20, 64
    data = [
        (torch.randn(batch_size, d_in, device=device),
         torch.randint(0, d_out, (batch_size,), device=device))
        for _ in range(num_batches)
    ]

    # Train with fp32
    model_fp32 = SimpleModel(d_in, d_hidden, d_out).to(device)
    losses_fp32 = train_fp32(model_fp32, data, epochs=20)

    # Train with AMP (same initial weights)
    model_amp = SimpleModel(d_in, d_hidden, d_out).to(device)
    model_amp.load_state_dict(model_fp32.__class__(d_in, d_hidden, d_out).state_dict())
    torch.manual_seed(42)
    model_amp = SimpleModel(d_in, d_hidden, d_out).to(device)
    losses_amp = train_with_amp(model_amp, data, epochs=20)

    print("FP32 losses:", [f"{l:.4f}" for l in losses_fp32[:5]], "...")
    print("AMP  losses:", [f"{l:.4f}" for l in losses_amp[:5]], "...")
    print(f"FP32 final loss: {losses_fp32[-1]:.4f}")
    print(f"AMP  final loss: {losses_amp[-1]:.4f}")

    if device == "cuda":
        compare_dtypes_in_autocast(model_amp, data[0][0])

Walkthrough

  1. autocast context -- Inside with autocast(), PyTorch automatically casts eligible operations (matmul, convolution) to float16 while keeping numerically sensitive operations (softmax, loss, layer norm) in float32. You do not need to manually cast anything.

  2. GradScaler.scale() -- Multiplies the loss by a large factor (initially 65536) before .backward(). This inflates all gradients proportionally, pushing them out of the float16 underflow zone.

  3. GradScaler.step() -- Internally calls scaler.unscale_(optimizer) to divide gradients by the scale factor, checks for inf/nan, and then calls optimizer.step() if gradients are valid. If any gradient is inf/nan, the step is skipped entirely.

  4. GradScaler.update() -- Adjusts the scale factor dynamically. If steps succeeded, it increases the scale (more aggressive). If a step was skipped (overflow detected), it halves the scale.

  5. bfloat16 option -- bfloat16 has the same exponent range as float32, so gradient underflow is not an issue and GradScaler is not needed. However, bfloat16 has less mantissa precision than float16.

Complexity Analysis

  • Memory reduction: ~2x for activations stored for backward (float16 uses half the bytes). Weight master copies remain in float32.
  • Speed improvement: ~1.5-2x on GPUs with tensor cores (A100, H100). Tensor cores operate on 16x16 fp16 matrices natively.
  • Accuracy cost: Typically negligible. AMP is standard practice for training models of all sizes.

Interview Tips

Interview Tip

Key points: (1) Explain the three-step cycle: scale -> backward -> unscale/step -> update. (2) Why GradScaler is needed for fp16 but not bfloat16 (exponent range difference). (3) Which operations stay in fp32 under autocast (softmax, layer norm, loss -- all reduction operations). (4) The difference between mixed precision and pure fp16 training (mixed keeps fp32 master weights). (5) Memory savings: activations are ~half the size, but weights remain fp32 (the master copy). (6) In practice, AMP training is nearly always used -- not using it wastes GPU capacity.

Quiz

Quiz — 3 Questions

Why does GradScaler multiply the loss by a large factor before calling backward()?

Why is GradScaler not needed when using bfloat16 instead of float16?

In mixed-precision training, why are model weights kept in fp32 (the 'master copy')?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.