Akshay’s Gradient
ML Codingintermediate40 min

Custom Autograd Function

Exercise: Custom Autograd Function

Learn to write custom backward passes in PyTorch by implementing torch.autograd.Function subclasses. This is essential when you need operations that PyTorch cannot differentiate automatically (or when the auto-generated backward is inefficient).

Problem Statement

Implement two custom autograd functions:

  1. MySigmoid -- sigmoid activation with a manually written backward pass
  2. MyLinearFunction -- a linear layer (y = xW^T + b) with manual backward for x, W, and b

Each must subclass torch.autograd.Function and implement both forward and backward as @staticmethod methods. Verify correctness using torch.autograd.gradcheck.

Inputs: Input tensors, weight matrices, biases.

Outputs: Output tensors with correct gradient computation on backward.

Key Concept

torch.autograd.Function lets you define custom forward and backward passes. The forward method saves tensors needed for the backward pass using ctx.save_for_backward(...). The backward method receives the upstream gradient and returns one gradient per forward input (or None if that input does not require gradients).

Interactive · Custom Autograd: Forward Save / Backward Compute
┌─────────────────────────────────────────────────────────────┐
│         Custom Autograd Function: Sigmoid Example           │
│                                                             │
│  FORWARD PASS:                                              │
│  ┌────────────────────────────────────────────┐             │
│  │  Input x ──▶ σ(x) = 1/(1+exp(-x)) ──▶ y  │             │
│  │                                            │             │
│  │  ctx.save_for_backward(y)  ◀── save output │             │
│  │  (not input! backward uses y directly)     │             │
│  └────────────────────────────────────────────┘             │
│                                                             │
│  BACKWARD PASS:                                             │
│  ┌────────────────────────────────────────────┐             │
│  │  grad_output  (dL/dy from upstream)        │             │
│  │       │                                    │             │
│  │       ▼                                    │             │
│  │  local gradient: dσ/dx = σ(x)*(1-σ(x))    │             │
│  │                        = y * (1 - y)       │             │
│  │       │                                    │             │
│  │       ▼                                    │             │
│  │  grad_input = grad_output * y * (1 - y)    │             │
│  │  (chain rule: dL/dx = dL/dy * dy/dx)       │             │
│  └────────────────────────────────────────────┘             │
│                                                             │
│  GRADIENT CHECK (verification):                             │
│  ┌────────────────────────────────────────────┐             │
│  │  For each element x_i:                     │             │
│  │    numeric = (f(x_i+ε) - f(x_i-ε)) / 2ε  │             │
│  │    analytic = backward result              │             │
│  │    assert |numeric - analytic| is small     │             │
│  └────────────────────────────────────────────┘             │
└─────────────────────────────────────────────────────────────┘
Warning

A subtle bug is saving the input tensor when you should save the output (or vice versa). For sigmoid, saving the output y is more efficient because the backward formula y * (1 - y) uses it directly, avoiding recomputing sigmoid. For linear layers, you must save both the input x and weights W because each is needed to compute the gradient of the other. Saving unnecessary tensors wastes memory in the computation graph.

Hints

Info
  1. Sigmoid: forward computes 1 / (1 + exp(-x)). Save the output for backward.
  2. Sigmoid backward: d_loss/d_x = d_loss/d_output * sigmoid(x) * (1 - sigmoid(x)).
  3. Linear forward: y = x @ W.T + b. Save x, W, b for backward.
  4. Linear backward: dx = grad_output @ W, dW = grad_output.T @ x, db = grad_output.sum(dim=0).
  5. Use torch.autograd.gradcheck(func, inputs, eps=1e-6) to verify your gradients numerically.
  6. Inputs to gradcheck must be float64 tensors with requires_grad=True.

Solution

import torch
from torch.autograd import Function, gradcheck
from typing import Any, Tuple, Optional


class MySigmoid(Function):
    """Custom sigmoid with manual backward."""

    @staticmethod
    def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
        output = 1.0 / (1.0 + torch.exp(-x))
        ctx.save_for_backward(output)  # save sigmoid output, not input
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
        (output,) = ctx.saved_tensors
        # d(sigmoid)/dx = sigmoid * (1 - sigmoid)
        grad_input = grad_output * output * (1.0 - output)
        return grad_input


