Algorithm: K-Means Clustering
Implement Lloyd's algorithm for K-Means clustering from scratch. K-Means is the most widely used clustering algorithm and a staple ML interview question.
Problem Statement
Implement a KMeans class that:
- Initializes centroids using k-means++ initialization
- Alternates between assignment (assign each point to nearest centroid) and update (recompute centroids as cluster means)
- Converges when centroids stop moving (or max iterations reached)
- Returns cluster assignments and final centroids
Inputs: Data matrix X of shape (n_samples, n_features), number of clusters k.
Outputs: Cluster assignments of shape (n_samples,), centroids of shape (k, n_features).
K-Means minimizes the within-cluster sum of squares (inertia): sum_i ||x_i - mu_{c_i}||^2. Lloyd's algorithm alternates between two steps: (1) assign each point to the nearest centroid, (2) update each centroid as the mean of its assigned points. This monotonically decreases the objective and converges to a local minimum.
┌──────────────────────────────────────────────────────────────────┐
│ K-Means Iteration Cycle (k=3) │
│ │
│ INITIALIZATION (k-means++) │
│ ┌──────────────────────┐ │
│ │ c1 = random point │ │
│ │ c2 = far from c1 │ (prob ∝ distance^2) │
│ │ c3 = far from c1,c2 │ │
│ └──────────┬───────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────┐ ┌──────────────────────┐ │
│ │ ASSIGNMENT STEP │ │ UPDATE STEP │ │
│ │ │ │ │ │
│ │ For each point x_i: │ │ For each cluster k: │ │
│ │ c_i = argmin ||x_i │────▶│ c_k = mean(points │ │
│ │ - c_k||^2 │ │ in cluster k) │ │
│ │ │ │ │ │
│ │ uses: ||x||^2 + │ │ Handle empty clusters│ │
│ │ ||c||^2 - 2*x@c^T │ │ → reinit to random pt│ │
│ └──────────────────────┘ └───────────┬───────────┘ │
│ ▲ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ CONVERGENCE CHECK │ │ │
│ └────┤ max centroid shift │◀──┘ │
│ no │ < tolerance? │ │
│ └────────┬───────────┘ │
│ yes │ │
│ ▼ │
│ ┌────────────────┐ │
│ │ Return labels, │ │
│ │ centroids, │ │
│ │ inertia │ │
│ └────────────────┘ │
└──────────────────────────────────────────────────────────────────┘
K-Means converges to a local minimum, not the global optimum. The final result depends heavily on initialization. Always run K-Means multiple times with different random seeds (e.g., scikit-learn's n_init=10 default) and select the run with the lowest inertia. K-means++ initialization helps but does not eliminate this problem.
Hints
- For k-means++ init: choose the first centroid randomly, then choose each subsequent centroid with probability proportional to
D(x)^2(distance to nearest existing centroid). - Assignment step: compute pairwise distances between all points and all centroids. Use broadcasting:
||x - c||^2 = ||x||^2 + ||c||^2 - 2*x@c^T. - Update step: for each cluster, compute the mean of assigned points.
- Check convergence: if the maximum centroid shift is below a threshold, stop.
- Handle empty clusters by re-initializing them to a random data point.
Solution
import numpy as np
from typing import Tuple
class KMeans:
"""K-Means clustering with k-means++ initialization."""
def __init__(
self,
n_clusters: int = 3,
max_iter: int = 300,
tol: float = 1e-4,
random_state: int = 42,
) -> None:
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.rng = np.random.RandomState(random_state)
self.centroids: np.ndarray = np.array([])
self.labels: np.ndarray = np.array([])
self.inertia: float = 0.0
def _init_centroids_pp(self, X: np.ndarray) -> np.ndarray:
"""K-means++ initialization for better starting centroids."""
n_samples, n_features = X.shape
centroids = np.empty((self.n_clusters, n_features))
# Choose first centroid uniformly at random
idx = self.rng.randint(n_samples)
centroids[0] = X[idx]
for k in range(1, self.n_clusters):
# Compute distance from each point to nearest existing centroid
dists = np.min(
np.sum((X[:, None, :] - centroids[None, :k, :]) ** 2, axis=2),
axis=1,
) # (n_samples,)
# Choose next centroid with probability proportional to distance^2
probs = dists / dists.sum()
idx = self.rng.choice(n_samples, p=probs)
centroids[k] = X[idx]
return centroids
def _compute_distances(self, X: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Compute squared Euclidean distances: (n_samples, n_clusters)."""
# Efficient: ||x-c||^2 = ||x||^2 + ||c||^2 - 2*x.c
X_sq = np.sum(X ** 2, axis=1, keepdims=True) # (n, 1)
C_sq = np.sum(centroids ** 2, axis=1, keepdims=True).T # (1, k)
cross = X @ centroids.T # (n, k)
return X_sq + C_sq - 2 * cross
def _assign(self, X: np.ndarray) -> np.ndarray:
"""Assign each point to the nearest centroid."""
distances = self._compute_distances(X, self.centroids)
return np.argmin(distances, axis=1)
def _update(self, X: np.ndarray, labels: np.ndarray) -> np.ndarray:
"""Recompute centroids as the mean of assigned points."""
new_centroids = np.empty_like(self.centroids)
for k in range(self.n_clusters):
mask = labels == k
if mask.sum() > 0:
new_centroids[k] = X[mask].mean(axis=0)
else:
# Handle empty cluster: reinitialize to a random point
new_centroids[k] = X[self.rng.randint(len(X))]
return new_centroids
def fit(self, X: np.ndarray) -> "KMeans":
"""Run K-Means clustering."""
self.centroids = self._init_centroids_pp(X)
for iteration in range(self.max_iter):
# Assignment step
self.labels = self._assign(X)
# Update step
new_centroids = self._update(X, self.labels)
# Check convergence
shift = np.max(np.sqrt(np.sum((new_centroids - self.centroids) ** 2, axis=1)))
self.centroids = new_centroids
if shift < self.tol:
break
# Compute final inertia
distances = self._compute_distances(X, self.centroids)
self.inertia = float(np.sum(distances[np.arange(len(X)), self.labels]))
return self
def predict(self, X: np.ndarray) -> np.ndarray:
"""Assign new points to nearest centroid."""
return self._assign(X)
# ---------- demo ----------
if __name__ == "__main__":
np.random.seed(42)
# Generate 3 well-separated clusters
cluster_centers = np.array([[0, 0], [5, 5], [10, 0]])
X = np.vstack([
center + np.random.randn(100, 2) * 0.8
for center in cluster_centers
])
kmeans = KMeans(n_clusters=3)
kmeans.fit(X)
print(f"Found centroids:\n{kmeans.centroids}")
print(f"True centers:\n{cluster_centers}")
print(f"Inertia: {kmeans.inertia:.2f}")
print(f"Label counts: {np.bincount(kmeans.labels)}")
# Verify centroids are close to true centers (up to permutation)
for true_c in cluster_centers:
dists = np.sqrt(np.sum((kmeans.centroids - true_c) ** 2, axis=1))
print(f"True center {true_c} -> nearest centroid dist: {dists.min():.4f}")
Walkthrough
-
K-means++ initialization -- Choosing initial centroids randomly can lead to poor convergence. K-means++ selects centroids that are spread out: each new centroid is chosen with probability proportional to
D(x)^2, favoring points far from existing centroids. This gives O(log k)-competitive initialization. -
Efficient distance computation -- Instead of a double loop, we use the algebraic identity
||x-c||^2 = ||x||^2 + ||c||^2 - 2*x.c. The cross-termX @ C^Tis a single matrix multiply, making the entire distance computation O(n * k * d). -
Assignment step --
argminover the distance matrix gives each point's nearest centroid. This is the E-step in the EM interpretation of K-Means. -
Update step -- The centroid of cluster k is the mean of all points assigned to it. Empty clusters are handled by re-initializing to a random point.
-
Convergence -- We measure the maximum centroid displacement. When it drops below
tol, the algorithm has converged. K-Means always converges because the objective decreases monotonically, but it may reach a local minimum.
Complexity Analysis
- Time per iteration: O(n * k * d) for the distance computation and assignment.
- Total time: O(T * n * k * d) where T = number of iterations (typically 10-300).
- Space: O(n * k) for the distance matrix, O(k * d) for centroids.
- K-means++ init: O(n * k * d) -- same as one iteration.
Interview Tips
Expect these follow-ups: (1) How to choose k? Elbow method (plot inertia vs. k) or silhouette score. (2) Sensitivity to initialization -- run multiple random restarts and pick the lowest inertia. (3) When K-Means fails: non-spherical clusters, clusters of very different sizes, high dimensions (curse of dimensionality). (4) Alternatives: K-Medoids (robust to outliers), DBSCAN (discovers arbitrary shapes), Gaussian Mixture Models (soft assignments). (5) Mini-batch K-Means for large datasets: update centroids on random subsets.
Quiz
Quiz — 3 Questions
What does k-means++ initialization guarantee compared to random initialization?
Why does K-Means struggle with clusters that have very different sizes or densities?
What is the efficient vectorized formula for computing squared Euclidean distances between n points and k centroids?