Akshay’s Gradient
ML Codingintermediate55 min

Decision Tree from Scratch

Algorithm: Decision Tree from Scratch

Implement a CART (Classification and Regression Trees) decision tree classifier from scratch. Decision trees are the building blocks of random forests and gradient boosting -- the most successful algorithms for tabular data.

Problem Statement

Implement a DecisionTree class that:

  1. Recursively splits the data by finding the feature and threshold that maximize information gain (or minimize Gini impurity)
  2. Stops splitting when a leaf is pure, max depth is reached, or a node has too few samples
  3. Predicts the class by traversing the tree to a leaf and returning the majority class

Inputs: Feature matrix X of shape (n_samples, n_features), labels y of shape (n_samples,).

Outputs: A tree structure that can classify new samples.

Key Concept

CART builds a binary tree by greedily finding the best split at each node. For classification, the best split maximizes information gain: IG = H(parent) - [n_left/n * H(left) + n_right/n * H(right)], where H is the Gini impurity 1 - sum(p_i^2) or entropy -sum(p_i * log(p_i)). This greedy approach does not find the globally optimal tree but is efficient and works well in practice.

Interactive · Decision Tree Recursive Splitting
┌──────────────────────────────────────────────────────────────────┐
│           Decision Tree Construction (CART)                       │
│                                                                  │
│                     Root: 300 samples                             │
│                     Gini = 0.50                                   │
│                   ┌─────────────────┐                             │
│                   │ Feature 0       │                             │
│                   │ <= 0.42?        │                             │
│                   └────┬───────┬────┘                             │
│              yes ╱              ╲ no                              │
│                ╱                  ╲                               │
│    ┌───────────────┐    ┌───────────────┐                        │
│    │ 180 samples   │    │ 120 samples   │                        │
│    │ Gini = 0.32   │    │ Gini = 0.38   │                        │
│    │ Feature 1     │    │ Feature 0     │                        │
│    │ <= -0.15?     │    │ <= 1.20?      │                        │
│    └──┬────────┬───┘    └──┬────────┬───┘                        │
│     ╱            ╲       ╱            ╲                           │
│   ╱                ╲   ╱                ╲                        │
│  ┌──────┐  ┌──────┐  ┌──────┐  ┌──────────┐                     │
│  │ LEAF │  │ LEAF │  │ LEAF │  │   LEAF   │                      │
│  │ Cls 0│  │ Cls 1│  │ Cls 1│  │   Cls 0  │                     │
│  │Gini=0│  │Gini=0│  │G=0.1 │  │  Gini=0  │                     │
│  └──────┘  └──────┘  └──────┘  └──────────┘                     │
│                                                                  │
│   Stopping conditions: pure node, max_depth, min_samples         │
│   Each split greedily maximizes information gain                  │
└──────────────────────────────────────────────────────────────────┘
Warning

Without depth limits or pruning, a decision tree will create one leaf per training sample, perfectly memorizing the data and achieving 0% training error but terrible generalization. Always set max_depth and min_samples_split to control tree complexity. In interviews, be prepared to discuss pre-pruning (stopping criteria) vs. post-pruning (grow full tree, then cut back).

Hints

Info
  1. For each node, iterate over all features and all unique thresholds (midpoints between sorted values).
  2. For each candidate split, compute the weighted impurity of the two child nodes.
  3. Pick the split with the lowest weighted impurity (highest information gain).
  4. Recursively apply to left and right children.
  5. A leaf node stores the majority class (mode of labels).
  6. Use Gini impurity: 1 - sum(p_i^2) where p_i is the fraction of class i.

Solution

import numpy as np
from typing import Optional, Tuple
from dataclasses import dataclass


@dataclass
class Node:
    """A node in the decision tree."""
    feature_idx: Optional[int] = None   # split feature index
    threshold: Optional[float] = None    # split threshold
    left: Optional["Node"] = None        # left child (feature <= threshold)
    right: Optional["Node"] = None       # right child (feature > threshold)
    value: Optional[int] = None          # leaf prediction (majority class)


