Implement Multi-Head Attention
Build the full multi-head attention mechanism used in transformers -- splitting the representation into multiple heads to jointly attend to information from different subspaces.
Problem Statement
Implement a MultiHeadAttention module that:
- Projects input
xinto Q, K, V using learned linear layers - Splits into
num_headsparallel attention heads - Applies scaled dot-product attention to each head independently
- Concatenates heads and applies a final output projection
Inputs:
query,key,value: tensors of shape(batch, seq_len, d_model)mask: optional causal or padding mask
Outputs:
output: tensor of shape(batch, seq_len, d_model)
Constraints: d_model must be divisible by num_heads. Each head operates on d_k = d_model // num_heads.
Multi-head attention allows the model to attend to information from different representation subspaces at different positions. A single attention head averages over all subspaces, reducing its expressiveness. Multiple heads give the model the capacity to capture different types of relationships (e.g., syntactic vs. semantic) simultaneously.
┌────────────────────────────────────────────────────────────────────┐
│ Multi-Head Attention │
│ │
│ Input x (B, T, d_model=512) │
│ │ │
│ ├──── W_q ────┐ ├──── W_k ────┐ ├──── W_v ────┐ │
│ ▼ │ ▼ │ ▼ │ │
│ Q (B,T,512) K (B,T,512) V (B,T,512) │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ Reshape: (B, T, 8, 64) ──▶ Transpose: (B, 8, T, 64) │
│ │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ... (8 heads) │
│ │Head 0│ │Head 1│ │Head 2│ │Head 3│ │
│ │ d=64 │ │ d=64 │ │ d=64 │ │ d=64 │ │
│ │ Attn │ │ Attn │ │ Attn │ │ Attn │ │
│ └──┬───┘ └──┬───┘ └──┬───┘ └──┬───┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌────────────────────────────────────┐ │
│ │ Transpose + Reshape (concatenate) │ │
│ │ (B, 8, T, 64) ──▶ (B, T, 512) │ │
│ └──────────────────┬─────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ W_o (512→512)│ │
│ │ Output proj │ │
│ └──────┬───────┘ │
│ ▼ │
│ Output (B, T, 512) │
└────────────────────────────────────────────────────────────────────┘
The most error-prone part of multi-head attention is the reshape/transpose sequence. Remember: after .view(B, T, H, d_k), you need .transpose(1, 2) to get (B, H, T, d_k). When merging heads back, you must call .contiguous() before .view() because transpose creates a non-contiguous tensor view. Forgetting .contiguous() will raise a runtime error.
Hints
- Create three linear layers
W_q,W_k,W_vof shape(d_model, d_model)and one output projectionW_o. - After projecting, reshape from
(B, seq, d_model)to(B, num_heads, seq, d_k)using.view()and.transpose(). - Compute attention in parallel across all heads using batched matrix multiplication (treat
B * num_headsas the batch dimension). - After attention, transpose back and reshape to
(B, seq, d_model). - Apply the output projection
W_o.
Solution
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention as described in 'Attention Is All You Need'."""
def __init__(self, d_model: int, num_heads: int) -> None:
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Projection matrices (combined for all heads)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (B, seq, d_model) -> (B, num_heads, seq, d_k)."""
B, seq_len, _ = x.shape
return x.view(B, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (B, num_heads, seq, d_k) -> (B, seq, d_model)."""
B, _, seq_len, _ = x.shape
return x.transpose(1, 2).contiguous().view(B, seq_len, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
query, key, value: (batch, seq_len, d_model)
mask: (seq_len, seq_len) bool tensor, True = blocked
Returns:
output: (batch, seq_len, d_model)
"""
# Step 1: Linear projections
Q = self._split_heads(self.W_q(query)) # (B, H, seq, d_k)
K = self._split_heads(self.W_k(key))
V = self._split_heads(self.W_v(value))
# Step 2: Scaled dot-product attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (B, H, seq_q, seq_k)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V) # (B, H, seq, d_k)
# Step 3: Concatenate heads
concat = self._merge_heads(attn_output) # (B, seq, d_model)
# Step 4: Output projection
output = self.W_o(concat) # (B, seq, d_model)
return output
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
B, seq_len, d_model, num_heads = 2, 10, 64, 8
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
x = torch.randn(B, seq_len, d_model)
# Self-attention (query = key = value = x)
out = mha(x, x, x)
print(f"Output shape: {out.shape}") # (2, 10, 64)
# With causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
out_causal = mha(x, x, x, mask=causal_mask)
print(f"Causal output shape: {out_causal.shape}") # (2, 10, 64)
# Verify parameter count
total_params = sum(p.numel() for p in mha.parameters())
expected = 4 * d_model * d_model # 4 projection matrices
print(f"Parameters: {total_params} (expected {expected})")
assert total_params == expected
Walkthrough
-
Projection -- A single
nn.Linear(d_model, d_model)layer effectively computes all head projections at once. The weight matrix is(d_model, d_model), which can be thought of asnum_headsseparate(d_model, d_k)matrices stacked together. -
Split heads -- We reshape the projected tensor from
(B, seq, d_model)to(B, num_heads, seq, d_k). This is a zero-cost reshape plus a transpose -- no data is copied. -
Parallel attention -- With the head dimension folded into the batch,
torch.matmulperforms attention across all heads simultaneously. This is much more efficient than looping over heads. -
Merge heads -- After attention, we reverse the reshape: transpose the head and sequence dimensions back, then view as
(B, seq, d_model). The.contiguous()call is needed becausetransposereturns a non-contiguous view. -
Output projection -- The final linear layer
W_omixes information across heads. Without it, each head's output would be independent.
Complexity Analysis
- Time: O(B * H * n^2 * d_k) = O(B * n^2 * d_model). The per-head computation is O(n^2 * d_k) and there are H heads, but H * d_k = d_model, so total is the same as single-head with full dimension.
- Space: O(B * H * n^2) for the attention weight matrices, plus O(d_model^2) for the four projection weight matrices.
Multi-head attention does not increase total compute over single-head attention with the same d_model -- it just partitions the work differently.
Interview Tips
Be ready to explain: (1) Why multi-head is better than single-head -- different heads learn different attention patterns (positional, syntactic, semantic). (2) The reshape/transpose mechanics -- this is the most error-prone part. (3) Why contiguous() is needed before .view(). (4) How this extends to cross-attention (query from decoder, key/value from encoder). (5) The parameter count: exactly 4 * d_model^2 for the four projections.
Quiz
Quiz — 3 Questions
If d_model=512 and num_heads=8, what is the dimension of each attention head?
Why is .contiguous() needed before .view() when merging heads back together?
Does multi-head attention increase the total compute compared to single-head attention with the same d_model?