Implement KV-Cache for Autoregressive Generation
Implement key-value caching for efficient autoregressive text generation. Without KV-cache, every new token requires recomputing attention over the entire sequence. With it, you only compute attention for the new token.
Problem Statement
Implement a CachedAttention module and a generate function that:
- On the first forward pass, compute K and V for all tokens and cache them
- On subsequent passes, only compute Q, K, V for the new token, append K and V to the cache, and compute attention between the new Q and all cached K, V
- The
generatefunction produces tokens autoregressively using the cache
Inputs: Initial prompt token IDs, a model with cached attention layers, number of tokens to generate.
Outputs: Generated token sequence.
In autoregressive generation, token t attends to tokens 0..t. Without caching, generating token t recomputes all K, V vectors for tokens 0..t-1 (redundant work). KV-cache stores these vectors, reducing per-step complexity from O(t * d) to O(d) for the projection, though the attention computation itself is still O(t * d) per step.
┌───────────────────────────────────────────────────────────────┐
│ KV-Cache: Prefill vs. Decode │
│ │
│ PREFILL PHASE (process prompt "The cat sat"): │
│ ┌─────────────────────────────────────────┐ │
│ │ tokens: [The] [cat] [sat] │ │
│ │ Q: q₁ q₂ q₃ (all at once)│ │
│ │ K: k₁ k₂ k₃ ──▶ CACHE │ │
│ │ V: v₁ v₂ v₃ ──▶ CACHE │ │
│ └─────────────────────────────────────────┘ │
│ │
│ DECODE STEP 1 (generate "on"): │
│ ┌─────────────────────────────────────────┐ │
│ │ New token: [on] │ │
│ │ Q: q₄ (only 1 token!) │ │
│ │ K: k₄ ──▶ append to cache │ │
│ │ V: v₄ ──▶ append to cache │ │
│ │ │ │
│ │ Cache now: K=[k₁,k₂,k₃,k₄] │ │
│ │ V=[v₁,v₂,v₃,v₄] │ │
│ │ │ │
│ │ Attention: q₄ @ [k₁,k₂,k₃,k₄]^T │ │
│ └─────────────────────────────────────────┘ │
│ │
│ DECODE STEP 2 (generate "the"): │
│ ┌─────────────────────────────────────────┐ │
│ │ New token: [the] │ │
│ │ Q: q₅ (only 1 token!) │ │
│ │ K: k₅ ──▶ append to cache │ │
│ │ V: v₅ ──▶ append to cache │ │
│ │ │ │
│ │ Cache: K=[k₁,k₂,k₃,k₄,k₅] │ │
│ │ V=[v₁,v₂,v₃,v₄,v₅] │ │
│ └─────────────────────────────────────────┘ │
│ │
│ Without cache: step t recomputes ALL t projections O(t*d) │
│ With cache: step t computes only 1 projection O(d) │
└───────────────────────────────────────────────────────────────┘
When discussing KV-cache in interviews, be prepared to compute the memory footprint. For LLaMA-2 70B: 80 layers, 64 heads per layer (8 KV heads with GQA), d_k=128, and 4K context. KV-cache size = 2 (K and V) * 80 (layers) * 8 (KV heads) * 128 (d_k) * 4096 (seq) * 2 bytes (fp16) = ~1.3 GB per sequence. This is why batch serving is memory-bound, and techniques like paged attention (vLLM) are critical.
Hints
- Add a
cacheattribute to store past K and V tensors: a tuple(cached_K, cached_V). - During prefill (first pass): compute Q, K, V for all positions, store K and V in cache.
- During generation (subsequent passes): compute Q, K, V for only the new token, concatenate K and V with the cache, then compute attention.
- The attention mask is not needed during cached generation because we naturally only see past tokens.
- Reset the cache between sequences.
Solution
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
class CachedMultiHeadAttention(nn.Module):
"""Multi-head attention with KV-cache support."""
def __init__(self, d_model: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.cache_k: Optional[torch.Tensor] = None
self.cache_v: Optional[torch.Tensor] = None
def reset_cache(self) -> None:
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
B, T, C = x.shape
H, d_k = self.num_heads, self.d_k
Q = self.W_q(x).view(B, T, H, d_k).transpose(1, 2) # (B, H, T, d_k)
K = self.W_k(x).view(B, T, H, d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, H, d_k).transpose(1, 2)
if use_cache:
if self.cache_k is not None:
# Append new K, V to existing cache
K = torch.cat([self.cache_k, K], dim=2) # (B, H, T_cached+T, d_k)
V = torch.cat([self.cache_v, V], dim=2)
# Update cache
self.cache_k = K.detach()
self.cache_v = V.detach()
# Attention: Q attends to all K (including cached)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
# Causal mask only needed during prefill (T > 1)
if T > 1:
seq_len_k = K.size(2)
# Build causal mask: Q positions can attend to K positions <= their own
q_positions = torch.arange(seq_len_k - T, seq_len_k, device=x.device)
k_positions = torch.arange(seq_len_k, device=x.device)
mask = q_positions.unsqueeze(1) < k_positions.unsqueeze(0) # (T, seq_len_k)
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(out)
class SimpleTransformer(nn.Module):
"""Minimal transformer for demonstrating KV-cache."""
def __init__(self, vocab_size: int, d_model: int, num_heads: int, num_layers: int) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
CachedMultiHeadAttention(d_model, num_heads)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = x + layer(x, use_cache=use_cache) # simplified: no FFN/LN for brevity
x = self.ln_f(x)
return self.lm_head(x)
def reset_cache(self) -> None:
for layer in self.layers:
layer.reset_cache()
@torch.no_grad()
def generate(
model: SimpleTransformer,
prompt_ids: torch.Tensor,
max_new_tokens: int,
) -> List[int]:
"""Autoregressive generation with KV-cache."""
model.eval()
model.reset_cache()
generated: List[int] = prompt_ids[0].tolist()
# Prefill: process the entire prompt at once
logits = model(prompt_ids, use_cache=True) # (1, T, vocab)
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
# Decode: generate one token at a time
for _ in range(max_new_tokens - 1):
input_ids = torch.tensor([[next_token]], device=prompt_ids.device)
logits = model(input_ids, use_cache=True) # (1, 1, vocab)
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
return generated
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
vocab_size, d_model, num_heads, num_layers = 100, 64, 4, 2
model = SimpleTransformer(vocab_size, d_model, num_heads, num_layers)
prompt = torch.randint(0, vocab_size, (1, 5))
# Generate with cache
output_cached = generate(model, prompt, max_new_tokens=10)
print(f"Prompt: {prompt[0].tolist()}")
print(f"Generated (with cache): {output_cached}")
# Verify: generate without cache (full recompute each step)
model.reset_cache()
model.eval()
output_no_cache = prompt[0].tolist()
for _ in range(10):
input_ids = torch.tensor([output_no_cache])
logits = model(input_ids, use_cache=False)
next_token = logits[0, -1].argmax().item()
output_no_cache.append(next_token)
print(f"Generated (no cache): {output_no_cache}")
assert output_cached == output_no_cache, "Cache and no-cache outputs differ!"
print("Cache verification passed: outputs match.")
Walkthrough
-
Cache structure -- Each attention layer stores
cache_kandcache_vtensors of shape(B, H, T_cached, d_k). These grow by one position each generation step. -
Prefill phase -- The entire prompt is processed in one forward pass. All K, V vectors are computed and stored in the cache. This is a parallel operation, similar to training.
-
Decode phase -- Only the new token is projected to Q, K, V. The new K and V are concatenated with the cache. The attention computation uses the full cached sequence as keys/values but only the single new token as the query.
-
Efficiency gain -- Without cache: generating
ntokens requiresnforward passes, each processingO(t)tokens (wheretgrows). Total:O(n^2)projections. With cache: each step projects only 1 token. Total:O(n)projections. The attention computation itself is stillO(n)per step (attending over all cached keys). -
Memory tradeoff -- KV-cache trades memory for compute. For a model with
Llayers,Hheads,d_khead dim, and sequence lengthT, the cache uses2 * L * H * d_k * T * sizeof(dtype)bytes. For a 70B model with 80 layers at 4K context, this is several GB per sequence.
Complexity Analysis
- Without KV-cache: Generating
ntokens takes O(n^2 * d) total compute for projections alone, plus O(n^2 * d) for attention at each step. - With KV-cache: O(n * d) total for projections (each step projects 1 token), O(n^2 * d) total for attention (step
tdoes O(t * d) attention). - Memory: O(L * n * d) for the cache across all layers. This is the bottleneck for long sequences.
Interview Tips
Key discussion points: (1) Explain the asymmetry: Q is only for the new token, K/V are for the full history. (2) Memory analysis: compute the cache size for a specific model (e.g., LLaMA-70B at 4K context). (3) Optimizations: quantized KV-cache (INT8/INT4), paged attention (vLLM), sliding window attention (Mistral). (4) Batch serving complications: different sequences in a batch have different cache lengths, requiring padding or paged memory management. (5) GQA (Grouped Query Attention): reduces KV-cache size by sharing K/V across query heads.
Quiz
Quiz — 3 Questions
What is the primary benefit of KV-cache during autoregressive generation?
Why is the causal mask not needed during the single-token decode phase with KV-cache?
How does Grouped Query Attention (GQA) help reduce KV-cache memory?