Akshay’s Gradient
ML Codingbeginner35 min

Softmax & Cross-Entropy

Implement Softmax and Cross-Entropy Loss

Implement numerically stable softmax and cross-entropy loss -- the standard output layer and loss function for classification in deep learning.

Problem Statement

Implement three functions:

  1. softmax(logits) -- convert raw logits to a probability distribution along the last axis.
  2. cross_entropy_loss(logits, targets) -- compute the mean cross-entropy loss for a batch.
  3. softmax_cross_entropy_backward(probs, targets) -- compute the gradient of the loss w.r.t. logits.

Inputs:

  • logits: array of shape (batch_size, num_classes) -- raw unnormalized scores.
  • targets: integer array of shape (batch_size,) -- ground-truth class indices.

Outputs:

  • Softmax returns probabilities of the same shape as logits.
  • Loss returns a scalar (mean over the batch).
  • Backward returns gradient of shape (batch_size, num_classes).
Key Concept

Naive softmax (exp(x) / sum(exp(x))) overflows for large logits. The trick is to subtract max(logits) before exponentiating: softmax(x) = exp(x - max(x)) / sum(exp(x - max(x))). This does not change the result mathematically but keeps all exponents non-positive, preventing overflow.

Interactive · Softmax and Cross-Entropy Forward/Backward Flow
┌───────────────────────────────────────────────────────────────────┐
│          Softmax + Cross-Entropy: Forward and Backward           │
│                                                                   │
│  Logits (raw scores)         Softmax Probabilities               │
│  ┌─────────────────┐         ┌─────────────────┐                 │
│  │ [3.2, 1.1, 0.5] │  ──▶   │ [0.78, 0.10, 0.05] │              │
│  │ [0.3, 2.8, 1.0] │  ──▶   │ [0.06, 0.72, 0.12] │              │
│  └─────────────────┘         └────────┬────────┘                 │
│         │                             │                           │
│         │ 1. Subtract max per row     │ 3. Index with targets     │
│         │ 2. exp() / sum(exp())       │    targets = [0, 1]       │
│         ▼                             ▼                           │
│  ┌─────────────────┐         ┌─────────────────┐                 │
│  │ Stable: no      │         │ Pick: 0.78, 0.72│                 │
│  │ overflow/NaN     │         │ Loss = -mean(   │                 │
│  └─────────────────┘         │   log(0.78),    │                 │
│                               │   log(0.72))   │                 │
│                               └────────┬────────┘                 │
│                                        │                          │
│  Backward (elegant shortcut):          ▼                          │
│  ┌──────────────────────────────────────────────────┐             │
│  │ grad = probs - one_hot(targets)                  │             │
│  │      = [0.78-1, 0.10, 0.05] = [-0.22, 0.10, 0.05]│            │
│  │      = [0.06, 0.72-1, 0.12] = [ 0.06,-0.28, 0.12]│            │
│  │ grad /= batch_size                               │             │
│  └──────────────────────────────────────────────────┘             │
└───────────────────────────────────────────────────────────────────┘
Warning

Never compute log(softmax(x)) by first computing softmax then taking the log. This loses precision because softmax outputs near-zero values for non-max classes, and log(~0) is very inaccurate. Use log_softmax or the logsumexp trick instead: log_softmax(x) = x - max(x) - log(sum(exp(x - max(x)))). In production code, always use torch.nn.functional.cross_entropy which computes this fused operation.

Hints

Info
  1. For numerical stability, compute x_shifted = logits - max(logits, axis=-1, keepdims=True).
  2. Exponentiate: exp_x = exp(x_shifted).
  3. Normalize: probs = exp_x / sum(exp_x, axis=-1, keepdims=True).
  4. Cross-entropy: pick the probability of the correct class with advanced indexing, then take -log(...).
  5. Clip the probability before taking the log to avoid log(0).
  6. The gradient of softmax-cross-entropy is elegantly simple: grad = probs.copy(); grad[i, targets[i]] -= 1 for each sample, then divide by batch size.

Solution

import numpy as np
from typing import Tuple


def softmax(logits: np.ndarray) -> np.ndarray:
    """Numerically stable softmax along the last axis."""
    # Subtract max for numerical stability
    shifted = logits - np.max(logits, axis=-1, keepdims=True)
    exp_shifted = np.exp(shifted)
    return exp_shifted / np.sum(exp_shifted, axis=-1, keepdims=True)


