Algorithm: Approximate Nearest Neighbor Search
Implement approximate nearest neighbor (ANN) search using a simplified HNSW-inspired approach, along with exact brute-force search for comparison. ANN is the backbone of vector databases, semantic search, and retrieval-augmented generation.
Problem Statement
Implement:
brute_force_knn(query, database, k)-- exact k-nearest neighbor searchANNIndex-- an approximate nearest neighbor index using a navigable small world (NSW) graph- Compare recall and speed of approximate vs. exact search
Inputs: Query vector of shape (d,), database of shape (n, d), number of neighbors k.
Outputs: Indices and distances of the k (approximate) nearest neighbors.
HNSW (Hierarchical Navigable Small World) builds a multi-layer graph where each node is connected to its approximate neighbors. Search starts at a coarse top layer and greedily navigates to finer layers. The key insight: random long-range edges (like a small-world network) allow O(log n) search even in high dimensions, dramatically beating O(n) brute force.
┌──────────────────────────────────────────────────────────────────┐
│ HNSW: Hierarchical Navigable Small World │
│ │
│ Layer 2 (sparse): A ──────────── D │
│ │ │ │
│ │ (long-range │ connections) │
│ │ │ │
│ Layer 1 (medium): A ─── C ────── D ─── F │
│ │ │ │ │ │
│ │ │ │ │ │
│ Layer 0 (dense): A ─ B ─ C ─ D ─ E ─ F ─ G ─ H │
│ │ ╲ │ ╱ │ ╲ │ ╱ │ ╲ │ ╱ │ ╲ │ │
│ └───┴───┴───┴───┴───┴───┴───┘ │
│ (each node has ≤ M connections) │
│ │
│ Search for query Q: │
│ ┌─────────────────────────────────────────────┐ │
│ │ Layer 2: Start at A → greedy → D (closer) │ │
│ │ Layer 1: Start at D → greedy → C (closer) │ │
│ │ Layer 0: Start at C → expand with ef cands │ │
│ │ → visit B, D, E → return top-k │ │
│ └─────────────────────────────────────────────┘ │
│ │
│ Complexity: O(log n) layers × O(ef * M) per layer │
│ vs. brute force: O(n * d) │
└──────────────────────────────────────────────────────────────────┘
The M parameter (max connections per node) directly controls the recall-speed-memory tradeoff. Setting M too low (e.g., 4) creates a poorly connected graph where search gets stuck in local minima. Setting M too high (e.g., 128) wastes memory and slows construction without proportional recall gains. Typical production values are M=16-64. Always benchmark on your specific data distribution.
Hints
- Brute force: compute all distances, use
np.argpartitionfor efficient top-k. - For the NSW graph: connect each new node to its
Mnearest existing neighbors during construction. - Search: start at a random entry point, greedily move to the neighbor closest to the query, maintain a dynamic candidate list.
- Use a max-heap (negative distances with heapq) to maintain the top-k candidates during search.
- The
ef(exploration factor) parameter controls the tradeoff between recall and speed.
Solution
import numpy as np
import heapq
from typing import List, Tuple, Set
from collections import defaultdict
def brute_force_knn(
query: np.ndarray, database: np.ndarray, k: int
) -> Tuple[np.ndarray, np.ndarray]:
"""
Exact k-nearest neighbor search.
Returns:
indices: (k,) indices of nearest neighbors
distances: (k,) L2 distances to nearest neighbors
"""
# Compute all distances efficiently
diffs = database - query[None, :] # (n, d)
distances = np.sum(diffs ** 2, axis=1) # (n,) squared L2
# Use argpartition for O(n) top-k selection (faster than full sort)
if k >= len(distances):
top_k_idx = np.argsort(distances)
else:
top_k_idx = np.argpartition(distances, k)[:k]
# Sort the top-k for consistent ordering
top_k_idx = top_k_idx[np.argsort(distances[top_k_idx])]
return top_k_idx, np.sqrt(distances[top_k_idx])
class NSWIndex:
"""Navigable Small World graph for approximate nearest neighbor search."""
def __init__(self, dim: int, M: int = 16, ef_construction: int = 200) -> None:
"""
Args:
dim: Vector dimensionality.
M: Max number of connections per node.
ef_construction: Size of candidate list during construction.
"""
self.dim = dim
self.M = M
self.ef_construction = ef_construction
self.vectors: List[np.ndarray] = []
self.graph: defaultdict = defaultdict(set) # adjacency list
self.entry_point: int = 0
def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
"""Squared L2 distance."""
return float(np.sum((a - b) ** 2))
def _search_layer(
self, query: np.ndarray, entry_points: List[int], ef: int
) -> List[Tuple[float, int]]:
"""
Greedy search on the NSW graph.
Returns:
List of (distance, index) tuples, sorted by distance.
"""
# Candidates: min-heap (closest first)
candidates = []
# Results: max-heap (farthest first, so we can prune)
results = []
visited: Set[int] = set()
for ep in entry_points:
dist = self._distance(query, self.vectors[ep])
heapq.heappush(candidates, (dist, ep))
heapq.heappush(results, (-dist, ep)) # negative for max-heap
visited.add(ep)
while candidates:
c_dist, c_idx = heapq.heappop(candidates)
# If the closest candidate is farther than the farthest result, stop
f_dist = -results[0][0]
if c_dist > f_dist and len(results) >= ef:
break
# Explore neighbors
for neighbor in self.graph[c_idx]:
if neighbor in visited:
continue
visited.add(neighbor)
n_dist = self._distance(query, self.vectors[neighbor])
# Add to results if better than the worst result, or if we need more
if len(results) < ef or n_dist < f_dist:
heapq.heappush(candidates, (n_dist, neighbor))
heapq.heappush(results, (-n_dist, neighbor))
if len(results) > ef:
heapq.heappop(results) # remove farthest
f_dist = -results[0][0]
# Convert max-heap to sorted list
return sorted([(-d, idx) for d, idx in results])
def add(self, vector: np.ndarray) -> None:
"""Add a vector to the index."""
idx = len(self.vectors)
self.vectors.append(vector.copy())
if idx == 0:
return # First node, no neighbors to connect
# Find nearest neighbors using graph search
neighbors = self._search_layer(vector, [self.entry_point], self.ef_construction)
# Connect to top-M nearest neighbors
for dist, neighbor_idx in neighbors[: self.M]:
self.graph[idx].add(neighbor_idx)
self.graph[neighbor_idx].add(idx)
# Prune if neighbor has too many connections
if len(self.graph[neighbor_idx]) > self.M:
# Keep only the M closest connections
dists = [
(self._distance(self.vectors[neighbor_idx], self.vectors[n]), n)
for n in self.graph[neighbor_idx]
]
dists.sort()
self.graph[neighbor_idx] = set(n for _, n in dists[: self.M])
# Update entry point to the node closest to the centroid (heuristic)
if idx % 100 == 0:
self.entry_point = idx # periodically refresh
def search(
self, query: np.ndarray, k: int, ef: int = 50
) -> Tuple[np.ndarray, np.ndarray]:
"""
Search for k approximate nearest neighbors.
Args:
query: (d,) query vector.
k: Number of neighbors.
ef: Exploration factor (higher = better recall, slower).
Returns:
indices: (k,) neighbor indices.
distances: (k,) L2 distances.
"""
if len(self.vectors) == 0:
return np.array([]), np.array([])
results = self._search_layer(query, [self.entry_point], max(ef, k))
results = results[:k]
indices = np.array([idx for _, idx in results])
distances = np.array([np.sqrt(d) for d, _ in results])
return indices, distances
# ---------- demo ----------
if __name__ == "__main__":
np.random.seed(42)
n, d, k = 5000, 64, 10
# Build database
database = np.random.randn(n, d).astype(np.float32)
queries = np.random.randn(5, d).astype(np.float32)
# Build NSW index
index = NSWIndex(dim=d, M=16, ef_construction=100)
for i in range(n):
index.add(database[i])
print(f"Index built with {len(index.vectors)} vectors")
# Compare exact vs approximate
for q_idx, query in enumerate(queries):
exact_idx, exact_dist = brute_force_knn(query, database, k)
approx_idx, approx_dist = index.search(query, k, ef=100)
# Compute recall: fraction of true neighbors found
recall = len(set(exact_idx) & set(approx_idx)) / k
print(f"Query {q_idx}: recall@{k} = {recall:.2f}")
Walkthrough
-
Brute force baseline -- Compute all pairwise distances in O(n * d).
argpartitionselects the top-k in O(n) expected time (vs. O(n log n) for full sort). This is the gold standard for recall. -
Graph construction -- Each new vector is connected to its M nearest existing neighbors. This creates a navigable graph where nearby vectors are linked. The greedy search during construction itself uses the graph, bootstrapping quality.
-
Greedy search -- Starting from an entry point, we explore the graph by always following edges toward vectors closer to the query. The candidate list (size
ef) allows backtracking: if the closest candidate is farther than the farthest result, we stop. -
Pruning -- When a node exceeds
Mconnections, we keep only the closest M. This bounds memory and keeps the graph navigable. -
ef parameter -- The exploration factor
efcontrols recall vs. speed. Higherefmeans more candidates are explored, improving recall at the cost of more distance computations. Atef = n, this degenerates to exact search.
Complexity Analysis
- Index construction: O(n * ef_construction * M * d) -- for each of n vectors, search the graph (ef_construction steps) and compute d-dimensional distances.
- Query time: O(ef * M * d) per query. Independent of n for fixed graph properties.
- Memory: O(n * (d + M)) -- storing vectors and adjacency lists.
- Recall: Typically 95-99% with proper tuning, at 10-100x speedup over brute force.
Interview Tips
Key discussion points: (1) Why graphs beat tree-based methods (KD-trees) in high dimensions -- trees degenerate to brute force when d > ~20, but graph search remains effective. (2) HNSW vs. NSW: HNSW adds hierarchy (like skip lists) for O(log n) entry point selection. (3) The tradeoff triangle: recall, latency, memory. (4) Quantization: reducing from float32 to int8 or binary codes cuts memory 4-32x. (5) Real-world systems: Pinecone, Weaviate, FAISS all use HNSW or IVF+PQ as their core algorithm. (6) Product quantization (PQ) compresses vectors for huge-scale (billions of vectors) search.
Quiz
Quiz — 3 Questions
What does the 'ef' (exploration factor) parameter control in HNSW search?
Why do tree-based ANN methods (like KD-trees) degrade in high dimensions while graph-based methods (HNSW) remain effective?
What is product quantization (PQ) and why is it used alongside HNSW in production systems?