Akshay’s Gradient
ML Codingintermediate35 min

Top-k / Nucleus Sampling

Algorithm: Top-k and Nucleus Sampling

Implement the decoding strategies used in modern language models: temperature scaling, top-k sampling, and nucleus (top-p) sampling. These methods control the tradeoff between creativity and coherence in text generation.

Problem Statement

Implement four functions:

  1. temperature_scale(logits, temperature) -- scale logits by temperature
  2. top_k_sampling(logits, k) -- sample from the top-k most probable tokens
  3. nucleus_sampling(logits, p) -- sample from the smallest set of tokens whose cumulative probability exceeds p
  4. combined_sampling(logits, temperature, top_k, top_p) -- apply all three in sequence

Inputs: Logits of shape (vocab_size,), temperature float, top-k int, top-p float.

Outputs: A sampled token index.

Key Concept

Temperature < 1 sharpens the distribution (more deterministic), temperature > 1 flattens it (more random). Top-k removes all but the k most probable tokens, preventing sampling of very unlikely tokens. Nucleus (top-p) sampling adaptively selects the number of tokens: it includes the smallest set whose cumulative probability exceeds p, so the number of candidates varies by context (fewer when the model is confident, more when uncertain).

Interactive · Sampling Pipeline: Temperature → Top-k → Nucleus
┌──────────────────────────────────────────────────────────────────┐
│           Sampling Pipeline for Text Generation                   │
│                                                                  │
│   Raw logits: [2.1, 0.5, -1.0, 5.3, 0.1, -0.8, 3.2, ...]      │
│                                                                  │
│   Step 1: Temperature (T=0.7)                                    │
│   logits/T:  [3.0, 0.7, -1.4, 7.6, 0.1, -1.1, 4.6, ...]       │
│   Effect: ▓▓▓▓▓▓  → ▓▓▓▓▓▓▓▓  (sharper distribution)           │
│                                                                  │
│   Step 2: Top-k (k=4)                                           │
│   probs:  [0.05, -, -, 0.72, -, -, 0.21, ...]                  │
│   Keeps:   ✓     ✗  ✗   ✓    ✗  ✗   ✓                          │
│   Effect: removes long-tail unlikely tokens                      │
│                                                                  │
│   Step 3: Nucleus (p=0.9)                                        │
│   Sorted:  0.72 + 0.21 = 0.93 > 0.9 → keep 2 tokens           │
│   ┌─────────────────────────────────────┐                        │
│   │ Token 3: ▓▓▓▓▓▓▓▓▓▓▓▓ 0.72        │                       │
│   │ Token 6: ▓▓▓▓         0.21  ← cutoff here (cumsum > p)     │
│   │ Token 0: ▓            0.05  ← filtered out                 │
│   └─────────────────────────────────────┘                        │
│                                                                  │
│   Step 4: Renormalize and sample                                 │
│   P(token 3) = 0.77, P(token 6) = 0.23                         │
│   → torch.multinomial → sampled token                            │
└──────────────────────────────────────────────────────────────────┘
Warning

When implementing nucleus sampling, watch out for the off-by-one error in the cumulative probability mask. The correct check is cumulative_probs - sorted_probs &gt; p, which shifts the mask by one position so the token that pushes the cumulative sum over the threshold is included. Using cumulative_probs &gt; p without the shift would exclude that token, potentially filtering out important probability mass.

Hints

Info
  1. Temperature: simply divide logits by temperature before softmax: logits / temperature.
  2. Top-k: sort logits descending, set everything below the k-th largest to -inf.
  3. Nucleus/top-p: sort probabilities descending, compute cumulative sum, mask tokens where cumsum > p (keep at least one token).
  4. After masking, renormalize with softmax and sample from the resulting distribution using torch.multinomial.
  5. Apply temperature first, then top-k, then top-p (this is the standard order).

Solution

import torch
import torch.nn.functional as F
from typing import Optional


