Akshay’s Gradient
ML Codingadvanced50 min

Beam Search Decoding

Exercise: Beam Search Decoder

Implement beam search -- the standard decoding algorithm for sequence generation when greedy decoding is not sufficient. Used in machine translation, speech recognition, and language model generation.

Problem Statement

Implement a beam_search function that:

  1. Maintains beam_width candidate sequences at each step
  2. Expands each candidate by all possible next tokens
  3. Keeps the top-k sequences by cumulative log-probability
  4. Stops when all beams have produced the EOS token or max length is reached
  5. Returns the top sequences ranked by score (optionally with length normalization)

Inputs: A model/scoring function, start token, beam width, max length, EOS token.

Outputs: List of (sequence, score) tuples, sorted by score.

Key Concept

Beam search is a breadth-limited search over the space of possible output sequences. At each step, it keeps the beam_width highest-scoring partial sequences. This is a middle ground between greedy decoding (beam_width=1, fast but suboptimal) and exhaustive search (exponentially expensive). Beam search finds better sequences than greedy decoding because it considers multiple hypotheses simultaneously.

Interactive · Beam Search Tree Expansion
┌──────────────────────────────────────────────────────────────────┐
│               Beam Search (beam_width = 3)                       │
│                                                                  │
│   Step 0:  [BOS]                                                 │
│              │                                                   │
│   Step 1:  ┌─┼─────────────┐                                    │
│            ▼ ▼             ▼                                     │
│          "The" "A"       "In"                                    │
│         -0.5  -0.9      -1.2                                     │
│            │    │          │                                      │
│   Step 2:  │    │          │                                      │
│         ┌──┴──┐│   ┌──────┘                                     │
│         ▼     ▼▼   ▼                                             │
│     "The cat" "The dog" "A new"  ← top-3 across ALL expansions  │
│      -1.1      -1.3     -1.5                                     │
│         │        │        │                                      │
│   Step 3:        │        │                                      │
│         ▼        ▼        ▼                                      │
│    "The cat sat" "The dog ran"  "The cat EOS" ← completed!      │
│      -2.0         -2.3          -1.8                             │
│                                                                  │
│   Completed beams are scored with length normalization:           │
│   score = log_prob / ((5 + len) / 6)^alpha                       │
└──────────────────────────────────────────────────────────────────┘
Warning

A common bug is forgetting to normalize scores when comparing beams of different lengths. Without length normalization, beam search strongly favors short sequences because log-probabilities are negative -- summing more of them yields a smaller (worse) total. Always apply length normalization before comparing or ranking beams.

Hints

Info
  1. Represent each beam as (token_ids, cumulative_log_prob).
  2. At each step, for each beam, get the log probabilities of all next tokens.
  3. Compute new scores: beam_score + log_prob(next_token).
  4. Flatten all candidates from all beams, take the top beam_width.
  5. When a beam generates EOS, move it to a "completed" list (do not expand it further).
  6. For length normalization: divide the score by seq_len^alpha where alpha is typically 0.6-0.7.

Solution

import torch
import torch.nn.functional as F
from typing import List, Tuple, Callable, Optional
from dataclasses import dataclass, field
import heapq


@dataclass
class Beam:
    """A single beam (hypothesis) in beam search."""
    tokens: List[int]
    log_prob: float
    is_finished: bool = False

    @property
    def score(self) -> float:
        return self.log_prob

    def normalized_score(self, alpha: float = 0.6) -> float:
        """Length-normalized score to avoid bias toward short sequences."""
        length_penalty = ((5.0 + len(self.tokens)) / 6.0) ** alpha
        return self.log_prob / length_penalty


