Akshay’s Gradient
ML Codingintermediate35 min

Custom LR Scheduler

Exercise: Learning Rate Scheduler with Cosine Warmup

Implement a learning rate scheduler with linear warmup followed by cosine decay -- the standard schedule used in training transformers, from BERT to GPT-4.

Problem Statement

Implement a CosineWarmupScheduler class that:

  1. Linearly increases the learning rate from 0 to max_lr during the warmup phase
  2. Decays the learning rate following a cosine curve from max_lr to min_lr during the remaining steps
  3. Works as a PyTorch LRScheduler subclass with proper step() method
  4. Also implement a standalone function version for clarity

Inputs: max_lr, min_lr, warmup_steps, total_steps.

Outputs: The learning rate at each step.

Key Concept

Warmup prevents early training instability: at initialization, model outputs are nearly random, so gradients are large and noisy. A small learning rate during warmup lets the model find a reasonable region before applying the full learning rate. Cosine decay smoothly reduces the learning rate to zero, which has been shown to outperform step decay and linear decay in practice.

Interactive · Cosine Warmup Learning Rate Schedule
┌──────────────────────────────────────────────────────────────────┐
│           Cosine Warmup Learning Rate Schedule                    │
│                                                                  │
│  LR                                                              │
│  max_lr ─ ─ ─ ─ ─ ─ ╱╲                                          │
│                     ╱  ╲                                          │
│                    ╱    ╲                                         │
│                   ╱      ╲                                        │
│                  ╱        ╲                                       │
│                 ╱          ╲                                      │
│                ╱            ╲                                     │
│               ╱              ╲                                    │
│              ╱   warmup       ╲  cosine decay                    │
│             ╱   (linear)       ╲  (smooth)                       │
│            ╱                    ╲                                 │
│           ╱                      ╲                                │
│          ╱                        ╲                               │
│  min_lr ╱─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ╲─ ─ ─ ─                     │
│         │                          │                              │
│  0 ─────┼──────────────────────────┼───────── steps              │
│       step 0              warmup  total                           │
│                           steps   steps                           │
│                                                                  │
│  Key: Warmup prevents early instability from large gradients.    │
│       Cosine decay gives a "soft landing" for better convergence.│
└──────────────────────────────────────────────────────────────────┘
Warning

A common mistake is calling scheduler.step() inside the batch loop instead of once per optimizer step. With gradient accumulation, the scheduler should step once per optimizer step, not once per micro-batch. Stepping too frequently causes the LR to decay faster than intended, leading to underfitting.

Hints

Info
  1. During warmup (step < warmup_steps): lr = max_lr * step / warmup_steps.
  2. During decay: normalize the progress to [0, 1] as progress = (step - warmup_steps) / (total_steps - warmup_steps).
  3. Cosine decay: lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * progress)).
  4. At progress=0, cos(0)=1, so lr=max_lr. At progress=1, cos(pi)=-1, so lr=min_lr.
  5. To subclass torch.optim.lr_scheduler.LambdaLR, provide a function that returns a multiplier.

Solution

import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from typing import List


def cosine_warmup_lr(
    step: int,
    max_lr: float,
    min_lr: float,
    warmup_steps: int,
    total_steps: int,
) -> float:
    """Compute the learning rate at a given step."""
    if step < warmup_steps:
        # Linear warmup: 0 -> max_lr
        return max_lr * step / warmup_steps
    elif step >= total_steps:
        return min_lr
    else:
        # Cosine decay: max_lr -> min_lr
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))


def get_cosine_warmup_scheduler(
    optimizer: torch.optim.Optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr_ratio: float = 0.0,
) -> LambdaLR:
    """
    Create a PyTorch LR scheduler with linear warmup + cosine decay.

    Uses LambdaLR which multiplies the base_lr by the returned factor.
    """

    def lr_lambda(current_step: int) -> float:
        if current_step < warmup_steps:
            # Linear warmup from 0 to 1
            return current_step / max(1, warmup_steps)
        # Cosine decay from 1 to min_lr_ratio
        progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
        progress = min(progress, 1.0)  # Clamp at the end
        return min_lr_ratio + 0.5 * (1.0 - min_lr_ratio) * (1.0 + math.cos(math.pi * progress))

    return LambdaLR(optimizer, lr_lambda)


class CosineWarmupScheduler:
    """Standalone scheduler (not a PyTorch subclass) for full control."""

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        max_lr: float,
        min_lr: float,
        warmup_steps: int,
        total_steps: int,
    ) -> None:
        self.optimizer = optimizer
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.current_step = 0

    def get_lr(self) -> float:
        return cosine_warmup_lr(
            self.current_step, self.max_lr, self.min_lr,
            self.warmup_steps, self.total_steps,
        )

    def step(self) -> None:
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
        self.current_step += 1


