Akshay’s Gradient
ML Codingbeginner35 min

K-Means Clustering

Algorithm: K-Means Clustering

Implement Lloyd's algorithm for K-Means clustering from scratch. K-Means is the most widely used clustering algorithm and a staple ML interview question.

Problem Statement

Implement a KMeans class that:

  1. Initializes centroids using k-means++ initialization
  2. Alternates between assignment (assign each point to nearest centroid) and update (recompute centroids as cluster means)
  3. Converges when centroids stop moving (or max iterations reached)
  4. Returns cluster assignments and final centroids

Inputs: Data matrix X of shape (n_samples, n_features), number of clusters k.

Outputs: Cluster assignments of shape (n_samples,), centroids of shape (k, n_features).

Key Concept

K-Means minimizes the within-cluster sum of squares (inertia): sum_i ||x_i - mu_{c_i}||^2. Lloyd's algorithm alternates between two steps: (1) assign each point to the nearest centroid, (2) update each centroid as the mean of its assigned points. This monotonically decreases the objective and converges to a local minimum.

Interactive · K-Means Lloyd's Algorithm Cycle
┌──────────────────────────────────────────────────────────────────┐
│              K-Means Iteration Cycle (k=3)                        │
│                                                                  │
│   INITIALIZATION (k-means++)                                     │
│   ┌──────────────────────┐                                       │
│   │  c1 = random point   │                                       │
│   │  c2 = far from c1    │  (prob ∝ distance^2)                 │
│   │  c3 = far from c1,c2 │                                       │
│   └──────────┬───────────┘                                       │
│              │                                                   │
│              ▼                                                   │
│   ┌──────────────────────┐     ┌──────────────────────┐          │
│   │   ASSIGNMENT STEP    │     │    UPDATE STEP        │         │
│   │                      │     │                       │         │
│   │  For each point x_i: │     │  For each cluster k:  │         │
│   │  c_i = argmin ||x_i  │────▶│  c_k = mean(points    │         │
│   │         - c_k||^2    │     │        in cluster k)  │         │
│   │                      │     │                       │         │
│   │  uses: ||x||^2 +     │     │  Handle empty clusters│         │
│   │  ||c||^2 - 2*x@c^T  │     │  → reinit to random pt│         │
│   └──────────────────────┘     └───────────┬───────────┘         │
│              ▲                              │                     │
│              │    ┌────────────────────┐    │                     │
│              │    │  CONVERGENCE CHECK │    │                     │
│              └────┤  max centroid shift │◀──┘                     │
│           no      │  < tolerance?      │                          │
│                   └────────┬───────────┘                          │
│                       yes  │                                      │
│                            ▼                                      │
│                   ┌────────────────┐                              │
│                   │ Return labels, │                              │
│                   │ centroids,     │                              │
│                   │ inertia        │                              │
│                   └────────────────┘                              │
└──────────────────────────────────────────────────────────────────┘
Warning