def beam_search(
    log_prob_fn: Callable[[torch.Tensor], torch.Tensor],
    start_token: int,
    beam_width: int = 5,
    max_length: int = 50,
    eos_token: int = 2,
    length_penalty_alpha: float = 0.6,
) -> List[Tuple[List[int], float]]:
    """
    Beam search decoding.

    Args:
        log_prob_fn: Given input_ids (1, seq_len), returns log_probs (vocab_size,)
                     for the next token.
        start_token: The initial token (e.g., BOS).
        beam_width:  Number of beams to maintain.
        max_length:  Maximum sequence length.
        eos_token:   End-of-sequence token ID.
        length_penalty_alpha: Length normalization exponent (0 = no normalization).

    Returns:
        List of (token_list, score) sorted by descending score.
    """
    # Initialize with a single beam containing just the start token
    active_beams = [Beam(tokens=[start_token], log_prob=0.0)]
    completed_beams: List[Beam] = []

    for step in range(max_length):
        if not active_beams:
            break

        all_candidates: List[Beam] = []

        for beam in active_beams:
            if beam.is_finished:
                completed_beams.append(beam)
                continue

            # Get next-token log probabilities
            input_ids = torch.tensor([beam.tokens], dtype=torch.long)
            next_log_probs = log_prob_fn(input_ids)  # (vocab_size,)

            # Get top-k candidates from this beam (no need to consider all vocab)
            topk_log_probs, topk_indices = torch.topk(next_log_probs, beam_width)

            for log_p, token_id in zip(topk_log_probs.tolist(), topk_indices.tolist()):
                new_beam = Beam(
                    tokens=beam.tokens + [token_id],
                    log_prob=beam.log_prob + log_p,
                    is_finished=(token_id == eos_token),
                )
                all_candidates.append(new_beam)

        if not all_candidates:
            break

        # Sort by score (higher is better for log-probs, which are negative)
        all_candidates.sort(
            key=lambda b: b.normalized_score(length_penalty_alpha),
            reverse=True,
        )

        # Keep top beam_width candidates
        active_beams = []
        for beam in all_candidates[:beam_width]:
            if beam.is_finished:
                completed_beams.append(beam)
            else:
                active_beams.append(beam)

    # Add any remaining active beams to completed
    completed_beams.extend(active_beams)

    # Sort completed beams by normalized score
    completed_beams.sort(
        key=lambda b: b.normalized_score(length_penalty_alpha),
        reverse=True,
    )

    return [(b.tokens, b.normalized_score(length_penalty_alpha)) for b in completed_beams]


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)

    # Create a simple "language model" with fixed transition probabilities
    vocab_size = 10
    EOS = 2
    BOS = 1

    # Transition matrix: log P(next_token | last_token)
    transition_logits = torch.randn(vocab_size, vocab_size)
    # Make EOS more likely after longer sequences
    transition_logits[:, EOS] = -5.0  # EOS is rare initially

    step_counter = [0]  # mutable counter

    def mock_log_prob_fn(input_ids: torch.Tensor) -> torch.Tensor:
        """Mock LM: log-prob depends on last token and sequence length."""
        last_token = input_ids[0, -1].item()
        logits = transition_logits[last_token].clone()
        # Increase EOS probability with sequence length
        seq_len = input_ids.shape[1]
        logits[EOS] += seq_len * 0.5
        return F.log_softmax(logits, dim=-1)

    # Run beam search
    results = beam_search(
        log_prob_fn=mock_log_prob_fn,
        start_token=BOS,
        beam_width=4,
        max_length=15,
        eos_token=EOS,
        length_penalty_alpha=0.6,
    )

    print(f"Found {len(results)} sequences:\n")
    for i, (tokens, score) in enumerate(results[:5]):
        print(f"  Beam {i}: tokens={tokens}, score={score:.4f}")

    # Compare with greedy decoding
    def greedy_decode(log_prob_fn, start_token, max_length, eos_token):
        tokens = [start_token]
        total_log_prob = 0.0
        for _ in range(max_length):
            input_ids = torch.tensor([tokens], dtype=torch.long)
            log_probs = log_prob_fn(input_ids)
            best_token = log_probs.argmax().item()
            total_log_prob += log_probs[best_token].item()
            tokens.append(best_token)
            if best_token == eos_token:
                break
        return tokens, total_log_prob

    greedy_tokens, greedy_score = greedy_decode(mock_log_prob_fn, BOS, 15, EOS)
    print(f"\nGreedy: tokens={greedy_tokens}, raw_score={greedy_score:.4f}")
    print(f"Best beam: tokens={results[0][0]}, norm_score={results[0][1]:.4f}")

Walkthrough

  1. Initialization -- Start with a single beam containing just the start token and zero cumulative log-probability.

  2. Expansion -- For each active beam, query the model for next-token log-probabilities. Create new candidate beams by appending each of the top-k next tokens. The new beam's score is the parent's score plus the new token's log-probability.

  3. Pruning -- Sort all candidates (from all beams) by score and keep only the top beam_width. This is the "beam" constraint -- we limit the search to a fixed width.

  4. EOS handling -- When a beam generates the EOS token, it is moved to the completed list and not expanded further. This frees up a beam slot for other hypotheses.

  5. Length normalization -- Raw log-probabilities favor shorter sequences (fewer negative terms to sum). The length penalty ((5 + len) / 6)^alpha counteracts this bias. The formulation from the Google NMT paper uses alpha=0.6.

  6. Top-k per beam -- We only consider the top beam_width tokens per beam, not the entire vocabulary. This optimization reduces the number of candidates from beam_width * vocab_size to beam_width^2.

Complexity Analysis

  • Time per step: O(beam_width * vocab_size) for scoring + O(beam_width^2 * log(beam_width)) for sorting candidates. In practice, the model forward pass dominates.
  • Total time: O(max_length * beam_width * model_forward_cost).
  • Space: O(beam_width * max_length) for storing beam sequences.

Beam search is much cheaper than exhaustive search (which is O(vocab_size^max_length)), but more expensive than greedy (which is O(max_length * model_forward_cost)).

Interview Tips

Interview Tip

Be ready to discuss: (1) When beam search helps vs. greedy -- it helps when the locally best token is not globally best (common in translation). (2) Beam search is not guaranteed to find the optimal sequence -- it is an approximate search. (3) The beam width tradeoff: larger beams find better sequences but are slower. (4) Length normalization is critical -- without it, beam search strongly prefers short outputs. (5) Diversity: vanilla beam search often produces very similar beams. Diverse beam search and nucleus sampling are alternatives. (6) For open-ended generation (chatbots), sampling methods (top-k, nucleus) are preferred over beam search because they produce more varied outputs.

Quiz

Quiz — 3 Questions

Why does beam search without length normalization tend to produce shorter sequences?

What happens when beam_width is set to 1?

Why does the implementation take the top-k tokens per beam rather than considering the entire vocabulary?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.