Akshay’s Gradient
ML Codingadvanced45 min

LoRA Layer

Implement a LoRA Layer

Implement Low-Rank Adaptation (LoRA) -- the most popular parameter-efficient fine-tuning method. LoRA freezes the original model weights and adds trainable low-rank decomposition matrices to each target layer.

Problem Statement

Implement a LoRALinear module that:

  1. Wraps an existing nn.Linear layer and freezes its weights
  2. Adds two small matrices A (d_in x r) and B (r x d_out) where r << min(d_in, d_out)
  3. Computes: output = frozen_linear(x) + (x @ A @ B) * scaling
  4. Only A and B are trainable; the original layer's parameters are frozen

Also implement a utility function to apply LoRA to all linear layers in a model.

Inputs: A pre-trained linear layer, rank r, scaling factor alpha.

Outputs: A LoRA-wrapped layer that behaves identically to the original at initialization (because B is initialized to zeros), but can be fine-tuned through A and B.

Key Concept

LoRA is based on the hypothesis that weight updates during fine-tuning have low intrinsic rank. Instead of updating the full weight matrix W (d_in x d_out), LoRA learns a low-rank update: W' = W + (alpha/r) * A @ B, where A is (d_in x r) and B is (r x d_out). This reduces trainable parameters from d_in * d_out to r * (d_in + d_out).

Interactive · LoRA: Low-Rank Weight Update
┌──────────────────────────────────────────────────────────────────┐
│                    LoRA: Low-Rank Adaptation                     │
│                                                                  │
│                  d_out                                            │
│              ┌───────────┐                                       │
│     d_in     │           │                                       │
│   ┌─────────▶│  W (frozen)│─────────────┐                        │
│   │          │  d_in×d_out│             │                        │
│   │          └───────────┘             │                        │
│   │                                    ▼                        │
│   │                              ┌──────────┐                   │
│ Input x                          │    +     │──▶ Output         │
│   │                              └──────────┘                   │
│   │                                    ▲                        │
│   │     r (rank, e.g. 8)              │                        │
│   │   ┌─────┐    ┌─────┐             │                        │
│   └──▶│  A  │───▶│  B  │─── × scaling ┘                        │
│       │d_in×r│   │r×d_out│                                       │
│       │(rand)│   │(zeros)│                                       │
│       └─────┘    └─────┘                                        │
│                                                                  │
│  Parameter comparison (d=4096, r=16):                            │
│  Full fine-tune: 4096 × 4096 = 16,777,216 params                │
│  LoRA:           4096 × 16 + 16 × 4096 = 131,072 params (0.8%)  │
│                                                                  │
│  After training, MERGE for free inference:                       │
│  W' = W + (alpha/r) × A @ B                                     │
│  (single matrix, no extra compute at inference time)             │
└──────────────────────────────────────────────────────────────────┘
Warning

A common implementation mistake is initializing both A and B to random values. If B is not initialized to zeros, the LoRA contribution is nonzero at initialization, which shifts the model away from the pre-trained weights and can cause unstable training. The correct pattern: A uses Kaiming/random normal initialization, B is initialized to all zeros, ensuring the initial model output is unchanged.

Hints

Info
  1. Freeze the original layer: set requires_grad = False for all its parameters.
  2. Initialize A with Kaiming uniform (or random normal) and B with zeros. This ensures the LoRA output is zero at initialization, so the model starts from the pre-trained weights.
  3. The scaling factor is alpha / r. This keeps the magnitude of the LoRA update consistent when you change the rank.
  4. During forward: output = original(x) + (x @ A @ B) * scaling.
  5. After training, you can merge: W_merged = W + scaling * A @ B for zero-overhead inference.

Solution

import torch
import torch.nn as nn
import math
from typing import Optional


class LoRALinear(nn.Module):
    """Linear layer with Low-Rank Adaptation."""

    def __init__(
        self,
        original_layer: nn.Linear,
        rank: int = 4,
        alpha: float = 1.0,
    ) -> None:
        super().__init__()
        self.original = original_layer
        self.rank = rank
        self.scaling = alpha / rank

        d_in = original_layer.in_features
        d_out = original_layer.out_features

        # Freeze original weights
        for param in self.original.parameters():
            param.requires_grad = False

        # LoRA matrices: A projects down, B projects up
        self.lora_A = nn.Parameter(torch.empty(d_in, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, d_out))

        # Initialize A with Kaiming uniform (same as nn.Linear default)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        # B is initialized to zero so the initial output equals the original

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original frozen forward pass
        original_out = self.original(x)
        # Low-rank adaptation: x @ A @ B * scaling
        lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
        return original_out + lora_out

    def merge_weights(self) -> nn.Linear:
        """Merge LoRA weights into the original layer for inference."""
        merged = nn.Linear(
            self.original.in_features,
            self.original.out_features,
            bias=self.original.bias is not None,
        )
        with torch.no_grad():
            merged.weight.copy_(
                self.original.weight + self.scaling * (self.lora_A @ self.lora_B).T
            )
            if self.original.bias is not None:
                merged.bias.copy_(self.original.bias)
        return merged