# ---------- demo ----------
if __name__ == "__main__":
    # Visualize the schedule
    max_lr = 3e-4
    min_lr = 1e-5
    warmup_steps = 100
    total_steps = 1000

    # Compute LR at each step
    lrs = [
        cosine_warmup_lr(s, max_lr, min_lr, warmup_steps, total_steps)
        for s in range(total_steps)
    ]

    # Print key points
    print(f"Step 0:    lr = {lrs[0]:.2e}")
    print(f"Step 50:   lr = {lrs[50]:.2e} (mid-warmup)")
    print(f"Step 100:  lr = {lrs[100]:.2e} (peak)")
    print(f"Step 500:  lr = {lrs[500]:.2e} (mid-decay)")
    print(f"Step 999:  lr = {lrs[999]:.2e} (near end)")

    # Verify with PyTorch scheduler
    model = nn.Linear(10, 2)
    optimizer = Adam(model.parameters(), lr=max_lr)
    scheduler = get_cosine_warmup_scheduler(
        optimizer, warmup_steps, total_steps, min_lr_ratio=min_lr / max_lr
    )

    pytorch_lrs = []
    for step in range(total_steps):
        pytorch_lrs.append(optimizer.param_groups[0]["lr"])
        # Simulate a training step
        optimizer.zero_grad()
        loss = model(torch.randn(4, 10)).sum()
        loss.backward()
        optimizer.step()
        scheduler.step()

    # Compare
    max_diff = max(abs(a - b) for a, b in zip(lrs, pytorch_lrs))
    print(f"\nMax diff (standalone vs PyTorch): {max_diff:.2e}")

    # Test standalone scheduler class
    optimizer2 = Adam(nn.Linear(10, 2).parameters(), lr=max_lr)
    sched2 = CosineWarmupScheduler(optimizer2, max_lr, min_lr, warmup_steps, total_steps)
    custom_lrs = []
    for _ in range(total_steps):
        custom_lrs.append(sched2.get_lr())
        sched2.step()

    max_diff2 = max(abs(a - b) for a, b in zip(lrs, custom_lrs))
    print(f"Max diff (standalone func vs class): {max_diff2:.2e}")

    # ASCII visualization
    print("\n--- Learning Rate Schedule ---")
    width = 60
    for i in range(0, total_steps, total_steps // 20):
        bar_len = int(lrs[i] / max_lr * width)
        bar = "#" * bar_len
        print(f"Step {i:4d} | {bar:<{width}} | {lrs[i]:.2e}")

Walkthrough

  1. Linear warmup -- For steps 0 to warmup_steps, the LR increases linearly from 0 to max_lr. The formula lr = max_lr * step / warmup_steps gives a straight ramp. At step 0, lr=0 (no update); at step=warmup_steps, lr=max_lr.

  2. Cosine decay -- After warmup, we normalize the remaining steps to a progress value in [0, 1]. The cosine formula 0.5 * (1 + cos(pi * progress)) maps [0, 1] to [1, 0] smoothly. Scaling by (max_lr - min_lr) and adding min_lr gives us decay from max_lr to min_lr.

  3. LambdaLR integration -- PyTorch's LambdaLR expects a function that returns a multiplier applied to the base learning rate. So we return values in [0, 1] (or [min_lr_ratio, 1]) rather than absolute LR values.

  4. Standalone class -- Directly sets param_group["lr"] for full control. This is simpler to understand and debug than subclassing PyTorch schedulers.

  5. Why cosine -- Compared to step decay (sudden drops) or linear decay (constant rate of decrease), cosine decay spends more time near the peak LR (where most learning happens) and gradually slows down, giving the model a "soft landing" that improves final convergence.

Complexity Analysis

  • Computation: O(1) per step -- just one cosine evaluation. Negligible compared to the training step.
  • Hyperparameters: warmup_steps is typically 1-10% of total_steps. min_lr is usually 0 or max_lr/10. These are set once and rarely tuned.

Common schedule variants: (1) cosine with restarts (warm restarts), (2) linear decay, (3) inverse square root (used in the original transformer paper), (4) WSD (warmup-stable-decay, used in recent Chinchilla-style training).

Interview Tips

Interview Tip

Key points: (1) Why warmup is needed -- prevents early instability from large gradients on random weights, especially important for Adam (whose adaptive learning rate estimates are unreliable in early steps). (2) Why cosine over linear decay -- cosine spends more time at high LR, giving better convergence empirically. (3) The linear scaling rule: when increasing batch size by K, increase LR by K (and increase warmup steps). (4) Learning rate is the single most important hyperparameter -- getting the schedule right matters more than model architecture changes. (5) Modern alternatives: WSD schedule, Chinchilla-optimal schedules with specific cooldown phases.

Quiz

Quiz — 3 Questions

Why is learning rate warmup important when training transformers?

According to the linear scaling rule, what should you do to the learning rate when doubling the batch size?

Why does cosine decay spend more training time at higher learning rates compared to linear decay?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.