Implement Scaled Dot-Product Self-Attention
Implement the core attention mechanism from "Attention Is All You Need" -- the building block of every modern transformer.
Problem Statement
Implement scaled_dot_product_attention(Q, K, V, mask=None) that:
- Computes attention scores:
scores = Q @ K^T / sqrt(d_k) - Optionally applies a causal (or padding) mask by setting masked positions to
-inf - Applies softmax to get attention weights
- Returns
weights @ Vand the attention weights
Inputs:
Q,K,V: tensors of shape(batch, seq_len, d_k)mask: optional boolean tensor of shape(seq_len, seq_len)whereTruemeans "mask out" (block attention)
Outputs:
output: tensor of shape(batch, seq_len, d_k)attn_weights: tensor of shape(batch, seq_len, seq_len)
Scaling by 1/sqrt(d_k) is critical. Without it, when d_k is large, the dot products grow in magnitude, pushing softmax into regions with extremely small gradients (saturation). The scaling keeps the variance of the dot products at approximately 1, regardless of dimension.
┌────────────────────────────────────────────────────────────────┐
│ Scaled Dot-Product Attention │
│ │
│ Q (query) K (key) V (value) │
│ (B,T,d_k) (B,T,d_k) (B,T,d_v) │
│ │ │ │ │
│ │ ┌──────┘ │ │
│ ▼ ▼ │ │
│ ┌──────────────┐ │ │
│ │ Q @ K^T │ │ │
│ │ (B, T, T) │ │ │
│ └──────┬───────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌──────────────┐ │ │
│ │ / sqrt(d_k) │ │ │
│ │ (scaling) │ │ │
│ └──────┬───────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌──────────────┐ │ │
│ │ mask (opt.) │ │ │
│ │ -inf future │ │ │
│ └──────┬───────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌──────────────┐ │ │
│ │ softmax │ │ │
│ │ (B, T, T) │ │ │
│ └──────┬───────┘ │ │
│ │ ┌───────┘ │
│ ▼ ▼ │
│ ┌────────────────────┐ │
│ │ attn_weights @ V │ │
│ │ (B, T, d_v) │ │
│ └────────────────────┘ │
│ │ │
│ ▼ │
│ Output (B, T, d_v) │
└────────────────────────────────────────────────────────────────┘
A frequent bug is using mask.masked_fill(mask, 0) instead of masked_fill(mask, float('-inf')). Setting masked positions to 0 does NOT block attention -- exp(0) = 1, so those positions still receive nonzero attention weight. You must use -inf so that exp(-inf) = 0. Another common mistake is applying the mask with wrong dimensions -- ensure it broadcasts correctly over the batch dimension.
Hints
- Use
torch.bmmor@operator for batched matrix multiplication. - Transpose K's last two dimensions:
K.transpose(-2, -1). - Scale by
1 / sqrt(d_k)whered_k = Q.shape[-1]. - If a mask is provided, use
scores.masked_fill_(mask, float('-inf')). - Apply
F.softmax(scores, dim=-1)along the last dimension (each query attends over all keys). - Final output is
attn_weights @ V.
Solution
import torch
import torch.nn.functional as F
import math
from typing import Optional, Tuple
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Scaled dot-product attention.
Args:
Q: (batch, seq_len_q, d_k)
K: (batch, seq_len_k, d_k)
V: (batch, seq_len_k, d_v)
mask: (seq_len_q, seq_len_k) boolean, True = masked out
Returns:
output: (batch, seq_len_q, d_v)
attn_weights: (batch, seq_len_q, seq_len_k)
"""
d_k = Q.size(-1)
# Step 1: Compute raw attention scores
scores = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores shape: (batch, seq_len_q, seq_len_k)
# Step 2: Apply mask (e.g., causal mask)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(0), float("-inf"))
# Step 3: Softmax over the key dimension
attn_weights = F.softmax(scores, dim=-1)
# Step 4: Weighted sum of values
output = torch.bmm(attn_weights, V)
return output, attn_weights
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""Create an upper-triangular causal mask (True = blocked)."""
return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
batch, seq_len, d_k = 2, 6, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)
# Without mask (bidirectional attention)
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {out.shape}") # (2, 6, 64)
print(f"Weights shape: {weights.shape}") # (2, 6, 6)
print(f"Weights sum (row): {weights[0, 0].sum().item():.4f}") # 1.0
# With causal mask (autoregressive)
mask = create_causal_mask(seq_len)
out_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask)
print(f"\nCausal weights[0,0]: {weights_causal[0, 0]}") # only first pos nonzero
print(f"Causal weights[0,2]: {weights_causal[0, 2]}") # first 3 positions nonzero
# Verify causality: position i should have zero weight on positions > i
for i in range(seq_len):
future_weight = weights_causal[0, i, i + 1 :].sum().item()
assert abs(future_weight) < 1e-6, f"Position {i} attends to the future!"
print("\nCausality check passed.")
Walkthrough
-
Score computation --
Q @ K^Tcomputes the dot product between every query-key pair. For a sequence of lengthn, this produces ann x nattention matrix. Each entry measures how much one token should attend to another. -
Scaling -- Dividing by
sqrt(d_k)counteracts the growth of dot-product magnitude with dimension. IfQandKentries have unit variance, their dot product has varianced_k; scaling brings it back to variance 1. -
Masking -- Setting positions to
-infbefore softmax ensures they get probability zero after exponentiation. The causal mask is upper-triangular: positionican only attend to positions<= i. -
Softmax -- Converts scores to a probability distribution over keys. Each query independently distributes its attention budget across all (unmasked) keys.
-
Weighted sum -- The output for each query is a weighted combination of value vectors, where the weights are the attention probabilities.
Complexity Analysis
- Time: O(B * n^2 * d) for the matrix multiplications, where B = batch size, n = sequence length, d = head dimension. This quadratic dependence on sequence length is the main bottleneck of standard attention.
- Space: O(B * n^2) for storing the attention weight matrix. This is why long-context models need techniques like FlashAttention (which tiles the computation to avoid materializing the full matrix).
Interview Tips
Interviewers expect you to: (1) Know why you scale by sqrt(d_k) -- explain the variance argument. (2) Implement masking correctly using -inf, not zero. (3) Understand the difference between causal masks (autoregressive) and padding masks (variable-length sequences). (4) Be able to discuss FlashAttention and why the O(n^2) memory of standard attention is the real bottleneck, not the O(n^2 * d) compute.
Quiz
Quiz — 3 Questions
Why do we use -inf (not 0) for masked positions before softmax?
Why does attention scale by 1/sqrt(d_k) instead of 1/d_k?
What is the space complexity bottleneck of standard (non-Flash) attention?