Akshay’s Gradient
ML Codingbeginner30 min

Batch Matrix Multiply

Exercise: Batch Matrix Multiplication and Einsum

Master batched matrix multiplication using both manual implementation and Einstein summation notation. These operations are the backbone of transformer computations.

Problem Statement

Implement three functions:

  1. batch_matmul_manual(A, B) -- batched matrix multiply using only basic operations (no torch.bmm or @)
  2. batch_matmul_einsum(A, B) -- same operation using torch.einsum
  3. einsum_attention(Q, K, V) -- compute multi-head attention scores using only einsum operations

Inputs:

  • A: tensor of shape (batch, n, m)
  • B: tensor of shape (batch, m, p)
  • Q, K, V: tensors of shape (batch, heads, seq_len, d_k)

Outputs:

  • Matrix product of shape (batch, n, p) for the first two functions
  • Attention output of shape (batch, heads, seq_len, d_k) for the third
Key Concept

Einstein summation (einsum) is a compact notation for expressing tensor contractions. The rule: indices that appear in the input but not the output are summed over. For example, "bij,bjk->bik" means: for each batch b, sum over j to compute C[b,i,k] = sum_j A[b,i,j] * B[b,j,k].

Interactive · Batch Matmul and Einsum Notation
┌────────────────────────────────────────────────────────────────┐
│            Batch Matrix Multiply via Einsum                    │
│                                                                │
│  Standard: "bij,bjk->bik"                                      │
│                                                                │
│  A (batch, n, m)     B (batch, m, p)     C (batch, n, p)      │
│  ┌───────┐           ┌───────┐           ┌───────┐            │
│  │ b i j │    @      │ b j k │    =      │ b i k │            │
│  └───────┘           └───────┘           └───────┘            │
│       │                  │                    │                │
│       └──────┬───────────┘                    │                │
│              ▼                                │                │
│     Sum over j (contracted)      ─────────────┘                │
│                                                                │
│  Attention: "bhqd,bhkd->bhqk"                                  │
│                                                                │
│  Q (B, H, T_q, d_k)    K (B, H, T_k, d_k)                    │
│       │                      │                                 │
│       └──────────┬───────────┘                                 │
│                  ▼                                              │
│     Contract over d (dot product)                              │
│                  │                                              │
│                  ▼                                              │
│     Scores (B, H, T_q, T_k)                                   │
│                                                                │
│  Reading einsum strings:                                       │
│  • Indices in BOTH inputs but NOT output → summed (contracted) │
│  • Indices in output → kept (free indices)                     │
└────────────────────────────────────────────────────────────────┘
Warning

The manual batch matmul using unsqueeze and broadcasting creates a temporary 4D tensor of shape (B, n, m, p), which can be very memory-expensive. For large matrices, prefer torch.bmm, the @ operator, or torch.einsum, which use optimized BLAS kernels and do not materialize the full outer product.

Interview Tip

When asked about einsum in interviews, the quickest way to reason about any expression is: "indices that appear in the inputs but NOT in the output are summed over." Practice reading expressions backward: given "bhqk,bhkd->bhqd", you can immediately see this is attention_weights @ V because k (the key position) is being contracted.

Hints

Info
  1. For manual batch matmul, use torch.sum(A.unsqueeze(-1) * B.unsqueeze(-3), dim=-2).
  2. The key insight: A[:,:,:,None] * B[:,None,:,:] creates a 4D tensor, and summing over the shared dimension gives the matrix product.
  3. Einsum for batch matmul: "bij,bjk->bik".
  4. For attention: scores = "bhqd,bhkd->bhqk", then output = "bhqk,bhkd->bhqd".
  5. Do not forget the scaling factor 1/sqrt(d_k) in attention.

Solution

import torch
import torch.nn.functional as F
import math


def batch_matmul_manual(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Batch matrix multiply without using bmm or @.
    A: (batch, n, m), B: (batch, m, p) -> (batch, n, p)
    """
    # Expand dimensions for broadcasting:
    # A: (batch, n, m, 1) * B: (batch, 1, m, p) -> (batch, n, m, p)
    # Sum over m (dim=-2) to get (batch, n, p)
    return torch.sum(A.unsqueeze(-1) * B.unsqueeze(-3), dim=-2)


def batch_matmul_einsum(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Batch matrix multiply using einsum."""
    return torch.einsum("bij,bjk->bik", A, B)


def einsum_attention(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
) -> torch.Tensor:
    """
    Multi-head attention using only einsum operations.
    Q, K, V: (batch, heads, seq_len, d_k)
    Returns: (batch, heads, seq_len, d_k)
    """
    d_k = Q.size(-1)

    # Compute attention scores: (batch, heads, seq_q, seq_k)
    scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) / math.sqrt(d_k)

    # Softmax over key dimension
    attn_weights = F.softmax(scores, dim=-1)

    # Weighted sum of values: (batch, heads, seq_q, d_k)
    output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, V)
    return output


