Implement a BPE Tokenizer
Build a Byte-Pair Encoding tokenizer from scratch -- the algorithm behind GPT's tokenizer. BPE iteratively merges the most frequent pair of tokens to build a vocabulary that balances between character-level and word-level representations.
Problem Statement
Implement a BPETokenizer class that:
train(corpus, vocab_size)-- learns merge rules from a text corpus by repeatedly merging the most frequent adjacent pairencode(text)-- tokenizes a string into a list of token IDs using the learned mergesdecode(ids)-- converts token IDs back to a string
Start with a character-level vocabulary (all unique bytes/characters in the corpus), then iteratively merge until reaching the target vocabulary size.
Inputs: Training corpus (string), target vocabulary size (int), text to encode (string).
Outputs: List of integer token IDs (encode), reconstructed string (decode).
BPE starts with individual characters and greedily merges the most frequent adjacent pair into a new token. This creates a vocabulary that represents common subwords (like "ing", "tion", "the") as single tokens while still being able to represent any text character-by-character. The merge order defines priority: earlier merges are applied first during encoding.
┌───────────────────────────────────────────────────────────────┐
│ BPE Merge Operations │
│ │
│ Corpus: "the cat sat the mat" │
│ │
│ Step 0 — Character tokenization: │
│ [t][h][e][ ][c][a][t][ ][s][a][t][ ][t][h][e][ ][m][a][t] │
│ │
│ Count pairs: (t,h)=2 (h,e)=2 (a,t)=3 (e, )=2 ... │
│ Most frequent: (a,t) → merge into [at] │
│ │
│ Step 1 — After merging (a,t): │
│ [t][h][e][ ][c][at][ ][s][at][ ][t][h][e][ ][m][at] │
│ │
│ Count pairs: (t,h)=2 (h,e)=2 ... │
│ Most frequent: (t,h) → merge into [th] │
│ │
│ Step 2 — After merging (t,h): │
│ [th][e][ ][c][at][ ][s][at][ ][th][e][ ][m][at] │
│ │
│ Next: merge (th,e) → [the] │
│ │
│ Step 3 — After merging (th,e): │
│ [the][ ][c][at][ ][s][at][ ][the][ ][m][at] │
│ │
│ ...continues until vocab_size reached │
└───────────────────────────────────────────────────────────────┘
The order of merge rules matters during encoding. When tokenizing new text, you must apply merges in the exact order they were learned during training, not by frequency in the new text. Applying merges out of order can produce different tokenizations. Also be careful with left-to-right greedy merging: if the sequence is [a][b][a][b] and you merge (a,b), scanning left-to-right gives [ab][ab], but different scan strategies could give [a][ba][b].
Hints
- Initialize the vocabulary with all unique characters in the corpus.
- Represent text as a list of token IDs (initially character IDs).
- Count all adjacent pairs in the current tokenization.
- Find the most frequent pair and merge all occurrences into a new token.
- Record the merge rule:
(token_a, token_b) -> new_token_id. - Repeat until vocab_size is reached.
- For encoding new text: start with characters, then apply merges in the order they were learned.
Solution
from typing import List, Dict, Tuple
from collections import Counter
class BPETokenizer:
"""Byte-Pair Encoding tokenizer."""
def __init__(self) -> None:
self.merges: List[Tuple[int, int]] = [] # ordered merge rules
self.vocab: Dict[int, str] = {} # id -> string
def _get_pair_counts(self, token_ids: List[int]) -> Counter:
"""Count adjacent pairs in a token sequence."""
counts: Counter = Counter()
for i in range(len(token_ids) - 1):
counts[(token_ids[i], token_ids[i + 1])] += 1
return counts
def _merge_pair(
self, token_ids: List[int], pair: Tuple[int, int], new_id: int
) -> List[int]:
"""Replace all occurrences of `pair` with `new_id`."""
result: List[int] = []
i = 0
while i < len(token_ids):
if (
i < len(token_ids) - 1
and token_ids[i] == pair[0]
and token_ids[i + 1] == pair[1]
):
result.append(new_id)
i += 2
else:
result.append(token_ids[i])
i += 1
return result
def train(self, corpus: str, vocab_size: int) -> None:
"""Learn BPE merge rules from a corpus."""
# Step 1: Initialize vocab with unique characters (bytes)
chars = sorted(set(corpus))
self.vocab = {i: ch for i, ch in enumerate(chars)}
char_to_id = {ch: i for i, ch in self.vocab.items()}
# Tokenize corpus at character level
token_ids = [char_to_id[ch] for ch in corpus]
# Step 2: Iteratively merge most frequent pairs
next_id = len(self.vocab)
while next_id < vocab_size:
pair_counts = self._get_pair_counts(token_ids)
if not pair_counts:
break # nothing left to merge
best_pair = max(pair_counts, key=pair_counts.get) # type: ignore
if pair_counts[best_pair] < 2:
break # no pair appears more than once
# Record the merge and create new vocab entry
self.merges.append(best_pair)
self.vocab[next_id] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
# Apply merge to the corpus
token_ids = self._merge_pair(token_ids, best_pair, next_id)
next_id += 1
def encode(self, text: str) -> List[int]:
"""Encode text using learned merge rules."""
# Start with character-level tokens
char_to_id = {ch: i for i, ch in self.vocab.items() if len(ch) == 1}
token_ids = [char_to_id[ch] for ch in text]
# Apply merges in the order they were learned
for merge_id, (a, b) in enumerate(self.merges):
new_id = len([v for v in self.vocab if len(self.vocab[v]) == 1]) + merge_id
# Find the correct new_id for this merge
merged_str = self.vocab[a] + self.vocab[b]
new_id = next(k for k, v in self.vocab.items() if v == merged_str)
token_ids = self._merge_pair(token_ids, (a, b), new_id)
return token_ids
def decode(self, ids: List[int]) -> str:
"""Decode token IDs back to a string."""
return "".join(self.vocab[i] for i in ids)
# ---------- demo ----------
if __name__ == "__main__":
corpus = (
"the cat sat on the mat. the cat ate the rat. "
"the dog sat on the log. the dog bit the frog."
)
tokenizer = BPETokenizer()
tokenizer.train(corpus, vocab_size=40)
print("Vocabulary:")
for i, token in sorted(tokenizer.vocab.items()):
if len(token) > 1:
print(f" {i}: '{token}'")
text = "the cat sat"
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)
print(f"\nOriginal: '{text}'")
print(f"Encoded: {encoded}")
print(f"Tokens: {[tokenizer.vocab[i] for i in encoded]}")
print(f"Decoded: '{decoded}'")
assert decoded == text, "Round-trip failed!"
print("Round-trip encoding/decoding passed.")
Walkthrough
-
Character initialization -- We start with every unique character as its own token. This guarantees we can represent any text, even if it contains characters not seen during training.
-
Pair counting -- We scan through the token sequence and count every adjacent pair. The pair
("t", "h")might appear 100 times, making it a good merge candidate. -
Greedy merging -- We always merge the most frequent pair first. This is a greedy algorithm -- it does not find the globally optimal vocabulary, but it works well in practice and is fast.
-
Merge application -- When merging pair
(a, b)into tokenab, we scan left-to-right and replace every occurrence. This is a linear scan per merge step. -
Encoding -- To tokenize new text, we start with characters and replay the merge rules in order. Each merge is applied greedily from left to right. The order matters: early merges represent more common patterns.
-
Decoding -- Simply concatenate the string representations of each token ID. BPE is always losslessly reversible.
Complexity Analysis
- Training: O(V * n) where V = number of merges (vocab_size - initial_chars) and n = corpus length. Each merge requires a linear scan.
- Encoding: O(M * n) where M = number of merge rules and n = text length.
- Decoding: O(n) linear scan.
Real implementations (like HuggingFace tokenizers) use optimized data structures (tries, priority queues) to speed up training significantly.
Interview Tips
Interviewers look for: (1) Understanding of why BPE is used -- balances vocabulary size with sequence length, handles rare/unseen words gracefully. (2) The greedy nature of the algorithm and its implications. (3) Awareness of alternatives: WordPiece (used in BERT, likelihood-based selection), Unigram (SentencePiece, probabilistic). (4) Practical details: pre-tokenization (splitting on whitespace/punctuation before BPE), special tokens, byte-level BPE (GPT-2 operates on bytes, not characters, ensuring full Unicode coverage).
Quiz
Quiz — 3 Questions
Why does BPE start from individual characters rather than whole words?
How does GPT-2's byte-level BPE differ from character-level BPE?
Why is the BPE algorithm considered greedy, and what are its implications?