Exercise: Mixed-Precision Training with AMP
Implement a training loop using Automatic Mixed Precision (AMP) -- the standard technique for cutting training memory and time nearly in half on modern GPUs.
Problem Statement
Implement a training function train_with_amp that:
- Uses
torch.cuda.amp.autocastto run forward passes in float16 (or bfloat16) - Uses
torch.cuda.amp.GradScalerto prevent underflow in float16 gradients - Properly handles the scale-unscale-step-update cycle
- Compare training with and without AMP to verify correctness
Inputs: A model, data loader, optimizer, loss function, and number of epochs.
Outputs: Training loss history, demonstrating that AMP achieves similar loss to float32 training.
Mixed precision uses float16 for most operations (faster, less memory) while keeping a float32 master copy of weights for accumulation. The challenge is that float16 has a tiny dynamic range (min ~6e-8), so small gradients underflow to zero. GradScaler multiplies the loss by a large factor before backward, inflating gradients to representable range, then unscales before the optimizer step.
┌───────────────────────────────────────────────────────────────┐
│ Mixed-Precision Training Cycle │
│ │
│ ┌─────────────────────────────────────┐ │
│ │ FORWARD (autocast context) │ │
│ │ ┌────────────┐ ┌────────────┐ │ │
│ │ │ Linear │ │ Softmax │ │ │
│ │ │ (fp16) │ │ (fp32) │ │ │
│ │ │ fast! │ │ precise! │ │ │
│ │ └────────────┘ └────────────┘ │ │
│ └────────────────────┬────────────────┘ │
│ │ │
│ ▼ loss (fp32) │
│ ┌─────────────────────────────────────┐ │
│ │ SCALE: loss × 65536 │ │
│ │ (inflate to prevent fp16 underflow)│ │
│ └────────────────────┬────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ BACKWARD: compute scaled gradients │ │
│ │ grads are 65536× larger than true │ │
│ └────────────────────┬────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ UNSCALE: grads / 65536 │ ──▶ inf/nan? │
│ │ CHECK: any overflow? │ skip step, │
│ └────────────────────┬────────────────┘ halve scale │
│ │ (no overflow) │
│ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ OPTIMIZER STEP (fp32 master weights)│ │
│ │ UPDATE SCALE (increase if stable) │ │
│ └─────────────────────────────────────┘ │
│ │
│ Memory: fp16 activations ≈ half the memory of fp32 │
│ Speed: tensor cores operate on fp16 matrices ≈ 2× faster │
└───────────────────────────────────────────────────────────────┘
Do NOT wrap the backward pass in autocast(). Only the forward pass should run under autocast. The backward pass should compute gradients in the original dtypes. Wrapping backward in autocast can cause incorrect gradient types and numerical issues. Also, never call optimizer.step() directly with AMP -- always use scaler.step(optimizer), which handles gradient unscaling and overflow detection.
Know which operations stay in fp32 under autocast: softmax, layer normalization, loss functions, and any operation that involves large reductions. These are kept in fp32 because they are numerically sensitive (small errors in summing many values accumulate). Matrix multiplications and convolutions run in fp16 because tensor cores handle them natively and the per-element error is acceptable.
Hints
- Wrap the forward pass and loss computation in
with torch.cuda.amp.autocast():. - Call
scaler.scale(loss).backward()instead ofloss.backward(). - Call
scaler.step(optimizer)instead ofoptimizer.step(). - Call
scaler.update()after each step to adjust the scale factor. - Do NOT call
autocaston the backward pass -- only the forward. GradScalerautomatically handles inf/nan detection: if gradients overflow, it skips the step and reduces the scale factor.
Solution
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple
class SimpleModel(nn.Module):
"""A small MLP for demonstration."""
def __init__(self, d_in: int, d_hidden: int, d_out: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def train_fp32(
model: nn.Module,
data: List[Tuple[torch.Tensor, torch.Tensor]],
lr: float = 1e-3,
epochs: int = 10,
) -> List[float]:
"""Standard float32 training loop."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses: List[float] = []
for epoch in range(epochs):
epoch_loss = 0.0
for x_batch, y_batch in data:
optimizer.zero_grad()
logits = model(x_batch)
loss = criterion(logits, y_batch)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(epoch_loss / len(data))
return losses
def train_with_amp(
model: nn.Module,
data: List[Tuple[torch.Tensor, torch.Tensor]],
lr: float = 1e-3,
epochs: int = 10,
use_bfloat16: bool = False,
) -> List[float]:
"""Training loop with Automatic Mixed Precision."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# GradScaler prevents gradient underflow in float16
# Not needed for bfloat16 (it has the same exponent range as float32)
scaler = GradScaler(enabled=not use_bfloat16)
amp_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
losses: List[float] = []
for epoch in range(epochs):
epoch_loss = 0.0
for x_batch, y_batch in data:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast(dtype=amp_dtype):
logits = model(x_batch)
loss = criterion(logits, y_batch)
# Note: loss is float32 (autocast promotes the loss)
# Backward pass: scale loss to prevent fp16 underflow
scaler.scale(loss).backward()
# Unscale gradients, then step (skips if inf/nan detected)
scaler.step(optimizer)
# Adjust the scale factor for the next iteration
scaler.update()
epoch_loss += loss.item()
losses.append(epoch_loss / len(data))
return losses
def compare_dtypes_in_autocast(model: nn.Module, x: torch.Tensor) -> None:
"""Show which operations run in fp16 vs fp32 under autocast."""
print("\n--- Dtype inspection under autocast ---")
with autocast():
# Linear layers run in fp16 (fast on tensor cores)
h = model.net[0](x)
print(f"After Linear: {h.dtype}") # float16
# ReLU preserves dtype
h = model.net[1](h)
print(f"After ReLU: {h.dtype}") # float16
# Loss computation is promoted to float32 for numerical stability
loss = nn.CrossEntropyLoss()(model(x), torch.zeros(x.size(0), dtype=torch.long, device=x.device))
print(f"Loss dtype: {loss.dtype}") # float32
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create synthetic data
d_in, d_hidden, d_out = 128, 256, 10
num_batches, batch_size = 20, 64
data = [
(torch.randn(batch_size, d_in, device=device),
torch.randint(0, d_out, (batch_size,), device=device))
for _ in range(num_batches)
]
# Train with fp32
model_fp32 = SimpleModel(d_in, d_hidden, d_out).to(device)
losses_fp32 = train_fp32(model_fp32, data, epochs=20)
# Train with AMP (same initial weights)
model_amp = SimpleModel(d_in, d_hidden, d_out).to(device)
model_amp.load_state_dict(model_fp32.__class__(d_in, d_hidden, d_out).state_dict())
torch.manual_seed(42)
model_amp = SimpleModel(d_in, d_hidden, d_out).to(device)
losses_amp = train_with_amp(model_amp, data, epochs=20)
print("FP32 losses:", [f"{l:.4f}" for l in losses_fp32[:5]], "...")
print("AMP losses:", [f"{l:.4f}" for l in losses_amp[:5]], "...")
print(f"FP32 final loss: {losses_fp32[-1]:.4f}")
print(f"AMP final loss: {losses_amp[-1]:.4f}")
if device == "cuda":
compare_dtypes_in_autocast(model_amp, data[0][0])
Walkthrough
-
autocast context -- Inside
with autocast(), PyTorch automatically casts eligible operations (matmul, convolution) to float16 while keeping numerically sensitive operations (softmax, loss, layer norm) in float32. You do not need to manually cast anything. -
GradScaler.scale() -- Multiplies the loss by a large factor (initially 65536) before
.backward(). This inflates all gradients proportionally, pushing them out of the float16 underflow zone. -
GradScaler.step() -- Internally calls
scaler.unscale_(optimizer)to divide gradients by the scale factor, checks for inf/nan, and then callsoptimizer.step()if gradients are valid. If any gradient is inf/nan, the step is skipped entirely. -
GradScaler.update() -- Adjusts the scale factor dynamically. If steps succeeded, it increases the scale (more aggressive). If a step was skipped (overflow detected), it halves the scale.
-
bfloat16 option -- bfloat16 has the same exponent range as float32, so gradient underflow is not an issue and
GradScaleris not needed. However, bfloat16 has less mantissa precision than float16.
Complexity Analysis
- Memory reduction: ~2x for activations stored for backward (float16 uses half the bytes). Weight master copies remain in float32.
- Speed improvement: ~1.5-2x on GPUs with tensor cores (A100, H100). Tensor cores operate on 16x16 fp16 matrices natively.
- Accuracy cost: Typically negligible. AMP is standard practice for training models of all sizes.
Interview Tips
Key points: (1) Explain the three-step cycle: scale -> backward -> unscale/step -> update. (2) Why GradScaler is needed for fp16 but not bfloat16 (exponent range difference). (3) Which operations stay in fp32 under autocast (softmax, layer norm, loss -- all reduction operations). (4) The difference between mixed precision and pure fp16 training (mixed keeps fp32 master weights). (5) Memory savings: activations are ~half the size, but weights remain fp32 (the master copy). (6) In practice, AMP training is nearly always used -- not using it wastes GPU capacity.
Quiz
Quiz — 3 Questions
Why does GradScaler multiply the loss by a large factor before calling backward()?
Why is GradScaler not needed when using bfloat16 instead of float16?
In mixed-precision training, why are model weights kept in fp32 (the 'master copy')?