Exercise: Batch Matrix Multiplication and Einsum
Master batched matrix multiplication using both manual implementation and Einstein summation notation. These operations are the backbone of transformer computations.
Problem Statement
Implement three functions:
batch_matmul_manual(A, B)-- batched matrix multiply using only basic operations (notorch.bmmor@)batch_matmul_einsum(A, B)-- same operation usingtorch.einsumeinsum_attention(Q, K, V)-- compute multi-head attention scores using only einsum operations
Inputs:
A: tensor of shape(batch, n, m)B: tensor of shape(batch, m, p)Q, K, V: tensors of shape(batch, heads, seq_len, d_k)
Outputs:
- Matrix product of shape
(batch, n, p)for the first two functions - Attention output of shape
(batch, heads, seq_len, d_k)for the third
Einstein summation (einsum) is a compact notation for expressing tensor contractions. The rule: indices that appear in the input but not the output are summed over. For example, "bij,bjk->bik" means: for each batch b, sum over j to compute C[b,i,k] = sum_j A[b,i,j] * B[b,j,k].
┌────────────────────────────────────────────────────────────────┐
│ Batch Matrix Multiply via Einsum │
│ │
│ Standard: "bij,bjk->bik" │
│ │
│ A (batch, n, m) B (batch, m, p) C (batch, n, p) │
│ ┌───────┐ ┌───────┐ ┌───────┐ │
│ │ b i j │ @ │ b j k │ = │ b i k │ │
│ └───────┘ └───────┘ └───────┘ │
│ │ │ │ │
│ └──────┬───────────┘ │ │
│ ▼ │ │
│ Sum over j (contracted) ─────────────┘ │
│ │
│ Attention: "bhqd,bhkd->bhqk" │
│ │
│ Q (B, H, T_q, d_k) K (B, H, T_k, d_k) │
│ │ │ │
│ └──────────┬───────────┘ │
│ ▼ │
│ Contract over d (dot product) │
│ │ │
│ ▼ │
│ Scores (B, H, T_q, T_k) │
│ │
│ Reading einsum strings: │
│ • Indices in BOTH inputs but NOT output → summed (contracted) │
│ • Indices in output → kept (free indices) │
└────────────────────────────────────────────────────────────────┘
The manual batch matmul using unsqueeze and broadcasting creates a temporary 4D tensor of shape (B, n, m, p), which can be very memory-expensive. For large matrices, prefer torch.bmm, the @ operator, or torch.einsum, which use optimized BLAS kernels and do not materialize the full outer product.
When asked about einsum in interviews, the quickest way to reason about any expression is: "indices that appear in the inputs but NOT in the output are summed over." Practice reading expressions backward: given "bhqk,bhkd->bhqd", you can immediately see this is attention_weights @ V because k (the key position) is being contracted.
Hints
- For manual batch matmul, use
torch.sum(A.unsqueeze(-1) * B.unsqueeze(-3), dim=-2). - The key insight:
A[:,:,:,None] * B[:,None,:,:]creates a 4D tensor, and summing over the shared dimension gives the matrix product. - Einsum for batch matmul:
"bij,bjk->bik". - For attention: scores =
"bhqd,bhkd->bhqk", then output ="bhqk,bhkd->bhqd". - Do not forget the scaling factor
1/sqrt(d_k)in attention.
Solution
import torch
import torch.nn.functional as F
import math
def batch_matmul_manual(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Batch matrix multiply without using bmm or @.
A: (batch, n, m), B: (batch, m, p) -> (batch, n, p)
"""
# Expand dimensions for broadcasting:
# A: (batch, n, m, 1) * B: (batch, 1, m, p) -> (batch, n, m, p)
# Sum over m (dim=-2) to get (batch, n, p)
return torch.sum(A.unsqueeze(-1) * B.unsqueeze(-3), dim=-2)
def batch_matmul_einsum(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Batch matrix multiply using einsum."""
return torch.einsum("bij,bjk->bik", A, B)
def einsum_attention(
Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
) -> torch.Tensor:
"""
Multi-head attention using only einsum operations.
Q, K, V: (batch, heads, seq_len, d_k)
Returns: (batch, heads, seq_len, d_k)
"""
d_k = Q.size(-1)
# Compute attention scores: (batch, heads, seq_q, seq_k)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) / math.sqrt(d_k)
# Softmax over key dimension
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum of values: (batch, heads, seq_q, d_k)
output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, V)
return output
# Bonus: common einsum patterns in transformers
def einsum_examples() -> None:
"""Demonstrate useful einsum patterns."""
B, H, T, D = 2, 4, 8, 16
# 1. Linear projection (no explicit weight broadcasting needed)
x = torch.randn(B, T, D)
W = torch.randn(D, D)
projected = torch.einsum("btd,de->bte", x, W)
assert projected.shape == (B, T, D)
# 2. Outer product (for attention pattern analysis)
a = torch.randn(T)
b = torch.randn(T)
outer = torch.einsum("i,j->ij", a, b)
assert outer.shape == (T, T)
# 3. Batch trace
M = torch.randn(B, D, D)
traces = torch.einsum("bii->b", M)
assert traces.shape == (B,)
# 4. Bilinear form: x^T A y for each batch
x = torch.randn(B, D)
A = torch.randn(D, D)
y = torch.randn(B, D)
bilinear = torch.einsum("bi,ij,bj->b", x, A, y)
assert bilinear.shape == (B,)
print("All einsum patterns verified.")
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
# Test batch matmul
A = torch.randn(4, 3, 5)
B = torch.randn(4, 5, 7)
result_manual = batch_matmul_manual(A, B)
result_einsum = batch_matmul_einsum(A, B)
result_builtin = torch.bmm(A, B)
print(f"Manual shape: {result_manual.shape}")
print(f"Einsum shape: {result_einsum.shape}")
print(f"Max diff (manual vs builtin): {(result_manual - result_builtin).abs().max():.2e}")
print(f"Max diff (einsum vs builtin): {(result_einsum - result_builtin).abs().max():.2e}")
# Test einsum attention
Q = torch.randn(2, 4, 8, 16)
K = torch.randn(2, 4, 8, 16)
V = torch.randn(2, 4, 8, 16)
attn_out = einsum_attention(Q, K, V)
print(f"\nAttention output shape: {attn_out.shape}") # (2, 4, 8, 16)
# Verify against standard implementation
scores_ref = (Q @ K.transpose(-2, -1)) / math.sqrt(16)
attn_ref = F.softmax(scores_ref, dim=-1) @ V
print(f"Max diff (einsum vs standard): {(attn_out - attn_ref).abs().max():.2e}")
einsum_examples()
Walkthrough
-
Manual batch matmul -- We use broadcasting to compute the outer product along the contracted dimension.
A.unsqueeze(-1)adds a trailing dimension to A,B.unsqueeze(-3)adds a dimension at position -3. Their element-wise product has shape(B, n, m, p), and summing overmyields the matrix product. -
Einsum batch matmul --
"bij,bjk->bik"is read as: "for each batch indexb, contract overj(it appears in both inputs but not in the output) to computeC[b,i,k]". This is more readable and often just as fast asbmm. -
Einsum attention -- Two einsum calls replace the standard
Q @ K^Tandattn @ V. The string"bhqd,bhkd->bhqk"makes the transposition implicit:dis contracted, producing the(q, k)attention matrix. -
Bonus patterns -- Einsum handles many common operations: projections, outer products, traces, and bilinear forms, all in a single unified notation.
Complexity Analysis
- Batch matmul: O(B * n * m * p) for all three methods. The manual method creates a temporary
(B, n, m, p)tensor (more memory), while einsum and bmm are more memory-efficient. - Einsum attention: Same O(B * H * T^2 * d) as standard attention. Einsum does not change algorithmic complexity; it is syntactic sugar that PyTorch's backend optimizes into the same CUDA kernels.
Interview Tips
Interviewers test einsum to assess tensor manipulation fluency. Key skills: (1) Read any einsum expression by identifying free indices (in output) vs. contracted indices (summed over). (2) Know common patterns: batch matmul, trace, outer product, bilinear form. (3) Understand that einsum compiles to the same operations as manual code -- it is not slower. (4) Be able to convert between einsum and explicit transpose/reshape/bmm equivalents. (5) For debugging, check shapes by looking at the einsum string: each index corresponds to one dimension.
Quiz
Quiz — 3 Questions
In the einsum expression 'bhqd,bhkd->bhqk', which index is being summed over?
What does the einsum expression 'bii->b' compute?
Is torch.einsum slower than explicit operations like torch.bmm?