Akshay’s Gradient
ML Codingintermediate50 min

BPE Tokenizer

Implement a BPE Tokenizer

Build a Byte-Pair Encoding tokenizer from scratch -- the algorithm behind GPT's tokenizer. BPE iteratively merges the most frequent pair of tokens to build a vocabulary that balances between character-level and word-level representations.

Problem Statement

Implement a BPETokenizer class that:

  1. train(corpus, vocab_size) -- learns merge rules from a text corpus by repeatedly merging the most frequent adjacent pair
  2. encode(text) -- tokenizes a string into a list of token IDs using the learned merges
  3. decode(ids) -- converts token IDs back to a string

Start with a character-level vocabulary (all unique bytes/characters in the corpus), then iteratively merge until reaching the target vocabulary size.

Inputs: Training corpus (string), target vocabulary size (int), text to encode (string).

Outputs: List of integer token IDs (encode), reconstructed string (decode).

Key Concept

BPE starts with individual characters and greedily merges the most frequent adjacent pair into a new token. This creates a vocabulary that represents common subwords (like "ing", "tion", "the") as single tokens while still being able to represent any text character-by-character. The merge order defines priority: earlier merges are applied first during encoding.

Interactive · BPE Merge Operations
┌───────────────────────────────────────────────────────────────┐
│                  BPE Merge Operations                         │
│                                                               │
│  Corpus: "the cat sat the mat"                                │
│                                                               │
│  Step 0 — Character tokenization:                             │
│  [t][h][e][ ][c][a][t][ ][s][a][t][ ][t][h][e][ ][m][a][t]  │
│                                                               │
│  Count pairs: (t,h)=2  (h,e)=2  (a,t)=3  (e, )=2 ...        │
│  Most frequent: (a,t) → merge into [at]                       │
│                                                               │
│  Step 1 — After merging (a,t):                                │
│  [t][h][e][ ][c][at][ ][s][at][ ][t][h][e][ ][m][at]        │
│                                                               │
│  Count pairs: (t,h)=2  (h,e)=2  ...                          │
│  Most frequent: (t,h) → merge into [th]                       │
│                                                               │
│  Step 2 — After merging (t,h):                                │
│  [th][e][ ][c][at][ ][s][at][ ][th][e][ ][m][at]            │
│                                                               │
│  Next: merge (th,e) → [the]                                   │
│                                                               │
│  Step 3 — After merging (th,e):                               │
│  [the][ ][c][at][ ][s][at][ ][the][ ][m][at]                │
│                                                               │
│  ...continues until vocab_size reached                        │
└───────────────────────────────────────────────────────────────┘
Warning

The order of merge rules matters during encoding. When tokenizing new text, you must apply merges in the exact order they were learned during training, not by frequency in the new text. Applying merges out of order can produce different tokenizations. Also be careful with left-to-right greedy merging: if the sequence is [a][b][a][b] and you merge (a,b), scanning left-to-right gives [ab][ab], but different scan strategies could give [a][ba][b].

Hints

Info
  1. Initialize the vocabulary with all unique characters in the corpus.
  2. Represent text as a list of token IDs (initially character IDs).
  3. Count all adjacent pairs in the current tokenization.
  4. Find the most frequent pair and merge all occurrences into a new token.
  5. Record the merge rule: (token_a, token_b) -> new_token_id.
  6. Repeat until vocab_size is reached.
  7. For encoding new text: start with characters, then apply merges in the order they were learned.

Solution

from typing import List, Dict, Tuple
from collections import Counter