def cross_entropy_loss(logits: np.ndarray, targets: np.ndarray) -> Tuple[float, np.ndarray]:
    """
    Compute mean cross-entropy loss and return (loss, probs).

    Args:
        logits:  (batch_size, num_classes) raw scores.
        targets: (batch_size,) integer class labels.

    Returns:
        loss:  scalar mean cross-entropy.
        probs: (batch_size, num_classes) softmax probabilities.
    """
    probs = softmax(logits)
    batch_size = logits.shape[0]
    # Pick the predicted probability for each true class
    correct_log_probs = -np.log(np.clip(probs[np.arange(batch_size), targets], 1e-12, 1.0))
    loss = float(np.mean(correct_log_probs))
    return loss, probs


def softmax_cross_entropy_backward(
    probs: np.ndarray, targets: np.ndarray
) -> np.ndarray:
    """
    Gradient of cross-entropy loss w.r.t. logits.

    The combined gradient is simply: probs - one_hot(targets),
    averaged over the batch.
    """
    batch_size = probs.shape[0]
    grad = probs.copy()
    grad[np.arange(batch_size), targets] -= 1.0
    grad /= batch_size
    return grad


# ---------- demo ----------
if __name__ == "__main__":
    np.random.seed(0)
    batch_size, num_classes = 4, 5
    logits = np.random.randn(batch_size, num_classes) * 10  # large values
    targets = np.array([0, 3, 2, 4])

    loss, probs = cross_entropy_loss(logits, targets)
    grad = softmax_cross_entropy_backward(probs, targets)

    print(f"Probs (row sums): {probs.sum(axis=1)}")  # should all be 1.0
    print(f"Loss: {loss:.4f}")
    print(f"Grad shape: {grad.shape}")  # (4, 5)

    # Numerical gradient check
    eps = 1e-5
    numerical_grad = np.zeros_like(logits)
    for i in range(batch_size):
        for j in range(num_classes):
            logits[i, j] += eps
            loss_plus, _ = cross_entropy_loss(logits, targets)
            logits[i, j] -= 2 * eps
            loss_minus, _ = cross_entropy_loss(logits, targets)
            logits[i, j] += eps
            numerical_grad[i, j] = (loss_plus - loss_minus) / (2 * eps)
    print(f"Max gradient error: {np.max(np.abs(grad - numerical_grad)):.2e}")

Walkthrough

  1. Stable softmax -- Subtracting the row-wise maximum ensures the largest exponent is exp(0) = 1, preventing overflow. The mathematical equivalence holds because exp(x - c) / sum(exp(x - c)) = exp(x) / sum(exp(x)) for any constant c.

  2. Cross-entropy -- We use advanced indexing probs[arange(B), targets] to extract the predicted probability for each sample's true class. The np.clip prevents log(0) = -inf.

  3. Combined backward pass -- The derivative of softmax + cross-entropy w.r.t. logits simplifies to p - y_one_hot. This elegant result is one reason this combination is universally used. We divide by batch_size because the loss is a mean.

  4. Gradient check -- The demo verifies the analytic gradient against a finite-difference numerical gradient, which is standard practice when implementing custom backward passes.

Complexity Analysis

  • Softmax: O(B * C) time and space where B = batch size, C = classes. Two passes over each row (one for max, one for sum).
  • Cross-entropy: O(B) additional work for the indexing and log.
  • Backward: O(B * C) for the copy and index subtraction.

All operations are embarrassingly parallelizable across the batch dimension.

Interview Tips

Interview Tip

Key things interviewers test: (1) Can you explain why subtracting the max is necessary? (Overflow for large logits, not just a trick.) (2) Do you know the combined gradient p - y? This shortcut avoids computing the full Jacobian of softmax. (3) Can you extend to label smoothing? (Replace the one-hot target with a smoothed distribution.) (4) Be ready to discuss the relationship between softmax temperature and prediction sharpness.

Quiz

Quiz — 3 Questions

What happens if you compute softmax without subtracting the maximum logit?

Why is the combined gradient of softmax + cross-entropy so simple (probs - one_hot)?

Why do we clip the probability before taking the log in cross-entropy (np.clip(p, 1e-12, 1.0))?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.