Akshay’s Gradient
ML Codingadvanced45 min

KV-Cache

Implement KV-Cache for Autoregressive Generation

Implement key-value caching for efficient autoregressive text generation. Without KV-cache, every new token requires recomputing attention over the entire sequence. With it, you only compute attention for the new token.

Problem Statement

Implement a CachedAttention module and a generate function that:

  1. On the first forward pass, compute K and V for all tokens and cache them
  2. On subsequent passes, only compute Q, K, V for the new token, append K and V to the cache, and compute attention between the new Q and all cached K, V
  3. The generate function produces tokens autoregressively using the cache

Inputs: Initial prompt token IDs, a model with cached attention layers, number of tokens to generate.

Outputs: Generated token sequence.

Key Concept

In autoregressive generation, token t attends to tokens 0..t. Without caching, generating token t recomputes all K, V vectors for tokens 0..t-1 (redundant work). KV-cache stores these vectors, reducing per-step complexity from O(t * d) to O(d) for the projection, though the attention computation itself is still O(t * d) per step.

Interactive · KV-Cache: Prefill vs. Decode Phases
┌───────────────────────────────────────────────────────────────┐
│              KV-Cache: Prefill vs. Decode                     │
│                                                               │
│  PREFILL PHASE (process prompt "The cat sat"):                │
│  ┌─────────────────────────────────────────┐                  │
│  │ tokens: [The] [cat] [sat]               │                  │
│  │    Q:    q₁    q₂    q₃    (all at once)│                  │
│  │    K:    k₁    k₂    k₃    ──▶ CACHE    │                  │
│  │    V:    v₁    v₂    v₃    ──▶ CACHE    │                  │
│  └─────────────────────────────────────────┘                  │
│                                                               │
│  DECODE STEP 1 (generate "on"):                               │
│  ┌─────────────────────────────────────────┐                  │
│  │ New token: [on]                         │                  │
│  │    Q:    q₄  (only 1 token!)            │                  │
│  │    K:    k₄  ──▶ append to cache        │                  │
│  │    V:    v₄  ──▶ append to cache        │                  │
│  │                                         │                  │
│  │ Cache now: K=[k₁,k₂,k₃,k₄]            │                  │
│  │            V=[v₁,v₂,v₃,v₄]            │                  │
│  │                                         │                  │
│  │ Attention: q₄ @ [k₁,k₂,k₃,k₄]^T      │                  │
│  └─────────────────────────────────────────┘                  │
│                                                               │
│  DECODE STEP 2 (generate "the"):                              │
│  ┌─────────────────────────────────────────┐                  │
│  │ New token: [the]                        │                  │
│  │    Q:    q₅  (only 1 token!)            │                  │
│  │    K:    k₅  ──▶ append to cache        │                  │
│  │    V:    v₅  ──▶ append to cache        │                  │
│  │                                         │                  │
│  │ Cache: K=[k₁,k₂,k₃,k₄,k₅]            │                  │
│  │        V=[v₁,v₂,v₃,v₄,v₅]            │                  │
│  └─────────────────────────────────────────┘                  │
│                                                               │
│  Without cache: step t recomputes ALL t projections  O(t*d)   │
│  With cache:    step t computes only 1 projection    O(d)     │
└───────────────────────────────────────────────────────────────┘
Interview Tip

When discussing KV-cache in interviews, be prepared to compute the memory footprint. For LLaMA-2 70B: 80 layers, 64 heads per layer (8 KV heads with GQA), d_k=128, and 4K context. KV-cache size = 2 (K and V) * 80 (layers) * 8 (KV heads) * 128 (d_k) * 4096 (seq) * 2 bytes (fp16) = ~1.3 GB per sequence. This is why batch serving is memory-bound, and techniques like paged attention (vLLM) are critical.

Hints

Info
  1. Add a cache attribute to store past K and V tensors: a tuple (cached_K, cached_V).
  2. During prefill (first pass): compute Q, K, V for all positions, store K and V in cache.
  3. During generation (subsequent passes): compute Q, K, V for only the new token, concatenate K and V with the cache, then compute attention.
  4. The attention mask is not needed during cached generation because we naturally only see past tokens.
  5. Reset the cache between sequences.

Solution

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List


