Implement a Transformer Block
Build a complete transformer block: multi-head self-attention, feed-forward network, layer normalization, and residual connections. This is the repeating unit that makes up GPT, BERT, and every other transformer model.
Problem Statement
Implement a TransformerBlock module with:
- Pre-norm multi-head self-attention with residual connection
- Pre-norm position-wise feed-forward network (FFN) with residual connection
- The FFN expands the dimension by a factor of 4, applies GELU activation, then projects back
Use pre-norm architecture (LayerNorm before each sub-layer, as in GPT-2/3) rather than post-norm (original transformer).
Inputs: Tensor x of shape (batch, seq_len, d_model), optional causal mask.
Outputs: Tensor of the same shape (batch, seq_len, d_model).
The pre-norm transformer block computes: x = x + MHA(LayerNorm(x)) then x = x + FFN(LayerNorm(x)). The residual connections ensure gradients flow unimpeded through the network. Pre-norm is more stable than post-norm for training deep models because the input to each sub-layer is always normalized.
┌──────────────────────────────────────────────────────────────┐
│ Pre-Norm Transformer Block │
│ │
│ Input x ─────────────────────────────────┐ (residual) │
│ │ │ │
│ ▼ │ │
│ ┌──────────┐ │ │
│ │LayerNorm │ │ │
│ └────┬─────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌──────────────────┐ │ │
│ │ Multi-Head │ │ │
│ │ Self-Attention │ │ │
│ └────────┬─────────┘ │ │
│ │ │ │
│ └───────────── + ◀───────────────┘ │
│ │ │
│ x' ═══════════════════════════════════════┐ (residual) │
│ │ │ │
│ ▼ │ │
│ ┌──────────┐ │ │
│ │LayerNorm │ │ │
│ └────┬─────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌──────────────────┐ │ │
│ │ Feed-Forward │ │ │
│ │ d──▶4d──GELU──▶d │ │ │
│ └────────┬─────────┘ │ │
│ │ │ │
│ └───────────── + ◀───────────────┘ │
│ │ │
│ ▼ │
│ Output x'' (same shape as input) │
└──────────────────────────────────────────────────────────────┘
Do not confuse pre-norm and post-norm architectures. In pre-norm (GPT-2/3, LLaMA), LayerNorm is applied before each sub-layer: x = x + Sublayer(LN(x)). In post-norm (original 2017 transformer, BERT), it is applied after: x = LN(x + Sublayer(x)). Pre-norm is more stable for deep models because the residual path carries raw activations, but it may need a final LayerNorm after the last block.
Hints
- Use
nn.LayerNorm(d_model)for normalization. - The FFN is two linear layers:
Linear(d_model, 4*d_model)-> GELU ->Linear(4*d_model, d_model). - Apply LayerNorm before the sub-layer (pre-norm), not after.
- Add the sub-layer output to the original input (residual):
x = x + sublayer(norm(x)). - For self-attention, query = key = value = normalized input.
Solution
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
class MultiHeadAttention(nn.Module):
"""Multi-head self-attention (compact version for the block)."""
def __init__(self, d_model: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
# Project Q, K, V in one shot
qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, d_k)
Q, K, V = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
attn = F.softmax(scores, dim=-1)
out = attn @ V # (B, H, T, d_k)
# Merge heads
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class FeedForward(nn.Module):
"""Position-wise feed-forward network with GELU activation."""
def __init__(self, d_model: int, d_ff: int) -> None:
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(F.gelu(self.fc1(x)))
class TransformerBlock(nn.Module):
"""Pre-norm transformer block: LN -> MHA -> residual -> LN -> FFN -> residual."""
def __init__(self, d_model: int, num_heads: int, d_ff: Optional[int] = None) -> None:
super().__init__()
d_ff = d_ff or 4 * d_model
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Sub-layer 1: Multi-head self-attention with residual
x = x + self.attn(self.ln1(x), mask)
# Sub-layer 2: Feed-forward network with residual
x = x + self.ffn(self.ln2(x))
return x
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
B, T, d_model, num_heads = 2, 16, 128, 4
block = TransformerBlock(d_model=d_model, num_heads=num_heads)
x = torch.randn(B, T, d_model)
# Causal mask for autoregressive modeling
mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
out = block(x, mask=mask)
print(f"Input shape: {x.shape}") # (2, 16, 128)
print(f"Output shape: {out.shape}") # (2, 16, 128)
# Verify residual connection: output should be close to input for a random init
diff = (out - x).abs().mean().item()
print(f"Mean |output - input|: {diff:.4f}") # small due to residual
# Count parameters
total = sum(p.numel() for p in block.parameters())
print(f"Total parameters: {total:,}")
# Expected: 3*d^2 (QKV) + d^2 (out) + 2*4*d^2 (FFN) + 4*d (LayerNorm) = 12*d^2 + 4*d
expected = 12 * d_model**2 + 4 * d_model + 2 * 4 * d_model # +bias in FFN
print(f"Expected ~{expected:,}")
Walkthrough
-
Fused QKV projection -- Instead of three separate linear layers, we use one
Linear(d_model, 3*d_model)and split the output. This is more memory-efficient because it requires one large matrix multiply instead of three smaller ones. -
Pre-norm architecture --
LayerNormis applied before each sub-layer. The original transformer applied it after (post-norm), but pre-norm is more stable for deep networks because the residual path carries unnormalized activations, ensuring gradients can flow directly through the skip connection. -
Residual connections --
x = x + sublayer(norm(x))adds the sub-layer output to the original input. At initialization, sub-layer outputs are small, so the block approximately computes the identity function. This makes deep networks trainable. -
FFN structure -- The two-layer FFN with a 4x expansion acts as a per-token MLP. The GELU nonlinearity is used in GPT-2/3 and most modern transformers (originally ReLU in the 2017 paper).
-
Stacking -- A full transformer is just
Nof these blocks in sequence, preceded by an embedding layer and followed by a final LayerNorm and output head.
Complexity Analysis
- Attention: O(B * n^2 * d) time, O(B * H * n^2) space for attention weights.
- FFN: O(B * n * d * d_ff) = O(B * n * 4d^2) time per block.
- Parameters per block: ~12 * d_model^2 (4d^2 from attention projections, 8d^2 from FFN).
- For GPT-3 (d=12288, 96 layers): ~12 * 12288^2 * 96 = ~174B parameters (matches the published count).
Interview Tips
Understand and be able to discuss: (1) Pre-norm vs. post-norm and why modern models prefer pre-norm. (2) The role of each component -- removing any one (residual, LayerNorm, FFN) significantly hurts performance. (3) Parameter counting -- be able to derive ~12d^2 per block. (4) Why GELU over ReLU (smoother, avoids dead neurons). (5) How to extend to encoder-decoder with cross-attention (add a second attention sub-layer where K, V come from the encoder).
Quiz
Quiz — 3 Questions
In a pre-norm transformer block, where is LayerNorm applied?
Approximately how many parameters does a single transformer block have (ignoring bias terms)?
Why are residual connections essential for training deep transformers?