Akshay’s Gradient
ML Codingadvanced60 min

ANN/HNSW Index

Algorithm: Approximate Nearest Neighbor Search

Implement approximate nearest neighbor (ANN) search using a simplified HNSW-inspired approach, along with exact brute-force search for comparison. ANN is the backbone of vector databases, semantic search, and retrieval-augmented generation.

Problem Statement

Implement:

  1. brute_force_knn(query, database, k) -- exact k-nearest neighbor search
  2. ANNIndex -- an approximate nearest neighbor index using a navigable small world (NSW) graph
  3. Compare recall and speed of approximate vs. exact search

Inputs: Query vector of shape (d,), database of shape (n, d), number of neighbors k.

Outputs: Indices and distances of the k (approximate) nearest neighbors.

Key Concept

HNSW (Hierarchical Navigable Small World) builds a multi-layer graph where each node is connected to its approximate neighbors. Search starts at a coarse top layer and greedily navigates to finer layers. The key insight: random long-range edges (like a small-world network) allow O(log n) search even in high dimensions, dramatically beating O(n) brute force.

Interactive · HNSW Multi-Layer Graph Structure and Search
┌──────────────────────────────────────────────────────────────────┐
│         HNSW: Hierarchical Navigable Small World                  │
│                                                                  │
│  Layer 2 (sparse):    A ──────────── D                           │
│                       │              │                            │
│                       │  (long-range │  connections)              │
│                       │              │                            │
│  Layer 1 (medium):    A ─── C ────── D ─── F                    │
│                       │     │        │     │                      │
│                       │     │        │     │                      │
│  Layer 0 (dense):     A ─ B ─ C ─ D ─ E ─ F ─ G ─ H            │
│                       │ ╲ │ ╱ │ ╲ │ ╱ │ ╲ │ ╱ │ ╲ │            │
│                       └───┴───┴───┴───┴───┴───┴───┘              │
│                       (each node has ≤ M connections)            │
│                                                                  │
│  Search for query Q:                                             │
│  ┌─────────────────────────────────────────────┐                 │
│  │ Layer 2: Start at A → greedy → D (closer)   │                │
│  │ Layer 1: Start at D → greedy → C (closer)   │                │
│  │ Layer 0: Start at C → expand with ef cands   │                │
│  │          → visit B, D, E → return top-k      │                │
│  └─────────────────────────────────────────────┘                 │
│                                                                  │
│  Complexity: O(log n) layers × O(ef * M) per layer               │
│  vs. brute force: O(n * d)                                       │
└──────────────────────────────────────────────────────────────────┘
Warning

The M parameter (max connections per node) directly controls the recall-speed-memory tradeoff. Setting M too low (e.g., 4) creates a poorly connected graph where search gets stuck in local minima. Setting M too high (e.g., 128) wastes memory and slows construction without proportional recall gains. Typical production values are M=16-64. Always benchmark on your specific data distribution.

Hints

Info
  1. Brute force: compute all distances, use np.argpartition for efficient top-k.
  2. For the NSW graph: connect each new node to its M nearest existing neighbors during construction.
  3. Search: start at a random entry point, greedily move to the neighbor closest to the query, maintain a dynamic candidate list.
  4. Use a max-heap (negative distances with heapq) to maintain the top-k candidates during search.
  5. The ef (exploration factor) parameter controls the tradeoff between recall and speed.

Solution

import numpy as np
import heapq
from typing import List, Tuple, Set
from collections import defaultdict