class MyLinearFunction(Function):
    """Custom linear layer: y = x @ W^T + b, with manual backward."""

    @staticmethod
    def forward(
        ctx: Any,
        x: torch.Tensor,
        weight: torch.Tensor,
        bias: torch.Tensor,
    ) -> torch.Tensor:
        ctx.save_for_backward(x, weight, bias)
        output = x @ weight.T + bias
        return output

    @staticmethod
    def backward(
        ctx: Any, grad_output: torch.Tensor
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        x, weight, bias = ctx.saved_tensors

        # Gradient w.r.t. input x: grad_output @ W
        grad_x = grad_output @ weight

        # Gradient w.r.t. weight: grad_output^T @ x
        grad_weight = grad_output.T @ x

        # Gradient w.r.t. bias: sum over batch dimension
        grad_bias = grad_output.sum(dim=0)

        return grad_x, grad_weight, grad_bias


# Convenience wrappers
def my_sigmoid(x: torch.Tensor) -> torch.Tensor:
    return MySigmoid.apply(x)


def my_linear(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
    return MyLinearFunction.apply(x, weight, bias)


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)

    # === Test MySigmoid ===
    print("=== MySigmoid ===")
    x = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)

    # Forward
    y = my_sigmoid(x)
    y_ref = torch.sigmoid(x)
    print(f"Max forward diff: {(y - y_ref).abs().max():.2e}")

    # Backward
    loss = y.sum()
    loss.backward()
    grad_custom = x.grad.clone()

    x.grad = None
    loss_ref = torch.sigmoid(x).sum()
    loss_ref.backward()
    grad_ref = x.grad.clone()
    print(f"Max backward diff: {(grad_custom - grad_ref).abs().max():.2e}")

    # Numerical gradient check
    x_check = torch.randn(3, 4, requires_grad=True, dtype=torch.float64)
    assert gradcheck(MySigmoid.apply, (x_check,), eps=1e-6, atol=1e-4)
    print("gradcheck passed for MySigmoid.")

    # === Test MyLinearFunction ===
    print("\n=== MyLinearFunction ===")
    B, in_features, out_features = 5, 4, 3

    x = torch.randn(B, in_features, requires_grad=True, dtype=torch.float64)
    W = torch.randn(out_features, in_features, requires_grad=True, dtype=torch.float64)
    b = torch.randn(out_features, requires_grad=True, dtype=torch.float64)

    # Forward
    y = my_linear(x, W, b)
    y_ref = x @ W.T + b
    print(f"Max forward diff: {(y - y_ref).abs().max():.2e}")

    # Numerical gradient check
    assert gradcheck(MyLinearFunction.apply, (x, W, b), eps=1e-6, atol=1e-4)
    print("gradcheck passed for MyLinearFunction.")

    # === Integration: build a small network with custom ops ===
    print("\n=== Custom Network ===")
    x = torch.randn(8, 4, requires_grad=True, dtype=torch.float64)
    W1 = torch.randn(6, 4, requires_grad=True, dtype=torch.float64)
    b1 = torch.randn(6, requires_grad=True, dtype=torch.float64)
    W2 = torch.randn(2, 6, requires_grad=True, dtype=torch.float64)
    b2 = torch.randn(2, requires_grad=True, dtype=torch.float64)

    # Forward pass: two-layer network with custom ops
    h = my_sigmoid(my_linear(x, W1, b1))
    out = my_linear(h, W2, b2)
    loss = out.sum()
    loss.backward()

    print(f"Output shape: {out.shape}")
    print(f"x.grad shape: {x.grad.shape}")
    print(f"W1.grad shape: {W1.grad.shape}")
    print("All gradients computed successfully.")

Walkthrough

  1. MySigmoid forward -- Computes sigmoid and saves the output (not input) using ctx.save_for_backward. We save the output because the backward formula sigma * (1 - sigma) uses the output directly, avoiding recomputation.

  2. MySigmoid backward -- Multiplies the upstream gradient by the local Jacobian sigma * (1 - sigma). This is element-wise because sigmoid is an element-wise operation.

  3. MyLinearFunction forward -- Computes x @ W^T + b and saves all three inputs. Each is needed for computing the gradient of the other.

  4. MyLinearFunction backward -- Three gradients, one per input:

    • grad_x = grad_output @ W -- the gradient flows back through the weight matrix
    • grad_W = grad_output^T @ x -- each weight's gradient depends on the input it multiplied
    • grad_b = grad_output.sum(0) -- bias gradient is summed over the batch
  5. gradcheck -- Numerically verifies the analytic gradient by computing finite differences. It perturbs each input element by eps and checks that (f(x+eps) - f(x-eps)) / (2*eps) matches the analytic gradient. Always use float64 for gradcheck.

Complexity Analysis

  • Forward: Same as the standard operations -- O(n) for sigmoid, O(B * d_in * d_out) for linear.
  • Backward: Same complexity as forward for each gradient computation. The linear backward involves two matrix multiplies, each O(B * d_in * d_out).
  • Memory: ctx.save_for_backward stores references to tensors, increasing memory proportional to the saved tensors' sizes. This is the memory cost of backpropagation.

Interview Tips

Interview Tip

Demonstrate mastery of: (1) The chain rule in practice -- upstream gradient times local Jacobian. (2) Knowing what to save: save outputs when the backward formula uses them (sigmoid), save inputs when needed for other gradients (linear). (3) Shape reasoning: the gradient of a tensor must have the same shape as that tensor. (4) When to use custom autograd: fused kernels (FlashAttention), non-differentiable operations (straight-through estimator for quantization), or memory optimization (gradient checkpointing). (5) Always validate with gradcheck.

Quiz

Quiz — 3 Questions

Why should you use float64 (double precision) when running torch.autograd.gradcheck?

In a custom autograd Function, how many gradients must backward() return?

When would you implement a custom autograd Function instead of relying on PyTorch's automatic differentiation?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.