Akshay’s Gradient
ML Codingintermediate50 min

Self-Attention Mechanism

Implement Scaled Dot-Product Self-Attention

Implement the core attention mechanism from "Attention Is All You Need" -- the building block of every modern transformer.

Problem Statement

Implement scaled_dot_product_attention(Q, K, V, mask=None) that:

  1. Computes attention scores: scores = Q @ K^T / sqrt(d_k)
  2. Optionally applies a causal (or padding) mask by setting masked positions to -inf
  3. Applies softmax to get attention weights
  4. Returns weights @ V and the attention weights

Inputs:

  • Q, K, V: tensors of shape (batch, seq_len, d_k)
  • mask: optional boolean tensor of shape (seq_len, seq_len) where True means "mask out" (block attention)

Outputs:

  • output: tensor of shape (batch, seq_len, d_k)
  • attn_weights: tensor of shape (batch, seq_len, seq_len)
Key Concept

Scaling by 1/sqrt(d_k) is critical. Without it, when d_k is large, the dot products grow in magnitude, pushing softmax into regions with extremely small gradients (saturation). The scaling keeps the variance of the dot products at approximately 1, regardless of dimension.

Interactive · Scaled Dot-Product Attention: Q, K, V Flow
┌────────────────────────────────────────────────────────────────┐
│         Scaled Dot-Product Attention                           │
│                                                                │
│   Q (query)    K (key)     V (value)                           │
│   (B,T,d_k)   (B,T,d_k)  (B,T,d_v)                           │
│      │            │           │                                │
│      │     ┌──────┘           │                                │
│      ▼     ▼                  │                                │
│   ┌──────────────┐            │                                │
│   │  Q @ K^T     │            │                                │
│   │  (B, T, T)   │            │                                │
│   └──────┬───────┘            │                                │
│          │                    │                                │
│          ▼                    │                                │
│   ┌──────────────┐            │                                │
│   │  / sqrt(d_k) │            │                                │
│   │  (scaling)   │            │                                │
│   └──────┬───────┘            │                                │
│          │                    │                                │
│          ▼                    │                                │
│   ┌──────────────┐            │                                │
│   │  mask (opt.) │            │                                │
│   │  -inf future │            │                                │
│   └──────┬───────┘            │                                │
│          │                    │                                │
│          ▼                    │                                │
│   ┌──────────────┐            │                                │
│   │  softmax     │            │                                │
│   │  (B, T, T)   │            │                                │
│   └──────┬───────┘            │                                │
│          │            ┌───────┘                                │
│          ▼            ▼                                        │
│   ┌────────────────────┐                                       │
│   │ attn_weights @ V   │                                       │
│   │ (B, T, d_v)        │                                       │
│   └────────────────────┘                                       │
│          │                                                     │
│          ▼                                                     │
│      Output (B, T, d_v)                                        │
└────────────────────────────────────────────────────────────────┘
Warning

A frequent bug is using mask.masked_fill(mask, 0) instead of masked_fill(mask, float('-inf')). Setting masked positions to 0 does NOT block attention -- exp(0) = 1, so those positions still receive nonzero attention weight. You must use -inf so that exp(-inf) = 0. Another common mistake is applying the mask with wrong dimensions -- ensure it broadcasts correctly over the batch dimension.

Hints

Info
  1. Use torch.bmm or @ operator for batched matrix multiplication.
  2. Transpose K's last two dimensions: K.transpose(-2, -1).
  3. Scale by 1 / sqrt(d_k) where d_k = Q.shape[-1].
  4. If a mask is provided, use scores.masked_fill_(mask, float('-inf')).
  5. Apply F.softmax(scores, dim=-1) along the last dimension (each query attends over all keys).
  6. Final output is attn_weights @ V.

Solution

import torch
import torch.nn.functional as F
import math
from typing import Optional, Tuple