class BPETokenizer:
    """Byte-Pair Encoding tokenizer."""

    def __init__(self) -> None:
        self.merges: List[Tuple[int, int]] = []  # ordered merge rules
        self.vocab: Dict[int, str] = {}           # id -> string

    def _get_pair_counts(self, token_ids: List[int]) -> Counter:
        """Count adjacent pairs in a token sequence."""
        counts: Counter = Counter()
        for i in range(len(token_ids) - 1):
            counts[(token_ids[i], token_ids[i + 1])] += 1
        return counts

    def _merge_pair(
        self, token_ids: List[int], pair: Tuple[int, int], new_id: int
    ) -> List[int]:
        """Replace all occurrences of `pair` with `new_id`."""
        result: List[int] = []
        i = 0
        while i < len(token_ids):
            if (
                i < len(token_ids) - 1
                and token_ids[i] == pair[0]
                and token_ids[i + 1] == pair[1]
            ):
                result.append(new_id)
                i += 2
            else:
                result.append(token_ids[i])
                i += 1
        return result

    def train(self, corpus: str, vocab_size: int) -> None:
        """Learn BPE merge rules from a corpus."""
        # Step 1: Initialize vocab with unique characters (bytes)
        chars = sorted(set(corpus))
        self.vocab = {i: ch for i, ch in enumerate(chars)}
        char_to_id = {ch: i for i, ch in self.vocab.items()}

        # Tokenize corpus at character level
        token_ids = [char_to_id[ch] for ch in corpus]

        # Step 2: Iteratively merge most frequent pairs
        next_id = len(self.vocab)
        while next_id < vocab_size:
            pair_counts = self._get_pair_counts(token_ids)
            if not pair_counts:
                break  # nothing left to merge

            best_pair = max(pair_counts, key=pair_counts.get)  # type: ignore
            if pair_counts[best_pair] < 2:
                break  # no pair appears more than once

            # Record the merge and create new vocab entry
            self.merges.append(best_pair)
            self.vocab[next_id] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]

            # Apply merge to the corpus
            token_ids = self._merge_pair(token_ids, best_pair, next_id)
            next_id += 1

    def encode(self, text: str) -> List[int]:
        """Encode text using learned merge rules."""
        # Start with character-level tokens
        char_to_id = {ch: i for i, ch in self.vocab.items() if len(ch) == 1}
        token_ids = [char_to_id[ch] for ch in text]

        # Apply merges in the order they were learned
        for merge_id, (a, b) in enumerate(self.merges):
            new_id = len([v for v in self.vocab if len(self.vocab[v]) == 1]) + merge_id
            # Find the correct new_id for this merge
            merged_str = self.vocab[a] + self.vocab[b]
            new_id = next(k for k, v in self.vocab.items() if v == merged_str)
            token_ids = self._merge_pair(token_ids, (a, b), new_id)

        return token_ids

    def decode(self, ids: List[int]) -> str:
        """Decode token IDs back to a string."""
        return "".join(self.vocab[i] for i in ids)


# ---------- demo ----------
if __name__ == "__main__":
    corpus = (
        "the cat sat on the mat. the cat ate the rat. "
        "the dog sat on the log. the dog bit the frog."
    )

    tokenizer = BPETokenizer()
    tokenizer.train(corpus, vocab_size=40)

    print("Vocabulary:")
    for i, token in sorted(tokenizer.vocab.items()):
        if len(token) > 1:
            print(f"  {i}: '{token}'")

    text = "the cat sat"
    encoded = tokenizer.encode(text)
    decoded = tokenizer.decode(encoded)

    print(f"\nOriginal:  '{text}'")
    print(f"Encoded:   {encoded}")
    print(f"Tokens:    {[tokenizer.vocab[i] for i in encoded]}")
    print(f"Decoded:   '{decoded}'")
    assert decoded == text, "Round-trip failed!"
    print("Round-trip encoding/decoding passed.")

Walkthrough

  1. Character initialization -- We start with every unique character as its own token. This guarantees we can represent any text, even if it contains characters not seen during training.

  2. Pair counting -- We scan through the token sequence and count every adjacent pair. The pair ("t", "h") might appear 100 times, making it a good merge candidate.

  3. Greedy merging -- We always merge the most frequent pair first. This is a greedy algorithm -- it does not find the globally optimal vocabulary, but it works well in practice and is fast.

  4. Merge application -- When merging pair (a, b) into token ab, we scan left-to-right and replace every occurrence. This is a linear scan per merge step.

  5. Encoding -- To tokenize new text, we start with characters and replay the merge rules in order. Each merge is applied greedily from left to right. The order matters: early merges represent more common patterns.

  6. Decoding -- Simply concatenate the string representations of each token ID. BPE is always losslessly reversible.

Complexity Analysis

  • Training: O(V * n) where V = number of merges (vocab_size - initial_chars) and n = corpus length. Each merge requires a linear scan.
  • Encoding: O(M * n) where M = number of merge rules and n = text length.
  • Decoding: O(n) linear scan.

Real implementations (like HuggingFace tokenizers) use optimized data structures (tries, priority queues) to speed up training significantly.

Interview Tips

Interview Tip

Interviewers look for: (1) Understanding of why BPE is used -- balances vocabulary size with sequence length, handles rare/unseen words gracefully. (2) The greedy nature of the algorithm and its implications. (3) Awareness of alternatives: WordPiece (used in BERT, likelihood-based selection), Unigram (SentencePiece, probabilistic). (4) Practical details: pre-tokenization (splitting on whitespace/punctuation before BPE), special tokens, byte-level BPE (GPT-2 operates on bytes, not characters, ensuring full Unicode coverage).

Quiz

Quiz — 3 Questions

Why does BPE start from individual characters rather than whole words?

How does GPT-2's byte-level BPE differ from character-level BPE?

Why is the BPE algorithm considered greedy, and what are its implications?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.