def apply_lora(
    model: nn.Module,
    rank: int = 4,
    alpha: float = 1.0,
    target_modules: Optional[list] = None,
) -> nn.Module:
    """Replace target linear layers with LoRA-wrapped versions."""
    target_modules = target_modules or ["q_proj", "v_proj"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Check if this layer's name matches any target
            if any(target in name for target in target_modules):
                parent_name = ".".join(name.split(".")[:-1])
                child_name = name.split(".")[-1]
                parent = model.get_submodule(parent_name) if parent_name else model
                setattr(parent, child_name, LoRALinear(module, rank, alpha))
    return model


# ---------- demo ----------
if __name__ == "__main__":
    torch.manual_seed(42)
    d_in, d_out, rank = 768, 768, 8

    # Create a "pre-trained" linear layer
    original = nn.Linear(d_in, d_out)
    lora_layer = LoRALinear(original, rank=rank, alpha=16.0)

    # Verify: at initialization, LoRA output equals original
    x = torch.randn(2, 10, d_in)
    with torch.no_grad():
        original_out = original(x)
        lora_out = lora_layer(x)
    print(f"Max diff at init: {(original_out - lora_out).abs().max():.2e}")  # ~0

    # Count parameters
    total_params = sum(p.numel() for p in lora_layer.parameters())
    trainable_params = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    print(f"\nTotal params:     {total_params:,}")
    print(f"Trainable (LoRA): {trainable_params:,}")
    print(f"Frozen (original):{frozen_params:,}")
    print(f"Trainable ratio:  {trainable_params / total_params:.2%}")

    # Simulate training
    optimizer = torch.optim.Adam(
        [p for p in lora_layer.parameters() if p.requires_grad], lr=1e-3
    )
    target = torch.randn(2, 10, d_out)
    for step in range(100):
        out = lora_layer(x)
        loss = ((out - target) ** 2).mean()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"\nAfter training - loss: {loss.item():.4f}")

    # Merge weights for inference
    merged = lora_layer.merge_weights()
    with torch.no_grad():
        lora_out = lora_layer(x)
        merged_out = merged(x)
    print(f"Max diff after merge: {(lora_out - merged_out).abs().max():.2e}")  # ~0

Walkthrough

  1. Freezing -- All parameters of the original linear layer have requires_grad = False. Only the LoRA matrices A and B are trainable. This means the optimizer only stores states for r * (d_in + d_out) parameters instead of d_in * d_out.

  2. Initialization -- B is initialized to zeros, making the initial LoRA contribution zero. The model starts exactly at the pre-trained weights. A uses Kaiming initialization so that gradients have reasonable magnitude from the first step.

  3. Scaling -- The factor alpha / r normalizes the LoRA update. When you increase rank, each individual component contributes less. The alpha hyperparameter controls the overall magnitude of the adaptation (analogous to a learning rate multiplier).

  4. Forward pass -- The computation x @ A @ B is factored: x @ A reduces from d_in to r dimensions, then @ B projects back to d_out. This is cheaper than a full d_in x d_out multiply when r is small.

  5. Merging -- After training, A @ B can be added to the original weight matrix, yielding a standard linear layer with no inference overhead. This is a key advantage of LoRA over adapter methods.

Complexity Analysis

  • Trainable parameters: r * (d_in + d_out) per layer vs. d_in * d_out for full fine-tuning.
  • For typical values (d=4096, r=16): 131K LoRA params vs. 16.7M full params per layer -- a 128x reduction.
  • Forward pass overhead: Two small matrix multiplies (B, T, d) @ (d, r) and (B, T, r) @ (r, d). For small r, this is negligible.
  • Memory savings: Optimizer states (Adam stores 2 extra copies) are only for LoRA params. For a 7B model, this reduces GPU memory from ~56GB to ~2GB.

Interview Tips

Interview Tip

Key points to nail: (1) Why initialize B to zero -- preserves pre-trained behavior at the start of fine-tuning. (2) The scaling factor alpha/r -- explain why it is needed when varying rank. (3) Mergeability -- LoRA can be absorbed into the base weights for zero-cost inference, unlike adapters. (4) Where to apply LoRA -- typically Q and V projections in attention layers (the original paper shows these are most effective). (5) Comparison with other PEFT methods: adapters add sequential bottleneck layers (slower inference), prefix tuning modifies the input (limited capacity).

Quiz

Quiz — 3 Questions

Why is the B matrix in LoRA initialized to zeros?

What is the purpose of the scaling factor alpha/r in LoRA?

What is the key advantage of LoRA over adapter-based fine-tuning methods?

Mark as Complete

Finished reviewing this topic? Mark it complete to track your progress.