Algorithm: Top-k and Nucleus Sampling
Implement the decoding strategies used in modern language models: temperature scaling, top-k sampling, and nucleus (top-p) sampling. These methods control the tradeoff between creativity and coherence in text generation.
Problem Statement
Implement four functions:
temperature_scale(logits, temperature)-- scale logits by temperaturetop_k_sampling(logits, k)-- sample from the top-k most probable tokensnucleus_sampling(logits, p)-- sample from the smallest set of tokens whose cumulative probability exceeds pcombined_sampling(logits, temperature, top_k, top_p)-- apply all three in sequence
Inputs: Logits of shape (vocab_size,), temperature float, top-k int, top-p float.
Outputs: A sampled token index.
Temperature < 1 sharpens the distribution (more deterministic), temperature > 1 flattens it (more random). Top-k removes all but the k most probable tokens, preventing sampling of very unlikely tokens. Nucleus (top-p) sampling adaptively selects the number of tokens: it includes the smallest set whose cumulative probability exceeds p, so the number of candidates varies by context (fewer when the model is confident, more when uncertain).
┌──────────────────────────────────────────────────────────────────┐
│ Sampling Pipeline for Text Generation │
│ │
│ Raw logits: [2.1, 0.5, -1.0, 5.3, 0.1, -0.8, 3.2, ...] │
│ │
│ Step 1: Temperature (T=0.7) │
│ logits/T: [3.0, 0.7, -1.4, 7.6, 0.1, -1.1, 4.6, ...] │
│ Effect: ▓▓▓▓▓▓ → ▓▓▓▓▓▓▓▓ (sharper distribution) │
│ │
│ Step 2: Top-k (k=4) │
│ probs: [0.05, -, -, 0.72, -, -, 0.21, ...] │
│ Keeps: ✓ ✗ ✗ ✓ ✗ ✗ ✓ │
│ Effect: removes long-tail unlikely tokens │
│ │
│ Step 3: Nucleus (p=0.9) │
│ Sorted: 0.72 + 0.21 = 0.93 > 0.9 → keep 2 tokens │
│ ┌─────────────────────────────────────┐ │
│ │ Token 3: ▓▓▓▓▓▓▓▓▓▓▓▓ 0.72 │ │
│ │ Token 6: ▓▓▓▓ 0.21 ← cutoff here (cumsum > p) │
│ │ Token 0: ▓ 0.05 ← filtered out │
│ └─────────────────────────────────────┘ │
│ │
│ Step 4: Renormalize and sample │
│ P(token 3) = 0.77, P(token 6) = 0.23 │
│ → torch.multinomial → sampled token │
└──────────────────────────────────────────────────────────────────┘
When implementing nucleus sampling, watch out for the off-by-one error in the cumulative probability mask. The correct check is cumulative_probs - sorted_probs > p, which shifts the mask by one position so the token that pushes the cumulative sum over the threshold is included. Using cumulative_probs > p without the shift would exclude that token, potentially filtering out important probability mass.
Hints
- Temperature: simply divide logits by temperature before softmax:
logits / temperature. - Top-k: sort logits descending, set everything below the k-th largest to
-inf. - Nucleus/top-p: sort probabilities descending, compute cumulative sum, mask tokens where cumsum > p (keep at least one token).
- After masking, renormalize with softmax and sample from the resulting distribution using
torch.multinomial. - Apply temperature first, then top-k, then top-p (this is the standard order).
Solution
import torch
import torch.nn.functional as F
from typing import Optional
def temperature_scale(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
"""
Scale logits by temperature.
temperature < 1: sharper (more deterministic)
temperature > 1: flatter (more random)
temperature = 1: no change
"""
if temperature <= 0:
raise ValueError("Temperature must be positive")
return logits / temperature
def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
"""
Set all logits outside the top-k to -inf.
"""
if k <= 0 or k >= logits.size(-1):
return logits
# Find the k-th largest value
top_k_values, _ = torch.topk(logits, k)
threshold = top_k_values[..., -1] # k-th largest value
# Mask everything below the threshold
filtered = logits.clone()
filtered[filtered < threshold] = float("-inf")
return filtered
def nucleus_filter(logits: torch.Tensor, p: float) -> torch.Tensor:
"""
Nucleus (top-p) filtering: keep the smallest set of tokens
whose cumulative probability exceeds p.
"""
if p >= 1.0:
return logits
# Sort probabilities descending
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Find the cutoff: mask tokens where cumsum > p (but always keep at least one)
sorted_mask = cumulative_probs - sorted_probs > p # shift by one position
sorted_logits[sorted_mask] = float("-inf")
# Unsort to restore original order
original_logits = torch.empty_like(logits)
original_logits.scatter_(0, sorted_indices, sorted_logits)
return original_logits
def sample_token(logits: torch.Tensor) -> int:
"""Sample a token from logit scores."""
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).item()
def combined_sampling(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
) -> int:
"""
Full sampling pipeline: temperature -> top-k -> nucleus -> sample.
"""
# Step 1: Temperature scaling
logits = temperature_scale(logits, temperature)
# Step 2: Top-k filtering
if top_k > 0:
logits = top_k_filter(logits, top_k)
# Step 3: Nucleus (top-p) filtering
if top_p < 1.0:
logits = nucleus_filter(logits, top_p)
# Step 4: Sample from the filtered distribution
return sample_token(logits)
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
vocab_size = 50
# Create some logits with a clear mode
logits = torch.randn(vocab_size)
logits[5] = 10.0 # token 5 is strongly preferred
logits[12] = 5.0 # token 12 is second
logits[30] = 3.0 # token 30 is third
# Demonstrate temperature effects
print("=== Temperature effects ===")
for temp in [0.1, 0.5, 1.0, 2.0]:
scaled = temperature_scale(logits, temp)
probs = F.softmax(scaled, dim=-1)
top_prob = probs[5].item()
entropy = -(probs * probs.log().clamp(min=-100)).sum().item()
print(f" temp={temp:.1f}: P(top)={top_prob:.4f}, entropy={entropy:.2f}")
# Demonstrate top-k
print("\n=== Top-k filtering ===")
for k in [1, 3, 10, 50]:
filtered = top_k_filter(logits, k)
n_active = (filtered > float("-inf")).sum().item()
print(f" k={k}: {n_active} active tokens")
# Demonstrate nucleus sampling
print("\n=== Nucleus (top-p) filtering ===")
for p in [0.1, 0.5, 0.9, 0.95]:
filtered = nucleus_filter(logits, p)
n_active = (filtered > float("-inf")).sum().item()
probs = F.softmax(filtered, dim=-1)
mass = probs[probs > 0].sum().item()
print(f" p={p}: {n_active} active tokens, total mass={mass:.4f}")
# Sampling distribution comparison
print("\n=== Sampling 1000 tokens ===")
configs = [
("Greedy (temp=0.01)", dict(temperature=0.01)),
("Creative (temp=1.5)", dict(temperature=1.5)),
("Top-k=5", dict(top_k=5)),
("Nucleus p=0.9", dict(top_p=0.9)),
("Combined (t=0.8,k=10,p=0.9)", dict(temperature=0.8, top_k=10, top_p=0.9)),
]
for name, kwargs in configs:
counts = torch.zeros(vocab_size)
for _ in range(1000):
token = combined_sampling(logits.clone(), **kwargs)
counts[token] += 1
top3 = counts.topk(3)
top_info = [(int(i), int(c)) for c, i in zip(top3.values, top3.indices)]
unique = (counts > 0).sum().item()
print(f" {name:40s}: unique={unique:3d}, top3={top_info}")
Walkthrough
-
Temperature scaling -- Dividing logits by temperature before softmax changes the entropy of the distribution. Low temperature (0.1) makes the model nearly deterministic. High temperature (2.0) makes the distribution nearly uniform. This is mathematically equivalent to raising probabilities to the power
1/temperatureand renormalizing. -
Top-k filtering -- Sets all logits outside the top-k to
-infso they get zero probability after softmax. This prevents sampling of tokens in the long tail, which tend to be nonsensical. Simple but effective. -
Nucleus filtering -- Sorts probabilities descending and finds the smallest prefix whose cumulative sum exceeds p. The trick is the
cumulative_probs - sorted_probs > pcheck, which ensures we include the token that pushes us over the threshold. The key advantage over top-k: the number of included tokens adapts to the model's confidence. -
Combined pipeline -- Temperature first (changes the shape), then top-k (hard cutoff), then nucleus (adaptive cutoff). This is the standard order used in libraries like HuggingFace.
-
Multinomial sampling --
torch.multinomialsamples from a categorical distribution. It handles the conversion from probabilities to a sampled index.
Complexity Analysis
- Temperature: O(V) where V = vocab size.
- Top-k: O(V) using
torch.topk(which uses partial sort internally). - Nucleus: O(V log V) for the sort + O(V) for cumsum and masking.
- Total: O(V log V), dominated by the sort in nucleus sampling. For V = 50K (typical LLM vocab), this is negligible compared to the model forward pass.
Interview Tips
Key points: (1) Why nucleus > top-k: top-k uses a fixed number of candidates regardless of the distribution shape, while nucleus adapts. When the model is very confident, nucleus might include only 2-3 tokens; when uncertain, it might include hundreds. (2) Temperature interacts with top-k/top-p: high temperature + low top-p can still produce diverse but reasonable text. (3) In practice, most deployments use temperature=0.7-1.0 with top-p=0.9-0.95. (4) Greedy decoding (temperature near 0) is used for factual tasks; high temperature for creative tasks. (5) Repetition penalty is another common technique applied on top of these.
Quiz
Quiz — 3 Questions
What advantage does nucleus (top-p) sampling have over top-k sampling?
What happens mathematically when you set temperature to a value approaching 0?
In what order should temperature, top-k, and top-p be applied, and why does the order matter?