# Bonus: common einsum patterns in transformers
def einsum_examples() -> None:
    """Demonstrate useful einsum patterns."""
    B, H, T, D = 2, 4, 8, 16

    # 1. Linear projection (no explicit weight broadcasting needed)
    x = torch.randn(B, T, D)
    W = torch.randn(D, D)
    projected = torch.einsum("btd,de->bte", x, W)
    assert projected.shape == (B, T, D)

    # 2. Outer product (for attention pattern analysis)
    a = torch.randn(T)
    b = torch.randn(T)
    outer = torch.einsum("i,j->ij", a, b)
    assert outer.shape == (T, T)

    # 3. Batch trace
    M = torch.randn(B, D, D)
    traces = torch.einsum("bii->b", M)
    assert traces.shape == (B,)

    # 4. Bilinear form: x^T A y for each batch
    x = torch.randn(B, D)
    A = torch.randn(D, D)
    y = torch.randn(B, D)
    bilinear = torch.einsum("bi,ij,bj->b", x, A, y)
    assert bilinear.shape == (B,)

    print("All einsum patterns verified.")


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)

    # Test batch matmul
    A = torch.randn(4, 3, 5)
    B = torch.randn(4, 5, 7)

    result_manual = batch_matmul_manual(A, B)
    result_einsum = batch_matmul_einsum(A, B)
    result_builtin = torch.bmm(A, B)

    print(f"Manual shape:  {result_manual.shape}")
    print(f"Einsum shape:  {result_einsum.shape}")
    print(f"Max diff (manual vs builtin):  {(result_manual - result_builtin).abs().max():.2e}")
    print(f"Max diff (einsum vs builtin):  {(result_einsum - result_builtin).abs().max():.2e}")

    # Test einsum attention
    Q = torch.randn(2, 4, 8, 16)
    K = torch.randn(2, 4, 8, 16)
    V = torch.randn(2, 4, 8, 16)

    attn_out = einsum_attention(Q, K, V)
    print(f"\nAttention output shape: {attn_out.shape}")  # (2, 4, 8, 16)

    # Verify against standard implementation
    scores_ref = (Q @ K.transpose(-2, -1)) / math.sqrt(16)
    attn_ref = F.softmax(scores_ref, dim=-1) @ V
    print(f"Max diff (einsum vs standard): {(attn_out - attn_ref).abs().max():.2e}")

    einsum_examples()

Walkthrough

  1. Manual batch matmul -- We use broadcasting to compute the outer product along the contracted dimension. A.unsqueeze(-1) adds a trailing dimension to A, B.unsqueeze(-3) adds a dimension at position -3. Their element-wise product has shape (B, n, m, p), and summing over m yields the matrix product.

  2. Einsum batch matmul -- "bij,bjk->bik" is read as: "for each batch index b, contract over j (it appears in both inputs but not in the output) to compute C[b,i,k]". This is more readable and often just as fast as bmm.

  3. Einsum attention -- Two einsum calls replace the standard Q @ K^T and attn @ V. The string "bhqd,bhkd->bhqk" makes the transposition implicit: d is contracted, producing the (q, k) attention matrix.

  4. Bonus patterns -- Einsum handles many common operations: projections, outer products, traces, and bilinear forms, all in a single unified notation.

Complexity Analysis

  • Batch matmul: O(B * n * m * p) for all three methods. The manual method creates a temporary (B, n, m, p) tensor (more memory), while einsum and bmm are more memory-efficient.
  • Einsum attention: Same O(B * H * T^2 * d) as standard attention. Einsum does not change algorithmic complexity; it is syntactic sugar that PyTorch's backend optimizes into the same CUDA kernels.

Interview Tips

Interview Tip

Interviewers test einsum to assess tensor manipulation fluency. Key skills: (1) Read any einsum expression by identifying free indices (in output) vs. contracted indices (summed over). (2) Know common patterns: batch matmul, trace, outer product, bilinear form. (3) Understand that einsum compiles to the same operations as manual code -- it is not slower. (4) Be able to convert between einsum and explicit transpose/reshape/bmm equivalents. (5) For debugging, check shapes by looking at the einsum string: each index corresponds to one dimension.

Quiz

Quiz — 3 Questions

In the einsum expression 'bhqd,bhkd->bhqk', which index is being summed over?

What does the einsum expression 'bii->b' compute?

Is torch.einsum slower than explicit operations like torch.bmm?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.