class CachedMultiHeadAttention(nn.Module):
    """Multi-head attention with KV-cache support."""

    def __init__(self, d_model: int, num_heads: int) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.cache_k: Optional[torch.Tensor] = None
        self.cache_v: Optional[torch.Tensor] = None

    def reset_cache(self) -> None:
        self.cache_k = None
        self.cache_v = None

    def forward(self, x: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
        B, T, C = x.shape
        H, d_k = self.num_heads, self.d_k

        Q = self.W_q(x).view(B, T, H, d_k).transpose(1, 2)  # (B, H, T, d_k)
        K = self.W_k(x).view(B, T, H, d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, H, d_k).transpose(1, 2)

        if use_cache:
            if self.cache_k is not None:
                # Append new K, V to existing cache
                K = torch.cat([self.cache_k, K], dim=2)  # (B, H, T_cached+T, d_k)
                V = torch.cat([self.cache_v, V], dim=2)
            # Update cache
            self.cache_k = K.detach()
            self.cache_v = V.detach()

        # Attention: Q attends to all K (including cached)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)

        # Causal mask only needed during prefill (T > 1)
        if T > 1:
            seq_len_k = K.size(2)
            # Build causal mask: Q positions can attend to K positions <= their own
            q_positions = torch.arange(seq_len_k - T, seq_len_k, device=x.device)
            k_positions = torch.arange(seq_len_k, device=x.device)
            mask = q_positions.unsqueeze(1) < k_positions.unsqueeze(0)  # (T, seq_len_k)
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


class SimpleTransformer(nn.Module):
    """Minimal transformer for demonstrating KV-cache."""

    def __init__(self, vocab_size: int, d_model: int, num_heads: int, num_layers: int) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            CachedMultiHeadAttention(d_model, num_heads)
            for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = x + layer(x, use_cache=use_cache)  # simplified: no FFN/LN for brevity
        x = self.ln_f(x)
        return self.lm_head(x)

    def reset_cache(self) -> None:
        for layer in self.layers:
            layer.reset_cache()


@torch.no_grad()
def generate(
    model: SimpleTransformer,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
) -> List[int]:
    """Autoregressive generation with KV-cache."""
    model.eval()
    model.reset_cache()
    generated: List[int] = prompt_ids[0].tolist()

    # Prefill: process the entire prompt at once
    logits = model(prompt_ids, use_cache=True)  # (1, T, vocab)
    next_token = logits[0, -1].argmax().item()
    generated.append(next_token)

    # Decode: generate one token at a time
    for _ in range(max_new_tokens - 1):
        input_ids = torch.tensor([[next_token]], device=prompt_ids.device)
        logits = model(input_ids, use_cache=True)  # (1, 1, vocab)
        next_token = logits[0, -1].argmax().item()
        generated.append(next_token)

    return generated


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    vocab_size, d_model, num_heads, num_layers = 100, 64, 4, 2

    model = SimpleTransformer(vocab_size, d_model, num_heads, num_layers)
    prompt = torch.randint(0, vocab_size, (1, 5))

    # Generate with cache
    output_cached = generate(model, prompt, max_new_tokens=10)
    print(f"Prompt: {prompt[0].tolist()}")
    print(f"Generated (with cache): {output_cached}")

    # Verify: generate without cache (full recompute each step)
    model.reset_cache()
    model.eval()
    output_no_cache = prompt[0].tolist()
    for _ in range(10):
        input_ids = torch.tensor([output_no_cache])
        logits = model(input_ids, use_cache=False)
        next_token = logits[0, -1].argmax().item()
        output_no_cache.append(next_token)

    print(f"Generated (no cache):   {output_no_cache}")
    assert output_cached == output_no_cache, "Cache and no-cache outputs differ!"
    print("Cache verification passed: outputs match.")

Walkthrough

  1. Cache structure -- Each attention layer stores cache_k and cache_v tensors of shape (B, H, T_cached, d_k). These grow by one position each generation step.

  2. Prefill phase -- The entire prompt is processed in one forward pass. All K, V vectors are computed and stored in the cache. This is a parallel operation, similar to training.

  3. Decode phase -- Only the new token is projected to Q, K, V. The new K and V are concatenated with the cache. The attention computation uses the full cached sequence as keys/values but only the single new token as the query.

  4. Efficiency gain -- Without cache: generating n tokens requires n forward passes, each processing O(t) tokens (where t grows). Total: O(n^2) projections. With cache: each step projects only 1 token. Total: O(n) projections. The attention computation itself is still O(n) per step (attending over all cached keys).

  5. Memory tradeoff -- KV-cache trades memory for compute. For a model with L layers, H heads, d_k head dim, and sequence length T, the cache uses 2 * L * H * d_k * T * sizeof(dtype) bytes. For a 70B model with 80 layers at 4K context, this is several GB per sequence.

Complexity Analysis

  • Without KV-cache: Generating n tokens takes O(n^2 * d) total compute for projections alone, plus O(n^2 * d) for attention at each step.
  • With KV-cache: O(n * d) total for projections (each step projects 1 token), O(n^2 * d) total for attention (step t does O(t * d) attention).
  • Memory: O(L * n * d) for the cache across all layers. This is the bottleneck for long sequences.

Interview Tips

Interview Tip

Key discussion points: (1) Explain the asymmetry: Q is only for the new token, K/V are for the full history. (2) Memory analysis: compute the cache size for a specific model (e.g., LLaMA-70B at 4K context). (3) Optimizations: quantized KV-cache (INT8/INT4), paged attention (vLLM), sliding window attention (Mistral). (4) Batch serving complications: different sequences in a batch have different cache lengths, requiring padding or paged memory management. (5) GQA (Grouped Query Attention): reduces KV-cache size by sharing K/V across query heads.

Quiz

Quiz — 3 Questions

What is the primary benefit of KV-cache during autoregressive generation?

Why is the causal mask not needed during the single-token decode phase with KV-cache?

How does Grouped Query Attention (GQA) help reduce KV-cache memory?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.