Algorithm: Decision Tree from Scratch
Implement a CART (Classification and Regression Trees) decision tree classifier from scratch. Decision trees are the building blocks of random forests and gradient boosting -- the most successful algorithms for tabular data.
Problem Statement
Implement a DecisionTree class that:
- Recursively splits the data by finding the feature and threshold that maximize information gain (or minimize Gini impurity)
- Stops splitting when a leaf is pure, max depth is reached, or a node has too few samples
- Predicts the class by traversing the tree to a leaf and returning the majority class
Inputs: Feature matrix X of shape (n_samples, n_features), labels y of shape (n_samples,).
Outputs: A tree structure that can classify new samples.
CART builds a binary tree by greedily finding the best split at each node. For classification, the best split maximizes information gain: IG = H(parent) - [n_left/n * H(left) + n_right/n * H(right)], where H is the Gini impurity 1 - sum(p_i^2) or entropy -sum(p_i * log(p_i)). This greedy approach does not find the globally optimal tree but is efficient and works well in practice.
┌──────────────────────────────────────────────────────────────────┐
│ Decision Tree Construction (CART) │
│ │
│ Root: 300 samples │
│ Gini = 0.50 │
│ ┌─────────────────┐ │
│ │ Feature 0 │ │
│ │ <= 0.42? │ │
│ └────┬───────┬────┘ │
│ yes ╱ ╲ no │
│ ╱ ╲ │
│ ┌───────────────┐ ┌───────────────┐ │
│ │ 180 samples │ │ 120 samples │ │
│ │ Gini = 0.32 │ │ Gini = 0.38 │ │
│ │ Feature 1 │ │ Feature 0 │ │
│ │ <= -0.15? │ │ <= 1.20? │ │
│ └──┬────────┬───┘ └──┬────────┬───┘ │
│ ╱ ╲ ╱ ╲ │
│ ╱ ╲ ╱ ╲ │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────────┐ │
│ │ LEAF │ │ LEAF │ │ LEAF │ │ LEAF │ │
│ │ Cls 0│ │ Cls 1│ │ Cls 1│ │ Cls 0 │ │
│ │Gini=0│ │Gini=0│ │G=0.1 │ │ Gini=0 │ │
│ └──────┘ └──────┘ └──────┘ └──────────┘ │
│ │
│ Stopping conditions: pure node, max_depth, min_samples │
│ Each split greedily maximizes information gain │
└──────────────────────────────────────────────────────────────────┘
Without depth limits or pruning, a decision tree will create one leaf per training sample, perfectly memorizing the data and achieving 0% training error but terrible generalization. Always set max_depth and min_samples_split to control tree complexity. In interviews, be prepared to discuss pre-pruning (stopping criteria) vs. post-pruning (grow full tree, then cut back).
Hints
- For each node, iterate over all features and all unique thresholds (midpoints between sorted values).
- For each candidate split, compute the weighted impurity of the two child nodes.
- Pick the split with the lowest weighted impurity (highest information gain).
- Recursively apply to left and right children.
- A leaf node stores the majority class (mode of labels).
- Use Gini impurity:
1 - sum(p_i^2)wherep_iis the fraction of classi.
Solution
import numpy as np
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class Node:
"""A node in the decision tree."""
feature_idx: Optional[int] = None # split feature index
threshold: Optional[float] = None # split threshold
left: Optional["Node"] = None # left child (feature <= threshold)
right: Optional["Node"] = None # right child (feature > threshold)
value: Optional[int] = None # leaf prediction (majority class)
class DecisionTree:
"""CART decision tree classifier."""
def __init__(self, max_depth: int = 10, min_samples_split: int = 2) -> None:
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.root: Optional[Node] = None
def _gini(self, y: np.ndarray) -> float:
"""Compute Gini impurity."""
if len(y) == 0:
return 0.0
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1.0 - np.sum(probs ** 2)
def _information_gain(
self, y: np.ndarray, y_left: np.ndarray, y_right: np.ndarray
) -> float:
"""Compute information gain from a split."""
n = len(y)
if n == 0:
return 0.0
parent_gini = self._gini(y)
weighted_child_gini = (
len(y_left) / n * self._gini(y_left)
+ len(y_right) / n * self._gini(y_right)
)
return parent_gini - weighted_child_gini
def _best_split(
self, X: np.ndarray, y: np.ndarray
) -> Tuple[Optional[int], Optional[float], float]:
"""Find the best feature and threshold to split on."""
best_gain = 0.0
best_feature = None
best_threshold = None
n_samples, n_features = X.shape
for feature_idx in range(n_features):
values = X[:, feature_idx]
# Use midpoints between sorted unique values as candidate thresholds
unique_vals = np.unique(values)
if len(unique_vals) <= 1:
continue
thresholds = (unique_vals[:-1] + unique_vals[1:]) / 2.0
for threshold in thresholds:
left_mask = values <= threshold
right_mask = ~left_mask
if left_mask.sum() == 0 or right_mask.sum() == 0:
continue
gain = self._information_gain(y, y[left_mask], y[right_mask])
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
return best_feature, best_threshold, best_gain
def _build_tree(self, X: np.ndarray, y: np.ndarray, depth: int) -> Node:
"""Recursively build the decision tree."""
# Leaf conditions
n_classes = len(np.unique(y))
if (
n_classes == 1 # pure node
or depth >= self.max_depth # max depth reached
or len(y) < self.min_samples_split # too few samples
):
# Return leaf with majority class
return Node(value=int(np.bincount(y).argmax()))
# Find the best split
feature_idx, threshold, gain = self._best_split(X, y)
if feature_idx is None or gain <= 0:
return Node(value=int(np.bincount(y).argmax()))
# Split the data
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
# Recursively build children
left_child = self._build_tree(X[left_mask], y[left_mask], depth + 1)
right_child = self._build_tree(X[right_mask], y[right_mask], depth + 1)
return Node(
feature_idx=feature_idx,
threshold=threshold,
left=left_child,
right=right_child,
)
def fit(self, X: np.ndarray, y: np.ndarray) -> "DecisionTree":
"""Build the decision tree from training data."""
self.root = self._build_tree(X, y.astype(int), depth=0)
return self
def _predict_one(self, x: np.ndarray, node: Node) -> int:
"""Traverse the tree to predict a single sample."""
if node.value is not None:
return node.value
if x[node.feature_idx] <= node.threshold:
return self._predict_one(x, node.left)
else:
return self._predict_one(x, node.right)
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict class labels for samples in X."""
return np.array([self._predict_one(x, self.root) for x in X])
# ---------- demo ----------
if __name__ == "__main__":
np.random.seed(42)
# Generate a simple 2D classification dataset
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=300, noise=0.2, random_state=42)
# Split
train_X, test_X = X[:200], X[200:]
train_y, test_y = y[:200], y[200:]
tree = DecisionTree(max_depth=5, min_samples_split=5)
tree.fit(train_X, train_y)
train_preds = tree.predict(train_X)
test_preds = tree.predict(test_X)
train_acc = (train_preds == train_y).mean()
test_acc = (test_preds == test_y).mean()
print(f"Train accuracy: {train_acc:.4f}")
print(f"Test accuracy: {test_acc:.4f}")
# Print tree structure
def print_tree(node: Node, depth: int = 0) -> None:
indent = " " * depth
if node.value is not None:
print(f"{indent}Leaf: class={node.value}")
else:
print(f"{indent}Feature {node.feature_idx} <= {node.threshold:.3f}")
print_tree(node.left, depth + 1)
print_tree(node.right, depth + 1)
print("\nTree structure:")
print_tree(tree.root)
Walkthrough
-
Gini impurity -- Measures the probability of misclassifying a randomly chosen sample if it were labeled according to the class distribution. Gini = 0 means perfect purity (all one class). Maximum Gini for binary classification is 0.5 (50/50 split).
-
Threshold selection -- For each feature, we sort the unique values and use midpoints as candidate thresholds. This is more efficient than testing every data point and avoids boundary issues.
-
Information gain -- The reduction in impurity from a split. We greedily pick the split that maximizes this. The weighted sum ensures we account for the sizes of the child nodes.
-
Recursive building -- The tree grows depth-first. Each internal node stores a feature index and threshold. Each leaf stores the majority class of its training samples.
-
Stopping criteria -- The tree stops growing when: a node is pure (all same class), maximum depth is reached, or too few samples remain. Without these limits, the tree would overfit by creating a leaf for every training sample.
-
Prediction -- Traverse from root to leaf by comparing the sample's feature value against each node's threshold. Left for
<=, right for>.
Complexity Analysis
- Training time: O(n * d * n * log(n)) per node for finding the best split (n thresholds per feature, d features). Total: O(n^2 * d * log(n)) in the worst case (tree depth = n).
- With max_depth limit: O(n * d * n_thresholds * max_depth). Typically much faster.
- Prediction time: O(max_depth) per sample -- just a tree traversal.
- Space: O(number_of_nodes), which is at most O(2^max_depth).
Interview Tips
Common interview extensions: (1) Gini vs. entropy -- in practice they produce very similar trees; Gini is slightly faster (no log computation). (2) How to handle continuous vs. categorical features. (3) Pruning: pre-pruning (max_depth, min_samples) vs. post-pruning (grow full tree, then remove branches that do not improve validation accuracy). (4) How random forests build on this: bagging (random subsets of data) + random subspace (random subsets of features at each split). (5) Why decision trees are prone to overfitting and how ensemble methods fix this.
Quiz
Quiz — 3 Questions
Why is Gini impurity preferred over accuracy as a splitting criterion?
How do Random Forests improve upon single decision trees?
What is the time complexity of finding the best split at a single node with n samples and d features?