Algorithm: Online Learning with SGD
Implement online gradient descent -- learning from one sample at a time without storing the full dataset. Online learning is essential for systems that must adapt to changing data streams, like ad click prediction, fraud detection, and recommendation systems.
Problem Statement
Implement:
OnlineLogisticRegression-- a logistic regression model that updates from one sample at a timeOnlineSGDClassifier-- generalized online learning with support for different loss functions- Track performance metrics on a streaming data source and demonstrate adaptation to distribution shift
Inputs: Streaming (feature_vector, label) pairs arriving one at a time.
Outputs: Predictions made before seeing each label, cumulative accuracy, and regret.
Online learning processes one sample at a time: predict, then observe the true label, then update the model. The key metric is regret: the difference between the model's cumulative loss and the best fixed model's loss in hindsight. Online gradient descent with learning rate eta = 1/sqrt(T) achieves O(sqrt(T)) regret, which means the average per-step regret goes to zero as T grows.
┌──────────────────────────────────────────────────────────────────┐
│ Online Learning Loop with Distribution Shift │
│ │
│ ┌─────────────────── Data Stream ──────────────────────┐ │
│ │ x1,y1 x2,y2 ... x999,y999 │ x1000,y1000 ... │ │
│ │ ◄─ Distribution A ──────────► │ ◄── Distribution B│ │
│ └─────────────────────────────────┴────────────────────┘ │
│ │ │
│ ▼ │
│ For each (x, y) in stream: │
│ ┌────────────────┐ │
│ │ 1. PREDICT │ y_hat = sigmoid(w · x + b) │
│ │ (before y) │ pred = 1 if y_hat >= 0.5 else 0 │
│ └───────┬────────┘ │
│ ▼ │
│ ┌────────────────┐ │
│ │ 2. OBSERVE y │ loss = -y*log(p) - (1-y)*log(1-p) │
│ └───────┬────────┘ │
│ ▼ │
│ ┌────────────────┐ │
│ │ 3. UPDATE │ w -= lr * (p - y) * x │
│ │ (one step) │ b -= lr * (p - y) │
│ └───────┬────────┘ │
│ ▼ │
│ Accuracy over time: │
│ ▓▓░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓│▓░░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓ │
│ ◄── learning ──────►│◄── shift ──►◄── adapted ──► │
│ ▲ │
│ Distribution │
│ shift! │
└──────────────────────────────────────────────────────────────────┘
A critical implementation detail: the numerically stable sigmoid is np.where(z >= 0, 1/(1+exp(-z)), exp(z)/(1+exp(z))). The naive formula 1/(1+exp(-z)) overflows when z is a large negative number because exp(-z) becomes enormous. Always use the two-branch formulation in production code.
Hints
- Logistic regression:
P(y=1|x) = sigmoid(w.x + b), loss = binary cross-entropy. - Gradient:
dw = (sigmoid(w.x + b) - y) * x,db = sigmoid(w.x + b) - y. - Update:
w -= lr * dw,b -= lr * db. - Use a decaying learning rate:
lr = initial_lr / sqrt(t)for convergence guarantees. - Track a running accuracy with exponential moving average for monitoring.
- To demonstrate adaptation, shift the data distribution mid-stream and show the model adjusts.
Solution
import numpy as np
from typing import Tuple, List, Optional
def sigmoid(z: np.ndarray) -> np.ndarray:
"""Numerically stable sigmoid."""
return np.where(z >= 0, 1 / (1 + np.exp(-z)), np.exp(z) / (1 + np.exp(z)))
class OnlineLogisticRegression:
"""Online logistic regression with SGD updates."""
def __init__(
self,
n_features: int,
learning_rate: float = 0.1,
l2_reg: float = 1e-4,
decay_lr: bool = True,
) -> None:
self.w = np.zeros(n_features)
self.b = 0.0
self.lr0 = learning_rate
self.l2_reg = l2_reg
self.decay_lr = decay_lr
self.t = 0 # step counter
@property
def lr(self) -> float:
if self.decay_lr and self.t > 0:
return self.lr0 / np.sqrt(self.t)
return self.lr0
def predict_proba(self, x: np.ndarray) -> float:
"""Predict P(y=1|x)."""
logit = np.dot(self.w, x) + self.b
return float(sigmoid(logit))
def predict(self, x: np.ndarray) -> int:
"""Predict class label."""
return 1 if self.predict_proba(x) >= 0.5 else 0
def update(self, x: np.ndarray, y: int) -> float:
"""
Update weights from a single sample. Returns the loss.
"""
self.t += 1
p = self.predict_proba(x)
# Binary cross-entropy loss
eps = 1e-12
loss = -(y * np.log(p + eps) + (1 - y) * np.log(1 - p + eps))
# Gradient of BCE w.r.t. parameters
error = p - y # derivative of BCE w.r.t. logit
grad_w = error * x + self.l2_reg * self.w
grad_b = error
# SGD update
lr = self.lr
self.w -= lr * grad_w
self.b -= lr * grad_b
return float(loss)
def partial_fit(self, x: np.ndarray, y: int) -> float:
"""Alias for update (scikit-learn convention)."""
return self.update(x, y)
class OnlineMetrics:
"""Track online learning performance metrics."""
def __init__(self, window_size: int = 100) -> None:
self.correct = 0
self.total = 0
self.cumulative_loss = 0.0
self.window_size = window_size
self.recent_correct: List[int] = []
def update(self, predicted: int, actual: int, loss: float) -> None:
self.total += 1
is_correct = int(predicted == actual)
self.correct += is_correct
self.cumulative_loss += loss
self.recent_correct.append(is_correct)
if len(self.recent_correct) > self.window_size:
self.recent_correct.pop(0)
@property
def accuracy(self) -> float:
return self.correct / max(1, self.total)
@property
def recent_accuracy(self) -> float:
if not self.recent_correct:
return 0.0
return sum(self.recent_correct) / len(self.recent_correct)
@property
def average_loss(self) -> float:
return self.cumulative_loss / max(1, self.total)
def generate_stream_with_shift(
n_samples: int,
n_features: int,
shift_at: int,
rng: np.random.RandomState,
) -> List[Tuple[np.ndarray, int]]:
"""Generate a data stream with a distribution shift."""
stream = []
# Phase 1: linear boundary at w1
w1 = rng.randn(n_features)
w1 /= np.linalg.norm(w1)
# Phase 2: different boundary at w2
w2 = rng.randn(n_features)
w2 /= np.linalg.norm(w2)
for i in range(n_samples):
x = rng.randn(n_features)
w = w1 if i < shift_at else w2
# Label with noise
logit = np.dot(w, x)
prob = sigmoid(logit)
y = int(rng.random() < prob)
stream.append((x, y))
return stream
# ---------- demo ----------
if __name__ == "__main__":
rng = np.random.RandomState(42)
n_features = 10
n_samples = 2000
shift_at = 1000
# Generate stream with distribution shift
stream = generate_stream_with_shift(n_samples, n_features, shift_at, rng)
# Online learning
model = OnlineLogisticRegression(n_features, learning_rate=0.5, decay_lr=False)
metrics = OnlineMetrics(window_size=100)
print("Step | Recent Acc | Cumulative Acc | Avg Loss")
print("-" * 55)
for i, (x, y) in enumerate(stream):
# Predict BEFORE seeing the label
pred = model.predict(x)
# Observe the label and update
loss = model.update(x, y)
metrics.update(pred, y, loss)
# Print periodically
if (i + 1) % 200 == 0:
marker = " <-- SHIFT" if i + 1 == shift_at else ""
print(
f"{i+1:4d} | {metrics.recent_accuracy:.4f} "
f"| {metrics.accuracy:.4f} "
f"| {metrics.average_loss:.4f}{marker}"
)
print(f"\nFinal accuracy: {metrics.accuracy:.4f}")
print(f"Final recent accuracy: {metrics.recent_accuracy:.4f}")
print(f"Model adapted to distribution shift: "
f"{'Yes' if metrics.recent_accuracy > 0.6 else 'No'}")
Walkthrough
-
Predict-then-update -- The online learning protocol: first make a prediction using the current model, then observe the true label, then update. This is the honest evaluation protocol -- the model never sees a label before predicting.
-
Gradient computation -- For logistic regression with BCE loss, the gradient simplifies to
(sigmoid(w.x) - y) * x. This is the same as the batch gradient, just computed on a single sample. -
Learning rate decay -- Using
lr = lr0 / sqrt(t)ensures convergence: the learning rate is large enough early on for fast adaptation but small enough later for stability. This gives the O(sqrt(T)) regret bound. -
L2 regularization -- Adding
l2_reg * wto the gradient is equivalent to weight decay. It prevents the weights from growing unboundedly in the online setting, which is important because there is no fixed dataset to regularize against. -
Distribution shift -- When the data-generating process changes at step 1000, the model's recent accuracy temporarily drops. With a non-decaying learning rate, the model adapts to the new distribution. With a decaying rate, adaptation is slower.
-
Metrics tracking -- The windowed recent accuracy provides a responsive signal, while cumulative accuracy shows long-term performance. The average loss connects to the regret analysis.
Complexity Analysis
- Time per update: O(d) where d = number of features. One dot product + one vector addition.
- Space: O(d) for the weight vector. No data storage needed.
- Regret bound: O(sqrt(T)) for online gradient descent, which means average regret is O(1/sqrt(T)).
- Comparison: Batch training requires O(n * d) time and O(n * d) space to store the dataset.
Interview Tips
Key discussion topics: (1) Online vs. batch learning: online handles streaming data, concept drift, and memory constraints. Batch typically achieves better final accuracy on stationary distributions. (2) Regret as the performance metric -- it compares to the best fixed model in hindsight, not the Bayes-optimal changing model. (3) Practical systems: Vowpal Wabbit, online A/B testing, bandits for ad selection. (4) Algorithms beyond SGD: AdaGrad (per-feature learning rates), FTRL (Follow The Regularized Leader, used at Google for ad CTR prediction). (5) The explore-exploit tradeoff: online learning naturally extends to multi-armed bandits and contextual bandits.
Quiz
Quiz — 3 Questions
Why is a non-decaying learning rate useful in online learning with distribution shift?
What is 'regret' in the context of online learning, and what regret bound does online gradient descent achieve?
Why is FTRL (Follow The Regularized Leader) preferred over plain SGD for online learning in ad click-through rate prediction?