Implement Softmax and Cross-Entropy Loss
Implement numerically stable softmax and cross-entropy loss -- the standard output layer and loss function for classification in deep learning.
Problem Statement
Implement three functions:
softmax(logits)-- convert raw logits to a probability distribution along the last axis.cross_entropy_loss(logits, targets)-- compute the mean cross-entropy loss for a batch.softmax_cross_entropy_backward(probs, targets)-- compute the gradient of the loss w.r.t. logits.
Inputs:
logits: array of shape(batch_size, num_classes)-- raw unnormalized scores.targets: integer array of shape(batch_size,)-- ground-truth class indices.
Outputs:
- Softmax returns probabilities of the same shape as logits.
- Loss returns a scalar (mean over the batch).
- Backward returns gradient of shape
(batch_size, num_classes).
Naive softmax (exp(x) / sum(exp(x))) overflows for large logits. The trick is to subtract max(logits) before exponentiating: softmax(x) = exp(x - max(x)) / sum(exp(x - max(x))). This does not change the result mathematically but keeps all exponents non-positive, preventing overflow.
┌───────────────────────────────────────────────────────────────────┐
│ Softmax + Cross-Entropy: Forward and Backward │
│ │
│ Logits (raw scores) Softmax Probabilities │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ [3.2, 1.1, 0.5] │ ──▶ │ [0.78, 0.10, 0.05] │ │
│ │ [0.3, 2.8, 1.0] │ ──▶ │ [0.06, 0.72, 0.12] │ │
│ └─────────────────┘ └────────┬────────┘ │
│ │ │ │
│ │ 1. Subtract max per row │ 3. Index with targets │
│ │ 2. exp() / sum(exp()) │ targets = [0, 1] │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Stable: no │ │ Pick: 0.78, 0.72│ │
│ │ overflow/NaN │ │ Loss = -mean( │ │
│ └─────────────────┘ │ log(0.78), │ │
│ │ log(0.72)) │ │
│ └────────┬────────┘ │
│ │ │
│ Backward (elegant shortcut): ▼ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ grad = probs - one_hot(targets) │ │
│ │ = [0.78-1, 0.10, 0.05] = [-0.22, 0.10, 0.05]│ │
│ │ = [0.06, 0.72-1, 0.12] = [ 0.06,-0.28, 0.12]│ │
│ │ grad /= batch_size │ │
│ └──────────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────────────┘
Never compute log(softmax(x)) by first computing softmax then taking the log. This loses precision because softmax outputs near-zero values for non-max classes, and log(~0) is very inaccurate. Use log_softmax or the logsumexp trick instead: log_softmax(x) = x - max(x) - log(sum(exp(x - max(x)))). In production code, always use torch.nn.functional.cross_entropy which computes this fused operation.
Hints
- For numerical stability, compute
x_shifted = logits - max(logits, axis=-1, keepdims=True). - Exponentiate:
exp_x = exp(x_shifted). - Normalize:
probs = exp_x / sum(exp_x, axis=-1, keepdims=True). - Cross-entropy: pick the probability of the correct class with advanced indexing, then take
-log(...). - Clip the probability before taking the log to avoid
log(0). - The gradient of softmax-cross-entropy is elegantly simple:
grad = probs.copy(); grad[i, targets[i]] -= 1for each sample, then divide by batch size.
Solution
import numpy as np
from typing import Tuple
def softmax(logits: np.ndarray) -> np.ndarray:
"""Numerically stable softmax along the last axis."""
# Subtract max for numerical stability
shifted = logits - np.max(logits, axis=-1, keepdims=True)
exp_shifted = np.exp(shifted)
return exp_shifted / np.sum(exp_shifted, axis=-1, keepdims=True)
def cross_entropy_loss(logits: np.ndarray, targets: np.ndarray) -> Tuple[float, np.ndarray]:
"""
Compute mean cross-entropy loss and return (loss, probs).
Args:
logits: (batch_size, num_classes) raw scores.
targets: (batch_size,) integer class labels.
Returns:
loss: scalar mean cross-entropy.
probs: (batch_size, num_classes) softmax probabilities.
"""
probs = softmax(logits)
batch_size = logits.shape[0]
# Pick the predicted probability for each true class
correct_log_probs = -np.log(np.clip(probs[np.arange(batch_size), targets], 1e-12, 1.0))
loss = float(np.mean(correct_log_probs))
return loss, probs
def softmax_cross_entropy_backward(
probs: np.ndarray, targets: np.ndarray
) -> np.ndarray:
"""
Gradient of cross-entropy loss w.r.t. logits.
The combined gradient is simply: probs - one_hot(targets),
averaged over the batch.
"""
batch_size = probs.shape[0]
grad = probs.copy()
grad[np.arange(batch_size), targets] -= 1.0
grad /= batch_size
return grad
# ---------- demo ----------
if __name__ == "__main__":
np.random.seed(0)
batch_size, num_classes = 4, 5
logits = np.random.randn(batch_size, num_classes) * 10 # large values
targets = np.array([0, 3, 2, 4])
loss, probs = cross_entropy_loss(logits, targets)
grad = softmax_cross_entropy_backward(probs, targets)
print(f"Probs (row sums): {probs.sum(axis=1)}") # should all be 1.0
print(f"Loss: {loss:.4f}")
print(f"Grad shape: {grad.shape}") # (4, 5)
# Numerical gradient check
eps = 1e-5
numerical_grad = np.zeros_like(logits)
for i in range(batch_size):
for j in range(num_classes):
logits[i, j] += eps
loss_plus, _ = cross_entropy_loss(logits, targets)
logits[i, j] -= 2 * eps
loss_minus, _ = cross_entropy_loss(logits, targets)
logits[i, j] += eps
numerical_grad[i, j] = (loss_plus - loss_minus) / (2 * eps)
print(f"Max gradient error: {np.max(np.abs(grad - numerical_grad)):.2e}")
Walkthrough
-
Stable softmax -- Subtracting the row-wise maximum ensures the largest exponent is
exp(0) = 1, preventing overflow. The mathematical equivalence holds becauseexp(x - c) / sum(exp(x - c)) = exp(x) / sum(exp(x))for any constantc. -
Cross-entropy -- We use advanced indexing
probs[arange(B), targets]to extract the predicted probability for each sample's true class. Thenp.clippreventslog(0) = -inf. -
Combined backward pass -- The derivative of
softmax + cross-entropyw.r.t. logits simplifies top - y_one_hot. This elegant result is one reason this combination is universally used. We divide bybatch_sizebecause the loss is a mean. -
Gradient check -- The demo verifies the analytic gradient against a finite-difference numerical gradient, which is standard practice when implementing custom backward passes.
Complexity Analysis
- Softmax: O(B * C) time and space where B = batch size, C = classes. Two passes over each row (one for max, one for sum).
- Cross-entropy: O(B) additional work for the indexing and log.
- Backward: O(B * C) for the copy and index subtraction.
All operations are embarrassingly parallelizable across the batch dimension.
Interview Tips
Key things interviewers test: (1) Can you explain why subtracting the max is necessary? (Overflow for large logits, not just a trick.) (2) Do you know the combined gradient p - y? This shortcut avoids computing the full Jacobian of softmax. (3) Can you extend to label smoothing? (Replace the one-hot target with a smoothed distribution.) (4) Be ready to discuss the relationship between softmax temperature and prediction sharpness.
Quiz
Quiz — 3 Questions
What happens if you compute softmax without subtracting the maximum logit?
Why is the combined gradient of softmax + cross-entropy so simple (probs - one_hot)?
Why do we clip the probability before taking the log in cross-entropy (np.clip(p, 1e-12, 1.0))?