Akshay’s Gradient
ML Codingintermediate35 min

Online Learning Update

Algorithm: Online Learning with SGD

Implement online gradient descent -- learning from one sample at a time without storing the full dataset. Online learning is essential for systems that must adapt to changing data streams, like ad click prediction, fraud detection, and recommendation systems.

Problem Statement

Implement:

  1. OnlineLogisticRegression -- a logistic regression model that updates from one sample at a time
  2. OnlineSGDClassifier -- generalized online learning with support for different loss functions
  3. Track performance metrics on a streaming data source and demonstrate adaptation to distribution shift

Inputs: Streaming (feature_vector, label) pairs arriving one at a time.

Outputs: Predictions made before seeing each label, cumulative accuracy, and regret.

Key Concept

Online learning processes one sample at a time: predict, then observe the true label, then update the model. The key metric is regret: the difference between the model's cumulative loss and the best fixed model's loss in hindsight. Online gradient descent with learning rate eta = 1/sqrt(T) achieves O(sqrt(T)) regret, which means the average per-step regret goes to zero as T grows.

Interactive · Online Learning: Predict-Update Loop with Distribution Shift
┌──────────────────────────────────────────────────────────────────┐
│          Online Learning Loop with Distribution Shift             │
│                                                                  │
│   ┌─────────────────── Data Stream ──────────────────────┐       │
│   │  x1,y1  x2,y2  ...  x999,y999  │  x1000,y1000 ...  │       │
│   │  ◄─ Distribution A ──────────►  │  ◄── Distribution B│       │
│   └─────────────────────────────────┴────────────────────┘       │
│                    │                                              │
│                    ▼                                              │
│   For each (x, y) in stream:                                     │
│   ┌────────────────┐                                             │
│   │ 1. PREDICT     │  y_hat = sigmoid(w · x + b)                │
│   │   (before y)   │  pred = 1 if y_hat >= 0.5 else 0           │
│   └───────┬────────┘                                             │
│           ▼                                                      │
│   ┌────────────────┐                                             │
│   │ 2. OBSERVE y   │  loss = -y*log(p) - (1-y)*log(1-p)         │
│   └───────┬────────┘                                             │
│           ▼                                                      │
│   ┌────────────────┐                                             │
│   │ 3. UPDATE      │  w -= lr * (p - y) * x                     │
│   │   (one step)   │  b -= lr * (p - y)                          │
│   └───────┬────────┘                                             │
│           ▼                                                      │
│   Accuracy over time:                                            │
│   ▓▓░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓│▓░░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓                      │
│   ◄── learning ──────►│◄── shift ──►◄── adapted ──►             │
│                        ▲                                          │
│                   Distribution                                    │
│                     shift!                                        │
└──────────────────────────────────────────────────────────────────┘
Warning

A critical implementation detail: the numerically stable sigmoid is np.where(z >= 0, 1/(1+exp(-z)), exp(z)/(1+exp(z))). The naive formula 1/(1+exp(-z)) overflows when z is a large negative number because exp(-z) becomes enormous. Always use the two-branch formulation in production code.

Hints

Info
  1. Logistic regression: P(y=1|x) = sigmoid(w.x + b), loss = binary cross-entropy.
  2. Gradient: dw = (sigmoid(w.x + b) - y) * x, db = sigmoid(w.x + b) - y.
  3. Update: w -= lr * dw, b -= lr * db.
  4. Use a decaying learning rate: lr = initial_lr / sqrt(t) for convergence guarantees.
  5. Track a running accuracy with exponential moving average for monitoring.
  6. To demonstrate adaptation, shift the data distribution mid-stream and show the model adjusts.

Solution

import numpy as np
from typing import Tuple, List, Optional


def sigmoid(z: np.ndarray) -> np.ndarray:
    """Numerically stable sigmoid."""
    return np.where(z >= 0, 1 / (1 + np.exp(-z)), np.exp(z) / (1 + np.exp(z)))


class OnlineLogisticRegression:
    """Online logistic regression with SGD updates."""

    def __init__(
        self,
        n_features: int,
        learning_rate: float = 0.1,
        l2_reg: float = 1e-4,
        decay_lr: bool = True,
    ) -> None:
        self.w = np.zeros(n_features)
        self.b = 0.0
        self.lr0 = learning_rate
        self.l2_reg = l2_reg
        self.decay_lr = decay_lr
        self.t = 0  # step counter

    @property
    def lr(self) -> float:
        if self.decay_lr and self.t > 0:
            return self.lr0 / np.sqrt(self.t)
        return self.lr0

    def predict_proba(self, x: np.ndarray) -> float:
        """Predict P(y=1|x)."""
        logit = np.dot(self.w, x) + self.b
        return float(sigmoid(logit))

    def predict(self, x: np.ndarray) -> int:
        """Predict class label."""
        return 1 if self.predict_proba(x) >= 0.5 else 0

    def update(self, x: np.ndarray, y: int) -> float:
        """
        Update weights from a single sample. Returns the loss.
        """
        self.t += 1
        p = self.predict_proba(x)

        # Binary cross-entropy loss
        eps = 1e-12
        loss = -(y * np.log(p + eps) + (1 - y) * np.log(1 - p + eps))

        # Gradient of BCE w.r.t. parameters
        error = p - y  # derivative of BCE w.r.t. logit
        grad_w = error * x + self.l2_reg * self.w
        grad_b = error

        # SGD update
        lr = self.lr
        self.w -= lr * grad_w
        self.b -= lr * grad_b

        return float(loss)

    def partial_fit(self, x: np.ndarray, y: int) -> float:
        """Alias for update (scikit-learn convention)."""
        return self.update(x, y)


