Exercise: Beam Search Decoder
Implement beam search -- the standard decoding algorithm for sequence generation when greedy decoding is not sufficient. Used in machine translation, speech recognition, and language model generation.
Problem Statement
Implement a beam_search function that:
- Maintains
beam_widthcandidate sequences at each step - Expands each candidate by all possible next tokens
- Keeps the top-k sequences by cumulative log-probability
- Stops when all beams have produced the EOS token or max length is reached
- Returns the top sequences ranked by score (optionally with length normalization)
Inputs: A model/scoring function, start token, beam width, max length, EOS token.
Outputs: List of (sequence, score) tuples, sorted by score.
Beam search is a breadth-limited search over the space of possible output sequences. At each step, it keeps the beam_width highest-scoring partial sequences. This is a middle ground between greedy decoding (beam_width=1, fast but suboptimal) and exhaustive search (exponentially expensive). Beam search finds better sequences than greedy decoding because it considers multiple hypotheses simultaneously.
┌──────────────────────────────────────────────────────────────────┐
│ Beam Search (beam_width = 3) │
│ │
│ Step 0: [BOS] │
│ │ │
│ Step 1: ┌─┼─────────────┐ │
│ ▼ ▼ ▼ │
│ "The" "A" "In" │
│ -0.5 -0.9 -1.2 │
│ │ │ │ │
│ Step 2: │ │ │ │
│ ┌──┴──┐│ ┌──────┘ │
│ ▼ ▼▼ ▼ │
│ "The cat" "The dog" "A new" ← top-3 across ALL expansions │
│ -1.1 -1.3 -1.5 │
│ │ │ │ │
│ Step 3: │ │ │
│ ▼ ▼ ▼ │
│ "The cat sat" "The dog ran" "The cat EOS" ← completed! │
│ -2.0 -2.3 -1.8 │
│ │
│ Completed beams are scored with length normalization: │
│ score = log_prob / ((5 + len) / 6)^alpha │
└──────────────────────────────────────────────────────────────────┘
A common bug is forgetting to normalize scores when comparing beams of different lengths. Without length normalization, beam search strongly favors short sequences because log-probabilities are negative -- summing more of them yields a smaller (worse) total. Always apply length normalization before comparing or ranking beams.
Hints
- Represent each beam as
(token_ids, cumulative_log_prob). - At each step, for each beam, get the log probabilities of all next tokens.
- Compute new scores:
beam_score + log_prob(next_token). - Flatten all candidates from all beams, take the top
beam_width. - When a beam generates EOS, move it to a "completed" list (do not expand it further).
- For length normalization: divide the score by
seq_len^alphawhere alpha is typically 0.6-0.7.
Solution
import torch
import torch.nn.functional as F
from typing import List, Tuple, Callable, Optional
from dataclasses import dataclass, field
import heapq
@dataclass
class Beam:
"""A single beam (hypothesis) in beam search."""
tokens: List[int]
log_prob: float
is_finished: bool = False
@property
def score(self) -> float:
return self.log_prob
def normalized_score(self, alpha: float = 0.6) -> float:
"""Length-normalized score to avoid bias toward short sequences."""
length_penalty = ((5.0 + len(self.tokens)) / 6.0) ** alpha
return self.log_prob / length_penalty
def beam_search(
log_prob_fn: Callable[[torch.Tensor], torch.Tensor],
start_token: int,
beam_width: int = 5,
max_length: int = 50,
eos_token: int = 2,
length_penalty_alpha: float = 0.6,
) -> List[Tuple[List[int], float]]:
"""
Beam search decoding.
Args:
log_prob_fn: Given input_ids (1, seq_len), returns log_probs (vocab_size,)
for the next token.
start_token: The initial token (e.g., BOS).
beam_width: Number of beams to maintain.
max_length: Maximum sequence length.
eos_token: End-of-sequence token ID.
length_penalty_alpha: Length normalization exponent (0 = no normalization).
Returns:
List of (token_list, score) sorted by descending score.
"""
# Initialize with a single beam containing just the start token
active_beams = [Beam(tokens=[start_token], log_prob=0.0)]
completed_beams: List[Beam] = []
for step in range(max_length):
if not active_beams:
break
all_candidates: List[Beam] = []
for beam in active_beams:
if beam.is_finished:
completed_beams.append(beam)
continue
# Get next-token log probabilities
input_ids = torch.tensor([beam.tokens], dtype=torch.long)
next_log_probs = log_prob_fn(input_ids) # (vocab_size,)
# Get top-k candidates from this beam (no need to consider all vocab)
topk_log_probs, topk_indices = torch.topk(next_log_probs, beam_width)
for log_p, token_id in zip(topk_log_probs.tolist(), topk_indices.tolist()):
new_beam = Beam(
tokens=beam.tokens + [token_id],
log_prob=beam.log_prob + log_p,
is_finished=(token_id == eos_token),
)
all_candidates.append(new_beam)
if not all_candidates:
break
# Sort by score (higher is better for log-probs, which are negative)
all_candidates.sort(
key=lambda b: b.normalized_score(length_penalty_alpha),
reverse=True,
)
# Keep top beam_width candidates
active_beams = []
for beam in all_candidates[:beam_width]:
if beam.is_finished:
completed_beams.append(beam)
else:
active_beams.append(beam)
# Add any remaining active beams to completed
completed_beams.extend(active_beams)
# Sort completed beams by normalized score
completed_beams.sort(
key=lambda b: b.normalized_score(length_penalty_alpha),
reverse=True,
)
return [(b.tokens, b.normalized_score(length_penalty_alpha)) for b in completed_beams]
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
# Create a simple "language model" with fixed transition probabilities
vocab_size = 10
EOS = 2
BOS = 1
# Transition matrix: log P(next_token | last_token)
transition_logits = torch.randn(vocab_size, vocab_size)
# Make EOS more likely after longer sequences
transition_logits[:, EOS] = -5.0 # EOS is rare initially
step_counter = [0] # mutable counter
def mock_log_prob_fn(input_ids: torch.Tensor) -> torch.Tensor:
"""Mock LM: log-prob depends on last token and sequence length."""
last_token = input_ids[0, -1].item()
logits = transition_logits[last_token].clone()
# Increase EOS probability with sequence length
seq_len = input_ids.shape[1]
logits[EOS] += seq_len * 0.5
return F.log_softmax(logits, dim=-1)
# Run beam search
results = beam_search(
log_prob_fn=mock_log_prob_fn,
start_token=BOS,
beam_width=4,
max_length=15,
eos_token=EOS,
length_penalty_alpha=0.6,
)
print(f"Found {len(results)} sequences:\n")
for i, (tokens, score) in enumerate(results[:5]):
print(f" Beam {i}: tokens={tokens}, score={score:.4f}")
# Compare with greedy decoding
def greedy_decode(log_prob_fn, start_token, max_length, eos_token):
tokens = [start_token]
total_log_prob = 0.0
for _ in range(max_length):
input_ids = torch.tensor([tokens], dtype=torch.long)
log_probs = log_prob_fn(input_ids)
best_token = log_probs.argmax().item()
total_log_prob += log_probs[best_token].item()
tokens.append(best_token)
if best_token == eos_token:
break
return tokens, total_log_prob
greedy_tokens, greedy_score = greedy_decode(mock_log_prob_fn, BOS, 15, EOS)
print(f"\nGreedy: tokens={greedy_tokens}, raw_score={greedy_score:.4f}")
print(f"Best beam: tokens={results[0][0]}, norm_score={results[0][1]:.4f}")
Walkthrough
-
Initialization -- Start with a single beam containing just the start token and zero cumulative log-probability.
-
Expansion -- For each active beam, query the model for next-token log-probabilities. Create new candidate beams by appending each of the top-k next tokens. The new beam's score is the parent's score plus the new token's log-probability.
-
Pruning -- Sort all candidates (from all beams) by score and keep only the top
beam_width. This is the "beam" constraint -- we limit the search to a fixed width. -
EOS handling -- When a beam generates the EOS token, it is moved to the completed list and not expanded further. This frees up a beam slot for other hypotheses.
-
Length normalization -- Raw log-probabilities favor shorter sequences (fewer negative terms to sum). The length penalty
((5 + len) / 6)^alphacounteracts this bias. The formulation from the Google NMT paper uses alpha=0.6. -
Top-k per beam -- We only consider the top
beam_widthtokens per beam, not the entire vocabulary. This optimization reduces the number of candidates frombeam_width * vocab_sizetobeam_width^2.
Complexity Analysis
- Time per step: O(beam_width * vocab_size) for scoring + O(beam_width^2 * log(beam_width)) for sorting candidates. In practice, the model forward pass dominates.
- Total time: O(max_length * beam_width * model_forward_cost).
- Space: O(beam_width * max_length) for storing beam sequences.
Beam search is much cheaper than exhaustive search (which is O(vocab_size^max_length)), but more expensive than greedy (which is O(max_length * model_forward_cost)).
Interview Tips
Be ready to discuss: (1) When beam search helps vs. greedy -- it helps when the locally best token is not globally best (common in translation). (2) Beam search is not guaranteed to find the optimal sequence -- it is an approximate search. (3) The beam width tradeoff: larger beams find better sequences but are slower. (4) Length normalization is critical -- without it, beam search strongly prefers short outputs. (5) Diversity: vanilla beam search often produces very similar beams. Diverse beam search and nucleus sampling are alternatives. (6) For open-ended generation (chatbots), sampling methods (top-k, nucleus) are preferred over beam search because they produce more varied outputs.
Quiz
Quiz — 3 Questions
Why does beam search without length normalization tend to produce shorter sequences?
What happens when beam_width is set to 1?
Why does the implementation take the top-k tokens per beam rather than considering the entire vocabulary?