def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Scaled dot-product attention.

    Args:
        Q: (batch, seq_len_q, d_k)
        K: (batch, seq_len_k, d_k)
        V: (batch, seq_len_k, d_v)
        mask: (seq_len_q, seq_len_k) boolean, True = masked out

    Returns:
        output: (batch, seq_len_q, d_v)
        attn_weights: (batch, seq_len_q, seq_len_k)
    """
    d_k = Q.size(-1)

    # Step 1: Compute raw attention scores
    scores = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # scores shape: (batch, seq_len_q, seq_len_k)

    # Step 2: Apply mask (e.g., causal mask)
    if mask is not None:
        scores = scores.masked_fill(mask.unsqueeze(0), float("-inf"))

    # Step 3: Softmax over the key dimension
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Weighted sum of values
    output = torch.bmm(attn_weights, V)

    return output, attn_weights


def create_causal_mask(seq_len: int) -> torch.Tensor:
    """Create an upper-triangular causal mask (True = blocked)."""
    return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    batch, seq_len, d_k = 2, 6, 64

    Q = torch.randn(batch, seq_len, d_k)
    K = torch.randn(batch, seq_len, d_k)
    V = torch.randn(batch, seq_len, d_k)

    # Without mask (bidirectional attention)
    out, weights = scaled_dot_product_attention(Q, K, V)
    print(f"Output shape: {out.shape}")          # (2, 6, 64)
    print(f"Weights shape: {weights.shape}")     # (2, 6, 6)
    print(f"Weights sum (row): {weights[0, 0].sum().item():.4f}")  # 1.0

    # With causal mask (autoregressive)
    mask = create_causal_mask(seq_len)
    out_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask)
    print(f"\nCausal weights[0,0]: {weights_causal[0, 0]}")  # only first pos nonzero
    print(f"Causal weights[0,2]: {weights_causal[0, 2]}")    # first 3 positions nonzero

    # Verify causality: position i should have zero weight on positions > i
    for i in range(seq_len):
        future_weight = weights_causal[0, i, i + 1 :].sum().item()
        assert abs(future_weight) < 1e-6, f"Position {i} attends to the future!"
    print("\nCausality check passed.")

Walkthrough

  1. Score computation -- Q @ K^T computes the dot product between every query-key pair. For a sequence of length n, this produces an n x n attention matrix. Each entry measures how much one token should attend to another.

  2. Scaling -- Dividing by sqrt(d_k) counteracts the growth of dot-product magnitude with dimension. If Q and K entries have unit variance, their dot product has variance d_k; scaling brings it back to variance 1.

  3. Masking -- Setting positions to -inf before softmax ensures they get probability zero after exponentiation. The causal mask is upper-triangular: position i can only attend to positions <= i.

  4. Softmax -- Converts scores to a probability distribution over keys. Each query independently distributes its attention budget across all (unmasked) keys.

  5. Weighted sum -- The output for each query is a weighted combination of value vectors, where the weights are the attention probabilities.

Complexity Analysis

  • Time: O(B * n^2 * d) for the matrix multiplications, where B = batch size, n = sequence length, d = head dimension. This quadratic dependence on sequence length is the main bottleneck of standard attention.
  • Space: O(B * n^2) for storing the attention weight matrix. This is why long-context models need techniques like FlashAttention (which tiles the computation to avoid materializing the full matrix).

Interview Tips

Interview Tip

Interviewers expect you to: (1) Know why you scale by sqrt(d_k) -- explain the variance argument. (2) Implement masking correctly using -inf, not zero. (3) Understand the difference between causal masks (autoregressive) and padding masks (variable-length sequences). (4) Be able to discuss FlashAttention and why the O(n^2) memory of standard attention is the real bottleneck, not the O(n^2 * d) compute.

Quiz

Quiz — 3 Questions

Why do we use -inf (not 0) for masked positions before softmax?

Why does attention scale by 1/sqrt(d_k) instead of 1/d_k?

What is the space complexity bottleneck of standard (non-Flash) attention?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.