Akshay’s Gradient
ML Codingintermediate55 min

Multi-Head Attention

Implement Multi-Head Attention

Build the full multi-head attention mechanism used in transformers -- splitting the representation into multiple heads to jointly attend to information from different subspaces.

Problem Statement

Implement a MultiHeadAttention module that:

  1. Projects input x into Q, K, V using learned linear layers
  2. Splits into num_heads parallel attention heads
  3. Applies scaled dot-product attention to each head independently
  4. Concatenates heads and applies a final output projection

Inputs:

  • query, key, value: tensors of shape (batch, seq_len, d_model)
  • mask: optional causal or padding mask

Outputs:

  • output: tensor of shape (batch, seq_len, d_model)

Constraints: d_model must be divisible by num_heads. Each head operates on d_k = d_model // num_heads.

Key Concept

Multi-head attention allows the model to attend to information from different representation subspaces at different positions. A single attention head averages over all subspaces, reducing its expressiveness. Multiple heads give the model the capacity to capture different types of relationships (e.g., syntactic vs. semantic) simultaneously.

Interactive · Multi-Head Attention: Split, Attend, Concatenate
┌────────────────────────────────────────────────────────────────────┐
│              Multi-Head Attention                                  │
│                                                                    │
│  Input x (B, T, d_model=512)                                      │
│      │                                                             │
│      ├──── W_q ────┐  ├──── W_k ────┐  ├──── W_v ────┐           │
│      ▼             │  ▼             │  ▼             │            │
│   Q (B,T,512)      K (B,T,512)      V (B,T,512)                   │
│      │                  │                │                         │
│      ▼                  ▼                ▼                         │
│   Reshape: (B, T, 8, 64) ──▶ Transpose: (B, 8, T, 64)            │
│                                                                    │
│   ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐    ... (8 heads)           │
│   │Head 0│ │Head 1│ │Head 2│ │Head 3│                             │
│   │ d=64 │ │ d=64 │ │ d=64 │ │ d=64 │                             │
│   │ Attn │ │ Attn │ │ Attn │ │ Attn │                             │
│   └──┬───┘ └──┬───┘ └──┬───┘ └──┬───┘                             │
│      │        │        │        │                                  │
│      ▼        ▼        ▼        ▼                                  │
│   ┌────────────────────────────────────┐                           │
│   │  Transpose + Reshape (concatenate) │                           │
│   │  (B, 8, T, 64) ──▶ (B, T, 512)    │                           │
│   └──────────────────┬─────────────────┘                           │
│                      │                                             │
│                      ▼                                             │
│              ┌──────────────┐                                      │
│              │   W_o (512→512)│                                     │
│              │   Output proj │                                      │
│              └──────┬───────┘                                      │
│                     ▼                                              │
│              Output (B, T, 512)                                    │
└────────────────────────────────────────────────────────────────────┘
Warning

The most error-prone part of multi-head attention is the reshape/transpose sequence. Remember: after .view(B, T, H, d_k), you need .transpose(1, 2) to get (B, H, T, d_k). When merging heads back, you must call .contiguous() before .view() because transpose creates a non-contiguous tensor view. Forgetting .contiguous() will raise a runtime error.

Hints

Info
  1. Create three linear layers W_q, W_k, W_v of shape (d_model, d_model) and one output projection W_o.
  2. After projecting, reshape from (B, seq, d_model) to (B, num_heads, seq, d_k) using .view() and .transpose().
  3. Compute attention in parallel across all heads using batched matrix multiplication (treat B * num_heads as the batch dimension).
  4. After attention, transpose back and reshape to (B, seq, d_model).
  5. Apply the output projection W_o.

Solution

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention as described in 'Attention Is All You Need'."""

    def __init__(self, d_model: int, num_heads: int) -> None:
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Projection matrices (combined for all heads)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Reshape (B, seq, d_model) -> (B, num_heads, seq, d_k)."""
        B, seq_len, _ = x.shape
        return x.view(B, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Reshape (B, num_heads, seq, d_k) -> (B, seq, d_model)."""
        B, _, seq_len, _ = x.shape
        return x.transpose(1, 2).contiguous().view(B, seq_len, self.d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            query, key, value: (batch, seq_len, d_model)
            mask: (seq_len, seq_len) bool tensor, True = blocked

        Returns:
            output: (batch, seq_len, d_model)
        """
        # Step 1: Linear projections
        Q = self._split_heads(self.W_q(query))  # (B, H, seq, d_k)
        K = self._split_heads(self.W_k(key))
        V = self._split_heads(self.W_v(value))

        # Step 2: Scaled dot-product attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores: (B, H, seq_q, seq_k)

        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)

        attn_output = torch.matmul(attn_weights, V)  # (B, H, seq, d_k)

        # Step 3: Concatenate heads
        concat = self._merge_heads(attn_output)  # (B, seq, d_model)

        # Step 4: Output projection
        output = self.W_o(concat)  # (B, seq, d_model)
        return output


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    B, seq_len, d_model, num_heads = 2, 10, 64, 8

    mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
    x = torch.randn(B, seq_len, d_model)

    # Self-attention (query = key = value = x)
    out = mha(x, x, x)
    print(f"Output shape: {out.shape}")  # (2, 10, 64)

    # With causal mask
    causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
    out_causal = mha(x, x, x, mask=causal_mask)
    print(f"Causal output shape: {out_causal.shape}")  # (2, 10, 64)

    # Verify parameter count
    total_params = sum(p.numel() for p in mha.parameters())
    expected = 4 * d_model * d_model  # 4 projection matrices
    print(f"Parameters: {total_params} (expected {expected})")
    assert total_params == expected

Walkthrough

  1. Projection -- A single nn.Linear(d_model, d_model) layer effectively computes all head projections at once. The weight matrix is (d_model, d_model), which can be thought of as num_heads separate (d_model, d_k) matrices stacked together.

  2. Split heads -- We reshape the projected tensor from (B, seq, d_model) to (B, num_heads, seq, d_k). This is a zero-cost reshape plus a transpose -- no data is copied.

  3. Parallel attention -- With the head dimension folded into the batch, torch.matmul performs attention across all heads simultaneously. This is much more efficient than looping over heads.

  4. Merge heads -- After attention, we reverse the reshape: transpose the head and sequence dimensions back, then view as (B, seq, d_model). The .contiguous() call is needed because transpose returns a non-contiguous view.

  5. Output projection -- The final linear layer W_o mixes information across heads. Without it, each head's output would be independent.

Complexity Analysis

  • Time: O(B * H * n^2 * d_k) = O(B * n^2 * d_model). The per-head computation is O(n^2 * d_k) and there are H heads, but H * d_k = d_model, so total is the same as single-head with full dimension.
  • Space: O(B * H * n^2) for the attention weight matrices, plus O(d_model^2) for the four projection weight matrices.

Multi-head attention does not increase total compute over single-head attention with the same d_model -- it just partitions the work differently.

Interview Tips

Interview Tip

Be ready to explain: (1) Why multi-head is better than single-head -- different heads learn different attention patterns (positional, syntactic, semantic). (2) The reshape/transpose mechanics -- this is the most error-prone part. (3) Why contiguous() is needed before .view(). (4) How this extends to cross-attention (query from decoder, key/value from encoder). (5) The parameter count: exactly 4 * d_model^2 for the four projections.

Quiz

Quiz — 3 Questions

If d_model=512 and num_heads=8, what is the dimension of each attention head?

Why is .contiguous() needed before .view() when merging heads back together?

Does multi-head attention increase the total compute compared to single-head attention with the same d_model?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.