Exercise: Learning Rate Scheduler with Cosine Warmup
Implement a learning rate scheduler with linear warmup followed by cosine decay -- the standard schedule used in training transformers, from BERT to GPT-4.
Problem Statement
Implement a CosineWarmupScheduler class that:
- Linearly increases the learning rate from 0 to
max_lrduring the warmup phase - Decays the learning rate following a cosine curve from
max_lrtomin_lrduring the remaining steps - Works as a PyTorch
LRSchedulersubclass with properstep()method - Also implement a standalone function version for clarity
Inputs: max_lr, min_lr, warmup_steps, total_steps.
Outputs: The learning rate at each step.
Warmup prevents early training instability: at initialization, model outputs are nearly random, so gradients are large and noisy. A small learning rate during warmup lets the model find a reasonable region before applying the full learning rate. Cosine decay smoothly reduces the learning rate to zero, which has been shown to outperform step decay and linear decay in practice.
┌──────────────────────────────────────────────────────────────────┐
│ Cosine Warmup Learning Rate Schedule │
│ │
│ LR │
│ max_lr ─ ─ ─ ─ ─ ─ ╱╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ warmup ╲ cosine decay │
│ ╱ (linear) ╲ (smooth) │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ ╲ │
│ min_lr ╱─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ╲─ ─ ─ ─ │
│ │ │ │
│ 0 ─────┼──────────────────────────┼───────── steps │
│ step 0 warmup total │
│ steps steps │
│ │
│ Key: Warmup prevents early instability from large gradients. │
│ Cosine decay gives a "soft landing" for better convergence.│
└──────────────────────────────────────────────────────────────────┘
A common mistake is calling scheduler.step() inside the batch loop instead of once per optimizer step. With gradient accumulation, the scheduler should step once per optimizer step, not once per micro-batch. Stepping too frequently causes the LR to decay faster than intended, leading to underfitting.
Hints
- During warmup (step < warmup_steps):
lr = max_lr * step / warmup_steps. - During decay: normalize the progress to [0, 1] as
progress = (step - warmup_steps) / (total_steps - warmup_steps). - Cosine decay:
lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * progress)). - At progress=0, cos(0)=1, so lr=max_lr. At progress=1, cos(pi)=-1, so lr=min_lr.
- To subclass
torch.optim.lr_scheduler.LambdaLR, provide a function that returns a multiplier.
Solution
import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from typing import List
def cosine_warmup_lr(
step: int,
max_lr: float,
min_lr: float,
warmup_steps: int,
total_steps: int,
) -> float:
"""Compute the learning rate at a given step."""
if step < warmup_steps:
# Linear warmup: 0 -> max_lr
return max_lr * step / warmup_steps
elif step >= total_steps:
return min_lr
else:
# Cosine decay: max_lr -> min_lr
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
def get_cosine_warmup_scheduler(
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
min_lr_ratio: float = 0.0,
) -> LambdaLR:
"""
Create a PyTorch LR scheduler with linear warmup + cosine decay.
Uses LambdaLR which multiplies the base_lr by the returned factor.
"""
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup from 0 to 1
return current_step / max(1, warmup_steps)
# Cosine decay from 1 to min_lr_ratio
progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
progress = min(progress, 1.0) # Clamp at the end
return min_lr_ratio + 0.5 * (1.0 - min_lr_ratio) * (1.0 + math.cos(math.pi * progress))
return LambdaLR(optimizer, lr_lambda)
class CosineWarmupScheduler:
"""Standalone scheduler (not a PyTorch subclass) for full control."""
def __init__(
self,
optimizer: torch.optim.Optimizer,
max_lr: float,
min_lr: float,
warmup_steps: int,
total_steps: int,
) -> None:
self.optimizer = optimizer
self.max_lr = max_lr
self.min_lr = min_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.current_step = 0
def get_lr(self) -> float:
return cosine_warmup_lr(
self.current_step, self.max_lr, self.min_lr,
self.warmup_steps, self.total_steps,
)
def step(self) -> None:
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
self.current_step += 1
# ---------- demo ----------
if __name__ == "__main__":
# Visualize the schedule
max_lr = 3e-4
min_lr = 1e-5
warmup_steps = 100
total_steps = 1000
# Compute LR at each step
lrs = [
cosine_warmup_lr(s, max_lr, min_lr, warmup_steps, total_steps)
for s in range(total_steps)
]
# Print key points
print(f"Step 0: lr = {lrs[0]:.2e}")
print(f"Step 50: lr = {lrs[50]:.2e} (mid-warmup)")
print(f"Step 100: lr = {lrs[100]:.2e} (peak)")
print(f"Step 500: lr = {lrs[500]:.2e} (mid-decay)")
print(f"Step 999: lr = {lrs[999]:.2e} (near end)")
# Verify with PyTorch scheduler
model = nn.Linear(10, 2)
optimizer = Adam(model.parameters(), lr=max_lr)
scheduler = get_cosine_warmup_scheduler(
optimizer, warmup_steps, total_steps, min_lr_ratio=min_lr / max_lr
)
pytorch_lrs = []
for step in range(total_steps):
pytorch_lrs.append(optimizer.param_groups[0]["lr"])
# Simulate a training step
optimizer.zero_grad()
loss = model(torch.randn(4, 10)).sum()
loss.backward()
optimizer.step()
scheduler.step()
# Compare
max_diff = max(abs(a - b) for a, b in zip(lrs, pytorch_lrs))
print(f"\nMax diff (standalone vs PyTorch): {max_diff:.2e}")
# Test standalone scheduler class
optimizer2 = Adam(nn.Linear(10, 2).parameters(), lr=max_lr)
sched2 = CosineWarmupScheduler(optimizer2, max_lr, min_lr, warmup_steps, total_steps)
custom_lrs = []
for _ in range(total_steps):
custom_lrs.append(sched2.get_lr())
sched2.step()
max_diff2 = max(abs(a - b) for a, b in zip(lrs, custom_lrs))
print(f"Max diff (standalone func vs class): {max_diff2:.2e}")
# ASCII visualization
print("\n--- Learning Rate Schedule ---")
width = 60
for i in range(0, total_steps, total_steps // 20):
bar_len = int(lrs[i] / max_lr * width)
bar = "#" * bar_len
print(f"Step {i:4d} | {bar:<{width}} | {lrs[i]:.2e}")
Walkthrough
-
Linear warmup -- For steps 0 to warmup_steps, the LR increases linearly from 0 to max_lr. The formula
lr = max_lr * step / warmup_stepsgives a straight ramp. At step 0, lr=0 (no update); at step=warmup_steps, lr=max_lr. -
Cosine decay -- After warmup, we normalize the remaining steps to a progress value in [0, 1]. The cosine formula
0.5 * (1 + cos(pi * progress))maps [0, 1] to [1, 0] smoothly. Scaling by(max_lr - min_lr)and addingmin_lrgives us decay from max_lr to min_lr. -
LambdaLR integration -- PyTorch's
LambdaLRexpects a function that returns a multiplier applied to the base learning rate. So we return values in [0, 1] (or [min_lr_ratio, 1]) rather than absolute LR values. -
Standalone class -- Directly sets
param_group["lr"]for full control. This is simpler to understand and debug than subclassing PyTorch schedulers. -
Why cosine -- Compared to step decay (sudden drops) or linear decay (constant rate of decrease), cosine decay spends more time near the peak LR (where most learning happens) and gradually slows down, giving the model a "soft landing" that improves final convergence.
Complexity Analysis
- Computation: O(1) per step -- just one cosine evaluation. Negligible compared to the training step.
- Hyperparameters: warmup_steps is typically 1-10% of total_steps. min_lr is usually 0 or max_lr/10. These are set once and rarely tuned.
Common schedule variants: (1) cosine with restarts (warm restarts), (2) linear decay, (3) inverse square root (used in the original transformer paper), (4) WSD (warmup-stable-decay, used in recent Chinchilla-style training).
Interview Tips
Key points: (1) Why warmup is needed -- prevents early instability from large gradients on random weights, especially important for Adam (whose adaptive learning rate estimates are unreliable in early steps). (2) Why cosine over linear decay -- cosine spends more time at high LR, giving better convergence empirically. (3) The linear scaling rule: when increasing batch size by K, increase LR by K (and increase warmup steps). (4) Learning rate is the single most important hyperparameter -- getting the schedule right matters more than model architecture changes. (5) Modern alternatives: WSD schedule, Chinchilla-optimal schedules with specific cooldown phases.
Quiz
Quiz — 3 Questions
Why is learning rate warmup important when training transformers?
According to the linear scaling rule, what should you do to the learning rate when doubling the batch size?
Why does cosine decay spend more training time at higher learning rates compared to linear decay?