Akshay’s Gradient
ML Codingadvanced60 min

Transformer Block

Implement a Transformer Block

Build a complete transformer block: multi-head self-attention, feed-forward network, layer normalization, and residual connections. This is the repeating unit that makes up GPT, BERT, and every other transformer model.

Problem Statement

Implement a TransformerBlock module with:

  1. Pre-norm multi-head self-attention with residual connection
  2. Pre-norm position-wise feed-forward network (FFN) with residual connection
  3. The FFN expands the dimension by a factor of 4, applies GELU activation, then projects back

Use pre-norm architecture (LayerNorm before each sub-layer, as in GPT-2/3) rather than post-norm (original transformer).

Inputs: Tensor x of shape (batch, seq_len, d_model), optional causal mask.

Outputs: Tensor of the same shape (batch, seq_len, d_model).

Key Concept

The pre-norm transformer block computes: x = x + MHA(LayerNorm(x)) then x = x + FFN(LayerNorm(x)). The residual connections ensure gradients flow unimpeded through the network. Pre-norm is more stable than post-norm for training deep models because the input to each sub-layer is always normalized.

Interactive · Pre-Norm Transformer Block: Residuals + LayerNorm
┌──────────────────────────────────────────────────────────────┐
│              Pre-Norm Transformer Block                       │
│                                                              │
│    Input x ─────────────────────────────────┐ (residual)     │
│       │                                     │                │
│       ▼                                     │                │
│  ┌──────────┐                               │                │
│  │LayerNorm │                               │                │
│  └────┬─────┘                               │                │
│       │                                     │                │
│       ▼                                     │                │
│  ┌──────────────────┐                       │                │
│  │  Multi-Head       │                       │                │
│  │  Self-Attention   │                       │                │
│  └────────┬─────────┘                       │                │
│           │                                 │                │
│           └─────────────  + ◀───────────────┘                │
│                           │                                  │
│    x' ═══════════════════════════════════════┐ (residual)    │
│       │                                     │                │
│       ▼                                     │                │
│  ┌──────────┐                               │                │
│  │LayerNorm │                               │                │
│  └────┬─────┘                               │                │
│       │                                     │                │
│       ▼                                     │                │
│  ┌──────────────────┐                       │                │
│  │  Feed-Forward     │                       │                │
│  │  d──▶4d──GELU──▶d │                       │                │
│  └────────┬─────────┘                       │                │
│           │                                 │                │
│           └─────────────  + ◀───────────────┘                │
│                           │                                  │
│                           ▼                                  │
│                   Output x'' (same shape as input)           │
└──────────────────────────────────────────────────────────────┘
Warning

Do not confuse pre-norm and post-norm architectures. In pre-norm (GPT-2/3, LLaMA), LayerNorm is applied before each sub-layer: x = x + Sublayer(LN(x)). In post-norm (original 2017 transformer, BERT), it is applied after: x = LN(x + Sublayer(x)). Pre-norm is more stable for deep models because the residual path carries raw activations, but it may need a final LayerNorm after the last block.

Hints

Info
  1. Use nn.LayerNorm(d_model) for normalization.
  2. The FFN is two linear layers: Linear(d_model, 4*d_model) -> GELU -> Linear(4*d_model, d_model).
  3. Apply LayerNorm before the sub-layer (pre-norm), not after.
  4. Add the sub-layer output to the original input (residual): x = x + sublayer(norm(x)).
  5. For self-attention, query = key = value = normalized input.

Solution

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention (compact version for the block)."""

    def __init__(self, d_model: int, num_heads: int) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        # Project Q, K, V in one shot
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, T, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = attn @ V  # (B, H, T, d_k)

        # Merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)


class FeedForward(nn.Module):
    """Position-wise feed-forward network with GELU activation."""

    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.gelu(self.fc1(x)))


class TransformerBlock(nn.Module):
    """Pre-norm transformer block: LN -> MHA -> residual -> LN -> FFN -> residual."""

    def __init__(self, d_model: int, num_heads: int, d_ff: Optional[int] = None) -> None:
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Sub-layer 1: Multi-head self-attention with residual
        x = x + self.attn(self.ln1(x), mask)
        # Sub-layer 2: Feed-forward network with residual
        x = x + self.ffn(self.ln2(x))
        return x


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    B, T, d_model, num_heads = 2, 16, 128, 4

    block = TransformerBlock(d_model=d_model, num_heads=num_heads)
    x = torch.randn(B, T, d_model)

    # Causal mask for autoregressive modeling
    mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
    out = block(x, mask=mask)

    print(f"Input shape:  {x.shape}")   # (2, 16, 128)
    print(f"Output shape: {out.shape}")  # (2, 16, 128)

    # Verify residual connection: output should be close to input for a random init
    diff = (out - x).abs().mean().item()
    print(f"Mean |output - input|: {diff:.4f}")  # small due to residual

    # Count parameters
    total = sum(p.numel() for p in block.parameters())
    print(f"Total parameters: {total:,}")
    # Expected: 3*d^2 (QKV) + d^2 (out) + 2*4*d^2 (FFN) + 4*d (LayerNorm) = 12*d^2 + 4*d
    expected = 12 * d_model**2 + 4 * d_model + 2 * 4 * d_model  # +bias in FFN
    print(f"Expected ~{expected:,}")

Walkthrough

  1. Fused QKV projection -- Instead of three separate linear layers, we use one Linear(d_model, 3*d_model) and split the output. This is more memory-efficient because it requires one large matrix multiply instead of three smaller ones.

  2. Pre-norm architecture -- LayerNorm is applied before each sub-layer. The original transformer applied it after (post-norm), but pre-norm is more stable for deep networks because the residual path carries unnormalized activations, ensuring gradients can flow directly through the skip connection.

  3. Residual connections -- x = x + sublayer(norm(x)) adds the sub-layer output to the original input. At initialization, sub-layer outputs are small, so the block approximately computes the identity function. This makes deep networks trainable.

  4. FFN structure -- The two-layer FFN with a 4x expansion acts as a per-token MLP. The GELU nonlinearity is used in GPT-2/3 and most modern transformers (originally ReLU in the 2017 paper).

  5. Stacking -- A full transformer is just N of these blocks in sequence, preceded by an embedding layer and followed by a final LayerNorm and output head.

Complexity Analysis

  • Attention: O(B * n^2 * d) time, O(B * H * n^2) space for attention weights.
  • FFN: O(B * n * d * d_ff) = O(B * n * 4d^2) time per block.
  • Parameters per block: ~12 * d_model^2 (4d^2 from attention projections, 8d^2 from FFN).
  • For GPT-3 (d=12288, 96 layers): ~12 * 12288^2 * 96 = ~174B parameters (matches the published count).

Interview Tips

Interview Tip

Understand and be able to discuss: (1) Pre-norm vs. post-norm and why modern models prefer pre-norm. (2) The role of each component -- removing any one (residual, LayerNorm, FFN) significantly hurts performance. (3) Parameter counting -- be able to derive ~12d^2 per block. (4) Why GELU over ReLU (smoother, avoids dead neurons). (5) How to extend to encoder-decoder with cross-attention (add a second attention sub-layer where K, V come from the encoder).

Quiz

Quiz — 3 Questions

In a pre-norm transformer block, where is LayerNorm applied?

Approximately how many parameters does a single transformer block have (ignoring bias terms)?

Why are residual connections essential for training deep transformers?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.