def temperature_scale(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
    """
    Scale logits by temperature.
    temperature < 1: sharper (more deterministic)
    temperature > 1: flatter (more random)
    temperature = 1: no change
    """
    if temperature <= 0:
        raise ValueError("Temperature must be positive")
    return logits / temperature


def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
    """
    Set all logits outside the top-k to -inf.
    """
    if k <= 0 or k >= logits.size(-1):
        return logits

    # Find the k-th largest value
    top_k_values, _ = torch.topk(logits, k)
    threshold = top_k_values[..., -1]  # k-th largest value

    # Mask everything below the threshold
    filtered = logits.clone()
    filtered[filtered < threshold] = float("-inf")
    return filtered


def nucleus_filter(logits: torch.Tensor, p: float) -> torch.Tensor:
    """
    Nucleus (top-p) filtering: keep the smallest set of tokens
    whose cumulative probability exceeds p.
    """
    if p >= 1.0:
        return logits

    # Sort probabilities descending
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = F.softmax(sorted_logits, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Find the cutoff: mask tokens where cumsum > p (but always keep at least one)
    sorted_mask = cumulative_probs - sorted_probs > p  # shift by one position
    sorted_logits[sorted_mask] = float("-inf")

    # Unsort to restore original order
    original_logits = torch.empty_like(logits)
    original_logits.scatter_(0, sorted_indices, sorted_logits)
    return original_logits


def sample_token(logits: torch.Tensor) -> int:
    """Sample a token from logit scores."""
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()


def combined_sampling(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
) -> int:
    """
    Full sampling pipeline: temperature -> top-k -> nucleus -> sample.
    """
    # Step 1: Temperature scaling
    logits = temperature_scale(logits, temperature)

    # Step 2: Top-k filtering
    if top_k > 0:
        logits = top_k_filter(logits, top_k)

    # Step 3: Nucleus (top-p) filtering
    if top_p < 1.0:
        logits = nucleus_filter(logits, top_p)

    # Step 4: Sample from the filtered distribution
    return sample_token(logits)


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    vocab_size = 50

    # Create some logits with a clear mode
    logits = torch.randn(vocab_size)
    logits[5] = 10.0   # token 5 is strongly preferred
    logits[12] = 5.0   # token 12 is second
    logits[30] = 3.0   # token 30 is third

    # Demonstrate temperature effects
    print("=== Temperature effects ===")
    for temp in [0.1, 0.5, 1.0, 2.0]:
        scaled = temperature_scale(logits, temp)
        probs = F.softmax(scaled, dim=-1)
        top_prob = probs[5].item()
        entropy = -(probs * probs.log().clamp(min=-100)).sum().item()
        print(f"  temp={temp:.1f}: P(top)={top_prob:.4f}, entropy={entropy:.2f}")

    # Demonstrate top-k
    print("\n=== Top-k filtering ===")
    for k in [1, 3, 10, 50]:
        filtered = top_k_filter(logits, k)
        n_active = (filtered > float("-inf")).sum().item()
        print(f"  k={k}: {n_active} active tokens")

    # Demonstrate nucleus sampling
    print("\n=== Nucleus (top-p) filtering ===")
    for p in [0.1, 0.5, 0.9, 0.95]:
        filtered = nucleus_filter(logits, p)
        n_active = (filtered > float("-inf")).sum().item()
        probs = F.softmax(filtered, dim=-1)
        mass = probs[probs > 0].sum().item()
        print(f"  p={p}: {n_active} active tokens, total mass={mass:.4f}")

    # Sampling distribution comparison
    print("\n=== Sampling 1000 tokens ===")
    configs = [
        ("Greedy (temp=0.01)",      dict(temperature=0.01)),
        ("Creative (temp=1.5)",     dict(temperature=1.5)),
        ("Top-k=5",                 dict(top_k=5)),
        ("Nucleus p=0.9",           dict(top_p=0.9)),
        ("Combined (t=0.8,k=10,p=0.9)", dict(temperature=0.8, top_k=10, top_p=0.9)),
    ]
    for name, kwargs in configs:
        counts = torch.zeros(vocab_size)
        for _ in range(1000):
            token = combined_sampling(logits.clone(), **kwargs)
            counts[token] += 1
        top3 = counts.topk(3)
        top_info = [(int(i), int(c)) for c, i in zip(top3.values, top3.indices)]
        unique = (counts > 0).sum().item()
        print(f"  {name:40s}: unique={unique:3d}, top3={top_info}")

Walkthrough

  1. Temperature scaling -- Dividing logits by temperature before softmax changes the entropy of the distribution. Low temperature (0.1) makes the model nearly deterministic. High temperature (2.0) makes the distribution nearly uniform. This is mathematically equivalent to raising probabilities to the power 1/temperature and renormalizing.

  2. Top-k filtering -- Sets all logits outside the top-k to -inf so they get zero probability after softmax. This prevents sampling of tokens in the long tail, which tend to be nonsensical. Simple but effective.

  3. Nucleus filtering -- Sorts probabilities descending and finds the smallest prefix whose cumulative sum exceeds p. The trick is the cumulative_probs - sorted_probs > p check, which ensures we include the token that pushes us over the threshold. The key advantage over top-k: the number of included tokens adapts to the model's confidence.

  4. Combined pipeline -- Temperature first (changes the shape), then top-k (hard cutoff), then nucleus (adaptive cutoff). This is the standard order used in libraries like HuggingFace.

  5. Multinomial sampling -- torch.multinomial samples from a categorical distribution. It handles the conversion from probabilities to a sampled index.

Complexity Analysis

  • Temperature: O(V) where V = vocab size.
  • Top-k: O(V) using torch.topk (which uses partial sort internally).
  • Nucleus: O(V log V) for the sort + O(V) for cumsum and masking.
  • Total: O(V log V), dominated by the sort in nucleus sampling. For V = 50K (typical LLM vocab), this is negligible compared to the model forward pass.

Interview Tips

Interview Tip

Key points: (1) Why nucleus > top-k: top-k uses a fixed number of candidates regardless of the distribution shape, while nucleus adapts. When the model is very confident, nucleus might include only 2-3 tokens; when uncertain, it might include hundreds. (2) Temperature interacts with top-k/top-p: high temperature + low top-p can still produce diverse but reasonable text. (3) In practice, most deployments use temperature=0.7-1.0 with top-p=0.9-0.95. (4) Greedy decoding (temperature near 0) is used for factual tasks; high temperature for creative tasks. (5) Repetition penalty is another common technique applied on top of these.

Quiz

Quiz — 3 Questions

What advantage does nucleus (top-p) sampling have over top-k sampling?

What happens mathematically when you set temperature to a value approaching 0?

In what order should temperature, top-k, and top-p be applied, and why does the order matter?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.