K-Means converges to a local minimum, not the global optimum. The final result depends heavily on initialization. Always run K-Means multiple times with different random seeds (e.g., scikit-learn's n_init=10 default) and select the run with the lowest inertia. K-means++ initialization helps but does not eliminate this problem.

Hints

Info
  1. For k-means++ init: choose the first centroid randomly, then choose each subsequent centroid with probability proportional to D(x)^2 (distance to nearest existing centroid).
  2. Assignment step: compute pairwise distances between all points and all centroids. Use broadcasting: ||x - c||^2 = ||x||^2 + ||c||^2 - 2*x@c^T.
  3. Update step: for each cluster, compute the mean of assigned points.
  4. Check convergence: if the maximum centroid shift is below a threshold, stop.
  5. Handle empty clusters by re-initializing them to a random data point.

Solution

import numpy as np
from typing import Tuple


class KMeans:
    """K-Means clustering with k-means++ initialization."""

    def __init__(
        self,
        n_clusters: int = 3,
        max_iter: int = 300,
        tol: float = 1e-4,
        random_state: int = 42,
    ) -> None:
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.tol = tol
        self.rng = np.random.RandomState(random_state)
        self.centroids: np.ndarray = np.array([])
        self.labels: np.ndarray = np.array([])
        self.inertia: float = 0.0

    def _init_centroids_pp(self, X: np.ndarray) -> np.ndarray:
        """K-means++ initialization for better starting centroids."""
        n_samples, n_features = X.shape
        centroids = np.empty((self.n_clusters, n_features))

        # Choose first centroid uniformly at random
        idx = self.rng.randint(n_samples)
        centroids[0] = X[idx]

        for k in range(1, self.n_clusters):
            # Compute distance from each point to nearest existing centroid
            dists = np.min(
                np.sum((X[:, None, :] - centroids[None, :k, :]) ** 2, axis=2),
                axis=1,
            )  # (n_samples,)
            # Choose next centroid with probability proportional to distance^2
            probs = dists / dists.sum()
            idx = self.rng.choice(n_samples, p=probs)
            centroids[k] = X[idx]

        return centroids

    def _compute_distances(self, X: np.ndarray, centroids: np.ndarray) -> np.ndarray:
        """Compute squared Euclidean distances: (n_samples, n_clusters)."""
        # Efficient: ||x-c||^2 = ||x||^2 + ||c||^2 - 2*x.c
        X_sq = np.sum(X ** 2, axis=1, keepdims=True)       # (n, 1)
        C_sq = np.sum(centroids ** 2, axis=1, keepdims=True).T  # (1, k)
        cross = X @ centroids.T                              # (n, k)
        return X_sq + C_sq - 2 * cross

    def _assign(self, X: np.ndarray) -> np.ndarray:
        """Assign each point to the nearest centroid."""
        distances = self._compute_distances(X, self.centroids)
        return np.argmin(distances, axis=1)

    def _update(self, X: np.ndarray, labels: np.ndarray) -> np.ndarray:
        """Recompute centroids as the mean of assigned points."""
        new_centroids = np.empty_like(self.centroids)
        for k in range(self.n_clusters):
            mask = labels == k
            if mask.sum() > 0:
                new_centroids[k] = X[mask].mean(axis=0)
            else:
                # Handle empty cluster: reinitialize to a random point
                new_centroids[k] = X[self.rng.randint(len(X))]
        return new_centroids

    def fit(self, X: np.ndarray) -> "KMeans":
        """Run K-Means clustering."""
        self.centroids = self._init_centroids_pp(X)

        for iteration in range(self.max_iter):
            # Assignment step
            self.labels = self._assign(X)

            # Update step
            new_centroids = self._update(X, self.labels)

            # Check convergence
            shift = np.max(np.sqrt(np.sum((new_centroids - self.centroids) ** 2, axis=1)))
            self.centroids = new_centroids

            if shift < self.tol:
                break

        # Compute final inertia
        distances = self._compute_distances(X, self.centroids)
        self.inertia = float(np.sum(distances[np.arange(len(X)), self.labels]))
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Assign new points to nearest centroid."""
        return self._assign(X)


# ---------- demo ----------
if __name__ == "__main__":
    np.random.seed(42)

    # Generate 3 well-separated clusters
    cluster_centers = np.array([[0, 0], [5, 5], [10, 0]])
    X = np.vstack([
        center + np.random.randn(100, 2) * 0.8
        for center in cluster_centers
    ])

    kmeans = KMeans(n_clusters=3)
    kmeans.fit(X)

    print(f"Found centroids:\n{kmeans.centroids}")
    print(f"True centers:\n{cluster_centers}")
    print(f"Inertia: {kmeans.inertia:.2f}")
    print(f"Label counts: {np.bincount(kmeans.labels)}")

    # Verify centroids are close to true centers (up to permutation)
    for true_c in cluster_centers:
        dists = np.sqrt(np.sum((kmeans.centroids - true_c) ** 2, axis=1))
        print(f"True center {true_c} -> nearest centroid dist: {dists.min():.4f}")

Walkthrough

  1. K-means++ initialization -- Choosing initial centroids randomly can lead to poor convergence. K-means++ selects centroids that are spread out: each new centroid is chosen with probability proportional to D(x)^2, favoring points far from existing centroids. This gives O(log k)-competitive initialization.

  2. Efficient distance computation -- Instead of a double loop, we use the algebraic identity ||x-c||^2 = ||x||^2 + ||c||^2 - 2*x.c. The cross-term X @ C^T is a single matrix multiply, making the entire distance computation O(n * k * d).

  3. Assignment step -- argmin over the distance matrix gives each point's nearest centroid. This is the E-step in the EM interpretation of K-Means.

  4. Update step -- The centroid of cluster k is the mean of all points assigned to it. Empty clusters are handled by re-initializing to a random point.

  5. Convergence -- We measure the maximum centroid displacement. When it drops below tol, the algorithm has converged. K-Means always converges because the objective decreases monotonically, but it may reach a local minimum.

Complexity Analysis

  • Time per iteration: O(n * k * d) for the distance computation and assignment.
  • Total time: O(T * n * k * d) where T = number of iterations (typically 10-300).
  • Space: O(n * k) for the distance matrix, O(k * d) for centroids.
  • K-means++ init: O(n * k * d) -- same as one iteration.

Interview Tips

Interview Tip

Expect these follow-ups: (1) How to choose k? Elbow method (plot inertia vs. k) or silhouette score. (2) Sensitivity to initialization -- run multiple random restarts and pick the lowest inertia. (3) When K-Means fails: non-spherical clusters, clusters of very different sizes, high dimensions (curse of dimensionality). (4) Alternatives: K-Medoids (robust to outliers), DBSCAN (discovers arbitrary shapes), Gaussian Mixture Models (soft assignments). (5) Mini-batch K-Means for large datasets: update centroids on random subsets.

Quiz

Quiz — 3 Questions

What does k-means++ initialization guarantee compared to random initialization?

Why does K-Means struggle with clusters that have very different sizes or densities?

What is the efficient vectorized formula for computing squared Euclidean distances between n points and k centroids?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.