class OnlineMetrics:
    """Track online learning performance metrics."""

    def __init__(self, window_size: int = 100) -> None:
        self.correct = 0
        self.total = 0
        self.cumulative_loss = 0.0
        self.window_size = window_size
        self.recent_correct: List[int] = []

    def update(self, predicted: int, actual: int, loss: float) -> None:
        self.total += 1
        is_correct = int(predicted == actual)
        self.correct += is_correct
        self.cumulative_loss += loss

        self.recent_correct.append(is_correct)
        if len(self.recent_correct) > self.window_size:
            self.recent_correct.pop(0)

    @property
    def accuracy(self) -> float:
        return self.correct / max(1, self.total)

    @property
    def recent_accuracy(self) -> float:
        if not self.recent_correct:
            return 0.0
        return sum(self.recent_correct) / len(self.recent_correct)

    @property
    def average_loss(self) -> float:
        return self.cumulative_loss / max(1, self.total)


def generate_stream_with_shift(
    n_samples: int,
    n_features: int,
    shift_at: int,
    rng: np.random.RandomState,
) -> List[Tuple[np.ndarray, int]]:
    """Generate a data stream with a distribution shift."""
    stream = []

    # Phase 1: linear boundary at w1
    w1 = rng.randn(n_features)
    w1 /= np.linalg.norm(w1)

    # Phase 2: different boundary at w2
    w2 = rng.randn(n_features)
    w2 /= np.linalg.norm(w2)

    for i in range(n_samples):
        x = rng.randn(n_features)
        w = w1 if i < shift_at else w2
        # Label with noise
        logit = np.dot(w, x)
        prob = sigmoid(logit)
        y = int(rng.random() < prob)
        stream.append((x, y))

    return stream


# ---------- demo ----------
if __name__ == "__main__":
    rng = np.random.RandomState(42)
    n_features = 10
    n_samples = 2000
    shift_at = 1000

    # Generate stream with distribution shift
    stream = generate_stream_with_shift(n_samples, n_features, shift_at, rng)

    # Online learning
    model = OnlineLogisticRegression(n_features, learning_rate=0.5, decay_lr=False)
    metrics = OnlineMetrics(window_size=100)

    print("Step | Recent Acc | Cumulative Acc | Avg Loss")
    print("-" * 55)

    for i, (x, y) in enumerate(stream):
        # Predict BEFORE seeing the label
        pred = model.predict(x)

        # Observe the label and update
        loss = model.update(x, y)
        metrics.update(pred, y, loss)

        # Print periodically
        if (i + 1) % 200 == 0:
            marker = " <-- SHIFT" if i + 1 == shift_at else ""
            print(
                f"{i+1:4d} | {metrics.recent_accuracy:.4f}     "
                f"| {metrics.accuracy:.4f}          "
                f"| {metrics.average_loss:.4f}{marker}"
            )

    print(f"\nFinal accuracy: {metrics.accuracy:.4f}")
    print(f"Final recent accuracy: {metrics.recent_accuracy:.4f}")
    print(f"Model adapted to distribution shift: "
          f"{'Yes' if metrics.recent_accuracy > 0.6 else 'No'}")

Walkthrough

  1. Predict-then-update -- The online learning protocol: first make a prediction using the current model, then observe the true label, then update. This is the honest evaluation protocol -- the model never sees a label before predicting.

  2. Gradient computation -- For logistic regression with BCE loss, the gradient simplifies to (sigmoid(w.x) - y) * x. This is the same as the batch gradient, just computed on a single sample.

  3. Learning rate decay -- Using lr = lr0 / sqrt(t) ensures convergence: the learning rate is large enough early on for fast adaptation but small enough later for stability. This gives the O(sqrt(T)) regret bound.

  4. L2 regularization -- Adding l2_reg * w to the gradient is equivalent to weight decay. It prevents the weights from growing unboundedly in the online setting, which is important because there is no fixed dataset to regularize against.

  5. Distribution shift -- When the data-generating process changes at step 1000, the model's recent accuracy temporarily drops. With a non-decaying learning rate, the model adapts to the new distribution. With a decaying rate, adaptation is slower.

  6. Metrics tracking -- The windowed recent accuracy provides a responsive signal, while cumulative accuracy shows long-term performance. The average loss connects to the regret analysis.

Complexity Analysis

  • Time per update: O(d) where d = number of features. One dot product + one vector addition.
  • Space: O(d) for the weight vector. No data storage needed.
  • Regret bound: O(sqrt(T)) for online gradient descent, which means average regret is O(1/sqrt(T)).
  • Comparison: Batch training requires O(n * d) time and O(n * d) space to store the dataset.

Interview Tips

Interview Tip

Key discussion topics: (1) Online vs. batch learning: online handles streaming data, concept drift, and memory constraints. Batch typically achieves better final accuracy on stationary distributions. (2) Regret as the performance metric -- it compares to the best fixed model in hindsight, not the Bayes-optimal changing model. (3) Practical systems: Vowpal Wabbit, online A/B testing, bandits for ad selection. (4) Algorithms beyond SGD: AdaGrad (per-feature learning rates), FTRL (Follow The Regularized Leader, used at Google for ad CTR prediction). (5) The explore-exploit tradeoff: online learning naturally extends to multi-armed bandits and contextual bandits.

Quiz

Quiz — 3 Questions

Why is a non-decaying learning rate useful in online learning with distribution shift?

What is 'regret' in the context of online learning, and what regret bound does online gradient descent achieve?

Why is FTRL (Follow The Regularized Leader) preferred over plain SGD for online learning in ad click-through rate prediction?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.