Exercise: Custom Autograd Function
Learn to write custom backward passes in PyTorch by implementing torch.autograd.Function subclasses. This is essential when you need operations that PyTorch cannot differentiate automatically (or when the auto-generated backward is inefficient).
Problem Statement
Implement two custom autograd functions:
MySigmoid-- sigmoid activation with a manually written backward passMyLinearFunction-- a linear layer (y = xW^T + b) with manual backward for x, W, and b
Each must subclass torch.autograd.Function and implement both forward and backward as @staticmethod methods. Verify correctness using torch.autograd.gradcheck.
Inputs: Input tensors, weight matrices, biases.
Outputs: Output tensors with correct gradient computation on backward.
torch.autograd.Function lets you define custom forward and backward passes. The forward method saves tensors needed for the backward pass using ctx.save_for_backward(...). The backward method receives the upstream gradient and returns one gradient per forward input (or None if that input does not require gradients).
┌─────────────────────────────────────────────────────────────┐
│ Custom Autograd Function: Sigmoid Example │
│ │
│ FORWARD PASS: │
│ ┌────────────────────────────────────────────┐ │
│ │ Input x ──▶ σ(x) = 1/(1+exp(-x)) ──▶ y │ │
│ │ │ │
│ │ ctx.save_for_backward(y) ◀── save output │ │
│ │ (not input! backward uses y directly) │ │
│ └────────────────────────────────────────────┘ │
│ │
│ BACKWARD PASS: │
│ ┌────────────────────────────────────────────┐ │
│ │ grad_output (dL/dy from upstream) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ local gradient: dσ/dx = σ(x)*(1-σ(x)) │ │
│ │ = y * (1 - y) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ grad_input = grad_output * y * (1 - y) │ │
│ │ (chain rule: dL/dx = dL/dy * dy/dx) │ │
│ └────────────────────────────────────────────┘ │
│ │
│ GRADIENT CHECK (verification): │
│ ┌────────────────────────────────────────────┐ │
│ │ For each element x_i: │ │
│ │ numeric = (f(x_i+ε) - f(x_i-ε)) / 2ε │ │
│ │ analytic = backward result │ │
│ │ assert |numeric - analytic| is small │ │
│ └────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
A subtle bug is saving the input tensor when you should save the output (or vice versa). For sigmoid, saving the output y is more efficient because the backward formula y * (1 - y) uses it directly, avoiding recomputing sigmoid. For linear layers, you must save both the input x and weights W because each is needed to compute the gradient of the other. Saving unnecessary tensors wastes memory in the computation graph.
Hints
- Sigmoid:
forwardcomputes1 / (1 + exp(-x)). Save the output for backward. - Sigmoid backward:
d_loss/d_x = d_loss/d_output * sigmoid(x) * (1 - sigmoid(x)). - Linear forward:
y = x @ W.T + b. Savex,W,bfor backward. - Linear backward:
dx = grad_output @ W,dW = grad_output.T @ x,db = grad_output.sum(dim=0). - Use
torch.autograd.gradcheck(func, inputs, eps=1e-6)to verify your gradients numerically. - Inputs to
gradcheckmust be float64 tensors withrequires_grad=True.
Solution
import torch
from torch.autograd import Function, gradcheck
from typing import Any, Tuple, Optional
class MySigmoid(Function):
"""Custom sigmoid with manual backward."""
@staticmethod
def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
output = 1.0 / (1.0 + torch.exp(-x))
ctx.save_for_backward(output) # save sigmoid output, not input
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
(output,) = ctx.saved_tensors
# d(sigmoid)/dx = sigmoid * (1 - sigmoid)
grad_input = grad_output * output * (1.0 - output)
return grad_input
class MyLinearFunction(Function):
"""Custom linear layer: y = x @ W^T + b, with manual backward."""
@staticmethod
def forward(
ctx: Any,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
ctx.save_for_backward(x, weight, bias)
output = x @ weight.T + bias
return output
@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
x, weight, bias = ctx.saved_tensors
# Gradient w.r.t. input x: grad_output @ W
grad_x = grad_output @ weight
# Gradient w.r.t. weight: grad_output^T @ x
grad_weight = grad_output.T @ x
# Gradient w.r.t. bias: sum over batch dimension
grad_bias = grad_output.sum(dim=0)
return grad_x, grad_weight, grad_bias
# Convenience wrappers
def my_sigmoid(x: torch.Tensor) -> torch.Tensor:
return MySigmoid.apply(x)
def my_linear(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return MyLinearFunction.apply(x, weight, bias)
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
# === Test MySigmoid ===
print("=== MySigmoid ===")
x = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
# Forward
y = my_sigmoid(x)
y_ref = torch.sigmoid(x)
print(f"Max forward diff: {(y - y_ref).abs().max():.2e}")
# Backward
loss = y.sum()
loss.backward()
grad_custom = x.grad.clone()
x.grad = None
loss_ref = torch.sigmoid(x).sum()
loss_ref.backward()
grad_ref = x.grad.clone()
print(f"Max backward diff: {(grad_custom - grad_ref).abs().max():.2e}")
# Numerical gradient check
x_check = torch.randn(3, 4, requires_grad=True, dtype=torch.float64)
assert gradcheck(MySigmoid.apply, (x_check,), eps=1e-6, atol=1e-4)
print("gradcheck passed for MySigmoid.")
# === Test MyLinearFunction ===
print("\n=== MyLinearFunction ===")
B, in_features, out_features = 5, 4, 3
x = torch.randn(B, in_features, requires_grad=True, dtype=torch.float64)
W = torch.randn(out_features, in_features, requires_grad=True, dtype=torch.float64)
b = torch.randn(out_features, requires_grad=True, dtype=torch.float64)
# Forward
y = my_linear(x, W, b)
y_ref = x @ W.T + b
print(f"Max forward diff: {(y - y_ref).abs().max():.2e}")
# Numerical gradient check
assert gradcheck(MyLinearFunction.apply, (x, W, b), eps=1e-6, atol=1e-4)
print("gradcheck passed for MyLinearFunction.")
# === Integration: build a small network with custom ops ===
print("\n=== Custom Network ===")
x = torch.randn(8, 4, requires_grad=True, dtype=torch.float64)
W1 = torch.randn(6, 4, requires_grad=True, dtype=torch.float64)
b1 = torch.randn(6, requires_grad=True, dtype=torch.float64)
W2 = torch.randn(2, 6, requires_grad=True, dtype=torch.float64)
b2 = torch.randn(2, requires_grad=True, dtype=torch.float64)
# Forward pass: two-layer network with custom ops
h = my_sigmoid(my_linear(x, W1, b1))
out = my_linear(h, W2, b2)
loss = out.sum()
loss.backward()
print(f"Output shape: {out.shape}")
print(f"x.grad shape: {x.grad.shape}")
print(f"W1.grad shape: {W1.grad.shape}")
print("All gradients computed successfully.")
Walkthrough
-
MySigmoid forward -- Computes sigmoid and saves the output (not input) using
ctx.save_for_backward. We save the output because the backward formulasigma * (1 - sigma)uses the output directly, avoiding recomputation. -
MySigmoid backward -- Multiplies the upstream gradient by the local Jacobian
sigma * (1 - sigma). This is element-wise because sigmoid is an element-wise operation. -
MyLinearFunction forward -- Computes
x @ W^T + band saves all three inputs. Each is needed for computing the gradient of the other. -
MyLinearFunction backward -- Three gradients, one per input:
grad_x = grad_output @ W-- the gradient flows back through the weight matrixgrad_W = grad_output^T @ x-- each weight's gradient depends on the input it multipliedgrad_b = grad_output.sum(0)-- bias gradient is summed over the batch
-
gradcheck -- Numerically verifies the analytic gradient by computing finite differences. It perturbs each input element by
epsand checks that(f(x+eps) - f(x-eps)) / (2*eps)matches the analytic gradient. Always use float64 for gradcheck.
Complexity Analysis
- Forward: Same as the standard operations -- O(n) for sigmoid, O(B * d_in * d_out) for linear.
- Backward: Same complexity as forward for each gradient computation. The linear backward involves two matrix multiplies, each O(B * d_in * d_out).
- Memory:
ctx.save_for_backwardstores references to tensors, increasing memory proportional to the saved tensors' sizes. This is the memory cost of backpropagation.
Interview Tips
Demonstrate mastery of: (1) The chain rule in practice -- upstream gradient times local Jacobian. (2) Knowing what to save: save outputs when the backward formula uses them (sigmoid), save inputs when needed for other gradients (linear). (3) Shape reasoning: the gradient of a tensor must have the same shape as that tensor. (4) When to use custom autograd: fused kernels (FlashAttention), non-differentiable operations (straight-through estimator for quantization), or memory optimization (gradient checkpointing). (5) Always validate with gradcheck.
Quiz
Quiz — 3 Questions
Why should you use float64 (double precision) when running torch.autograd.gradcheck?
In a custom autograd Function, how many gradients must backward() return?
When would you implement a custom autograd Function instead of relying on PyTorch's automatic differentiation?