class DecisionTree:
    """CART decision tree classifier."""

    def __init__(self, max_depth: int = 10, min_samples_split: int = 2) -> None:
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root: Optional[Node] = None

    def _gini(self, y: np.ndarray) -> float:
        """Compute Gini impurity."""
        if len(y) == 0:
            return 0.0
        _, counts = np.unique(y, return_counts=True)
        probs = counts / len(y)
        return 1.0 - np.sum(probs ** 2)

    def _information_gain(
        self, y: np.ndarray, y_left: np.ndarray, y_right: np.ndarray
    ) -> float:
        """Compute information gain from a split."""
        n = len(y)
        if n == 0:
            return 0.0
        parent_gini = self._gini(y)
        weighted_child_gini = (
            len(y_left) / n * self._gini(y_left)
            + len(y_right) / n * self._gini(y_right)
        )
        return parent_gini - weighted_child_gini

    def _best_split(
        self, X: np.ndarray, y: np.ndarray
    ) -> Tuple[Optional[int], Optional[float], float]:
        """Find the best feature and threshold to split on."""
        best_gain = 0.0
        best_feature = None
        best_threshold = None

        n_samples, n_features = X.shape

        for feature_idx in range(n_features):
            values = X[:, feature_idx]
            # Use midpoints between sorted unique values as candidate thresholds
            unique_vals = np.unique(values)
            if len(unique_vals) <= 1:
                continue
            thresholds = (unique_vals[:-1] + unique_vals[1:]) / 2.0

            for threshold in thresholds:
                left_mask = values <= threshold
                right_mask = ~left_mask

                if left_mask.sum() == 0 or right_mask.sum() == 0:
                    continue

                gain = self._information_gain(y, y[left_mask], y[right_mask])
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature_idx
                    best_threshold = threshold

        return best_feature, best_threshold, best_gain

    def _build_tree(self, X: np.ndarray, y: np.ndarray, depth: int) -> Node:
        """Recursively build the decision tree."""
        # Leaf conditions
        n_classes = len(np.unique(y))
        if (
            n_classes == 1                      # pure node
            or depth >= self.max_depth          # max depth reached
            or len(y) < self.min_samples_split  # too few samples
        ):
            # Return leaf with majority class
            return Node(value=int(np.bincount(y).argmax()))

        # Find the best split
        feature_idx, threshold, gain = self._best_split(X, y)

        if feature_idx is None or gain <= 0:
            return Node(value=int(np.bincount(y).argmax()))

        # Split the data
        left_mask = X[:, feature_idx] <= threshold
        right_mask = ~left_mask

        # Recursively build children
        left_child = self._build_tree(X[left_mask], y[left_mask], depth + 1)
        right_child = self._build_tree(X[right_mask], y[right_mask], depth + 1)

        return Node(
            feature_idx=feature_idx,
            threshold=threshold,
            left=left_child,
            right=right_child,
        )

    def fit(self, X: np.ndarray, y: np.ndarray) -> "DecisionTree":
        """Build the decision tree from training data."""
        self.root = self._build_tree(X, y.astype(int), depth=0)
        return self

    def _predict_one(self, x: np.ndarray, node: Node) -> int:
        """Traverse the tree to predict a single sample."""
        if node.value is not None:
            return node.value
        if x[node.feature_idx] <= node.threshold:
            return self._predict_one(x, node.left)
        else:
            return self._predict_one(x, node.right)

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Predict class labels for samples in X."""
        return np.array([self._predict_one(x, self.root) for x in X])


# ---------- demo ----------
if __name__ == "__main__":
    np.random.seed(42)

    # Generate a simple 2D classification dataset
    from sklearn.datasets import make_moons
    X, y = make_moons(n_samples=300, noise=0.2, random_state=42)

    # Split
    train_X, test_X = X[:200], X[200:]
    train_y, test_y = y[:200], y[200:]

    tree = DecisionTree(max_depth=5, min_samples_split=5)
    tree.fit(train_X, train_y)

    train_preds = tree.predict(train_X)
    test_preds = tree.predict(test_X)

    train_acc = (train_preds == train_y).mean()
    test_acc = (test_preds == test_y).mean()
    print(f"Train accuracy: {train_acc:.4f}")
    print(f"Test accuracy:  {test_acc:.4f}")

    # Print tree structure
    def print_tree(node: Node, depth: int = 0) -> None:
        indent = "  " * depth
        if node.value is not None:
            print(f"{indent}Leaf: class={node.value}")
        else:
            print(f"{indent}Feature {node.feature_idx} <= {node.threshold:.3f}")
            print_tree(node.left, depth + 1)
            print_tree(node.right, depth + 1)

    print("\nTree structure:")
    print_tree(tree.root)

Walkthrough

  1. Gini impurity -- Measures the probability of misclassifying a randomly chosen sample if it were labeled according to the class distribution. Gini = 0 means perfect purity (all one class). Maximum Gini for binary classification is 0.5 (50/50 split).

  2. Threshold selection -- For each feature, we sort the unique values and use midpoints as candidate thresholds. This is more efficient than testing every data point and avoids boundary issues.

  3. Information gain -- The reduction in impurity from a split. We greedily pick the split that maximizes this. The weighted sum ensures we account for the sizes of the child nodes.

  4. Recursive building -- The tree grows depth-first. Each internal node stores a feature index and threshold. Each leaf stores the majority class of its training samples.

  5. Stopping criteria -- The tree stops growing when: a node is pure (all same class), maximum depth is reached, or too few samples remain. Without these limits, the tree would overfit by creating a leaf for every training sample.

  6. Prediction -- Traverse from root to leaf by comparing the sample's feature value against each node's threshold. Left for <=, right for >.

Complexity Analysis

  • Training time: O(n * d * n * log(n)) per node for finding the best split (n thresholds per feature, d features). Total: O(n^2 * d * log(n)) in the worst case (tree depth = n).
  • With max_depth limit: O(n * d * n_thresholds * max_depth). Typically much faster.
  • Prediction time: O(max_depth) per sample -- just a tree traversal.
  • Space: O(number_of_nodes), which is at most O(2^max_depth).

Interview Tips

Interview Tip

Common interview extensions: (1) Gini vs. entropy -- in practice they produce very similar trees; Gini is slightly faster (no log computation). (2) How to handle continuous vs. categorical features. (3) Pruning: pre-pruning (max_depth, min_samples) vs. post-pruning (grow full tree, then remove branches that do not improve validation accuracy). (4) How random forests build on this: bagging (random subsets of data) + random subspace (random subsets of features at each split). (5) Why decision trees are prone to overfitting and how ensemble methods fix this.

Quiz

Quiz — 3 Questions

Why is Gini impurity preferred over accuracy as a splitting criterion?

How do Random Forests improve upon single decision trees?

What is the time complexity of finding the best split at a single node with n samples and d features?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.