def brute_force_knn(
    query: np.ndarray, database: np.ndarray, k: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Exact k-nearest neighbor search.

    Returns:
        indices: (k,) indices of nearest neighbors
        distances: (k,) L2 distances to nearest neighbors
    """
    # Compute all distances efficiently
    diffs = database - query[None, :]       # (n, d)
    distances = np.sum(diffs ** 2, axis=1)  # (n,) squared L2

    # Use argpartition for O(n) top-k selection (faster than full sort)
    if k >= len(distances):
        top_k_idx = np.argsort(distances)
    else:
        top_k_idx = np.argpartition(distances, k)[:k]
        # Sort the top-k for consistent ordering
        top_k_idx = top_k_idx[np.argsort(distances[top_k_idx])]

    return top_k_idx, np.sqrt(distances[top_k_idx])


class NSWIndex:
    """Navigable Small World graph for approximate nearest neighbor search."""

    def __init__(self, dim: int, M: int = 16, ef_construction: int = 200) -> None:
        """
        Args:
            dim: Vector dimensionality.
            M: Max number of connections per node.
            ef_construction: Size of candidate list during construction.
        """
        self.dim = dim
        self.M = M
        self.ef_construction = ef_construction
        self.vectors: List[np.ndarray] = []
        self.graph: defaultdict = defaultdict(set)  # adjacency list
        self.entry_point: int = 0

    def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
        """Squared L2 distance."""
        return float(np.sum((a - b) ** 2))

    def _search_layer(
        self, query: np.ndarray, entry_points: List[int], ef: int
    ) -> List[Tuple[float, int]]:
        """
        Greedy search on the NSW graph.

        Returns:
            List of (distance, index) tuples, sorted by distance.
        """
        # Candidates: min-heap (closest first)
        candidates = []
        # Results: max-heap (farthest first, so we can prune)
        results = []
        visited: Set[int] = set()

        for ep in entry_points:
            dist = self._distance(query, self.vectors[ep])
            heapq.heappush(candidates, (dist, ep))
            heapq.heappush(results, (-dist, ep))  # negative for max-heap
            visited.add(ep)

        while candidates:
            c_dist, c_idx = heapq.heappop(candidates)

            # If the closest candidate is farther than the farthest result, stop
            f_dist = -results[0][0]
            if c_dist > f_dist and len(results) >= ef:
                break

            # Explore neighbors
            for neighbor in self.graph[c_idx]:
                if neighbor in visited:
                    continue
                visited.add(neighbor)

                n_dist = self._distance(query, self.vectors[neighbor])

                # Add to results if better than the worst result, or if we need more
                if len(results) < ef or n_dist < f_dist:
                    heapq.heappush(candidates, (n_dist, neighbor))
                    heapq.heappush(results, (-n_dist, neighbor))
                    if len(results) > ef:
                        heapq.heappop(results)  # remove farthest
                    f_dist = -results[0][0]

        # Convert max-heap to sorted list
        return sorted([(-d, idx) for d, idx in results])

    def add(self, vector: np.ndarray) -> None:
        """Add a vector to the index."""
        idx = len(self.vectors)
        self.vectors.append(vector.copy())

        if idx == 0:
            return  # First node, no neighbors to connect

        # Find nearest neighbors using graph search
        neighbors = self._search_layer(vector, [self.entry_point], self.ef_construction)

        # Connect to top-M nearest neighbors
        for dist, neighbor_idx in neighbors[: self.M]:
            self.graph[idx].add(neighbor_idx)
            self.graph[neighbor_idx].add(idx)

            # Prune if neighbor has too many connections
            if len(self.graph[neighbor_idx]) > self.M:
                # Keep only the M closest connections
                dists = [
                    (self._distance(self.vectors[neighbor_idx], self.vectors[n]), n)
                    for n in self.graph[neighbor_idx]
                ]
                dists.sort()
                self.graph[neighbor_idx] = set(n for _, n in dists[: self.M])

        # Update entry point to the node closest to the centroid (heuristic)
        if idx % 100 == 0:
            self.entry_point = idx  # periodically refresh

    def search(
        self, query: np.ndarray, k: int, ef: int = 50
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Search for k approximate nearest neighbors.

        Args:
            query: (d,) query vector.
            k: Number of neighbors.
            ef: Exploration factor (higher = better recall, slower).

        Returns:
            indices: (k,) neighbor indices.
            distances: (k,) L2 distances.
        """
        if len(self.vectors) == 0:
            return np.array([]), np.array([])

        results = self._search_layer(query, [self.entry_point], max(ef, k))
        results = results[:k]

        indices = np.array([idx for _, idx in results])
        distances = np.array([np.sqrt(d) for d, _ in results])
        return indices, distances


# ---------- demo ----------
if __name__ == "__main__":
    np.random.seed(42)
    n, d, k = 5000, 64, 10

    # Build database
    database = np.random.randn(n, d).astype(np.float32)
    queries = np.random.randn(5, d).astype(np.float32)

    # Build NSW index
    index = NSWIndex(dim=d, M=16, ef_construction=100)
    for i in range(n):
        index.add(database[i])
    print(f"Index built with {len(index.vectors)} vectors")

    # Compare exact vs approximate
    for q_idx, query in enumerate(queries):
        exact_idx, exact_dist = brute_force_knn(query, database, k)
        approx_idx, approx_dist = index.search(query, k, ef=100)

        # Compute recall: fraction of true neighbors found
        recall = len(set(exact_idx) & set(approx_idx)) / k
        print(f"Query {q_idx}: recall@{k} = {recall:.2f}")

Walkthrough

  1. Brute force baseline -- Compute all pairwise distances in O(n * d). argpartition selects the top-k in O(n) expected time (vs. O(n log n) for full sort). This is the gold standard for recall.

  2. Graph construction -- Each new vector is connected to its M nearest existing neighbors. This creates a navigable graph where nearby vectors are linked. The greedy search during construction itself uses the graph, bootstrapping quality.

  3. Greedy search -- Starting from an entry point, we explore the graph by always following edges toward vectors closer to the query. The candidate list (size ef) allows backtracking: if the closest candidate is farther than the farthest result, we stop.

  4. Pruning -- When a node exceeds M connections, we keep only the closest M. This bounds memory and keeps the graph navigable.

  5. ef parameter -- The exploration factor ef controls recall vs. speed. Higher ef means more candidates are explored, improving recall at the cost of more distance computations. At ef = n, this degenerates to exact search.

Complexity Analysis

  • Index construction: O(n * ef_construction * M * d) -- for each of n vectors, search the graph (ef_construction steps) and compute d-dimensional distances.
  • Query time: O(ef * M * d) per query. Independent of n for fixed graph properties.
  • Memory: O(n * (d + M)) -- storing vectors and adjacency lists.
  • Recall: Typically 95-99% with proper tuning, at 10-100x speedup over brute force.

Interview Tips

Interview Tip

Key discussion points: (1) Why graphs beat tree-based methods (KD-trees) in high dimensions -- trees degenerate to brute force when d > ~20, but graph search remains effective. (2) HNSW vs. NSW: HNSW adds hierarchy (like skip lists) for O(log n) entry point selection. (3) The tradeoff triangle: recall, latency, memory. (4) Quantization: reducing from float32 to int8 or binary codes cuts memory 4-32x. (5) Real-world systems: Pinecone, Weaviate, FAISS all use HNSW or IVF+PQ as their core algorithm. (6) Product quantization (PQ) compresses vectors for huge-scale (billions of vectors) search.

Quiz

Quiz — 3 Questions

What does the 'ef' (exploration factor) parameter control in HNSW search?

Why do tree-based ANN methods (like KD-trees) degrade in high dimensions while graph-based methods (HNSW) remain effective?

What is product quantization (PQ) and why is it used alongside HNSW in production systems?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.