Implement a LoRA Layer
Implement Low-Rank Adaptation (LoRA) -- the most popular parameter-efficient fine-tuning method. LoRA freezes the original model weights and adds trainable low-rank decomposition matrices to each target layer.
Problem Statement
Implement a LoRALinear module that:
- Wraps an existing
nn.Linearlayer and freezes its weights - Adds two small matrices
A(d_in x r) andB(r x d_out) wherer << min(d_in, d_out) - Computes:
output = frozen_linear(x) + (x @ A @ B) * scaling - Only
AandBare trainable; the original layer's parameters are frozen
Also implement a utility function to apply LoRA to all linear layers in a model.
Inputs: A pre-trained linear layer, rank r, scaling factor alpha.
Outputs: A LoRA-wrapped layer that behaves identically to the original at initialization (because B is initialized to zeros), but can be fine-tuned through A and B.
LoRA is based on the hypothesis that weight updates during fine-tuning have low intrinsic rank. Instead of updating the full weight matrix W (d_in x d_out), LoRA learns a low-rank update: W' = W + (alpha/r) * A @ B, where A is (d_in x r) and B is (r x d_out). This reduces trainable parameters from d_in * d_out to r * (d_in + d_out).
┌──────────────────────────────────────────────────────────────────┐
│ LoRA: Low-Rank Adaptation │
│ │
│ d_out │
│ ┌───────────┐ │
│ d_in │ │ │
│ ┌─────────▶│ W (frozen)│─────────────┐ │
│ │ │ d_in×d_out│ │ │
│ │ └───────────┘ │ │
│ │ ▼ │
│ │ ┌──────────┐ │
│ Input x │ + │──▶ Output │
│ │ └──────────┘ │
│ │ ▲ │
│ │ r (rank, e.g. 8) │ │
│ │ ┌─────┐ ┌─────┐ │ │
│ └──▶│ A │───▶│ B │─── × scaling ┘ │
│ │d_in×r│ │r×d_out│ │
│ │(rand)│ │(zeros)│ │
│ └─────┘ └─────┘ │
│ │
│ Parameter comparison (d=4096, r=16): │
│ Full fine-tune: 4096 × 4096 = 16,777,216 params │
│ LoRA: 4096 × 16 + 16 × 4096 = 131,072 params (0.8%) │
│ │
│ After training, MERGE for free inference: │
│ W' = W + (alpha/r) × A @ B │
│ (single matrix, no extra compute at inference time) │
└──────────────────────────────────────────────────────────────────┘
A common implementation mistake is initializing both A and B to random values. If B is not initialized to zeros, the LoRA contribution is nonzero at initialization, which shifts the model away from the pre-trained weights and can cause unstable training. The correct pattern: A uses Kaiming/random normal initialization, B is initialized to all zeros, ensuring the initial model output is unchanged.
Hints
- Freeze the original layer: set
requires_grad = Falsefor all its parameters. - Initialize
Awith Kaiming uniform (or random normal) andBwith zeros. This ensures the LoRA output is zero at initialization, so the model starts from the pre-trained weights. - The scaling factor is
alpha / r. This keeps the magnitude of the LoRA update consistent when you change the rank. - During forward:
output = original(x) + (x @ A @ B) * scaling. - After training, you can merge:
W_merged = W + scaling * A @ Bfor zero-overhead inference.
Solution
import torch
import torch.nn as nn
import math
from typing import Optional
class LoRALinear(nn.Module):
"""Linear layer with Low-Rank Adaptation."""
def __init__(
self,
original_layer: nn.Linear,
rank: int = 4,
alpha: float = 1.0,
) -> None:
super().__init__()
self.original = original_layer
self.rank = rank
self.scaling = alpha / rank
d_in = original_layer.in_features
d_out = original_layer.out_features
# Freeze original weights
for param in self.original.parameters():
param.requires_grad = False
# LoRA matrices: A projects down, B projects up
self.lora_A = nn.Parameter(torch.empty(d_in, rank))
self.lora_B = nn.Parameter(torch.zeros(rank, d_out))
# Initialize A with Kaiming uniform (same as nn.Linear default)
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
# B is initialized to zero so the initial output equals the original
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Original frozen forward pass
original_out = self.original(x)
# Low-rank adaptation: x @ A @ B * scaling
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
return original_out + lora_out
def merge_weights(self) -> nn.Linear:
"""Merge LoRA weights into the original layer for inference."""
merged = nn.Linear(
self.original.in_features,
self.original.out_features,
bias=self.original.bias is not None,
)
with torch.no_grad():
merged.weight.copy_(
self.original.weight + self.scaling * (self.lora_A @ self.lora_B).T
)
if self.original.bias is not None:
merged.bias.copy_(self.original.bias)
return merged
def apply_lora(
model: nn.Module,
rank: int = 4,
alpha: float = 1.0,
target_modules: Optional[list] = None,
) -> nn.Module:
"""Replace target linear layers with LoRA-wrapped versions."""
target_modules = target_modules or ["q_proj", "v_proj"]
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Check if this layer's name matches any target
if any(target in name for target in target_modules):
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, LoRALinear(module, rank, alpha))
return model
# ---------- demo ----------
if __name__ == "__main__":
torch.manual_seed(42)
d_in, d_out, rank = 768, 768, 8
# Create a "pre-trained" linear layer
original = nn.Linear(d_in, d_out)
lora_layer = LoRALinear(original, rank=rank, alpha=16.0)
# Verify: at initialization, LoRA output equals original
x = torch.randn(2, 10, d_in)
with torch.no_grad():
original_out = original(x)
lora_out = lora_layer(x)
print(f"Max diff at init: {(original_out - lora_out).abs().max():.2e}") # ~0
# Count parameters
total_params = sum(p.numel() for p in lora_layer.parameters())
trainable_params = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params
print(f"\nTotal params: {total_params:,}")
print(f"Trainable (LoRA): {trainable_params:,}")
print(f"Frozen (original):{frozen_params:,}")
print(f"Trainable ratio: {trainable_params / total_params:.2%}")
# Simulate training
optimizer = torch.optim.Adam(
[p for p in lora_layer.parameters() if p.requires_grad], lr=1e-3
)
target = torch.randn(2, 10, d_out)
for step in range(100):
out = lora_layer(x)
loss = ((out - target) ** 2).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"\nAfter training - loss: {loss.item():.4f}")
# Merge weights for inference
merged = lora_layer.merge_weights()
with torch.no_grad():
lora_out = lora_layer(x)
merged_out = merged(x)
print(f"Max diff after merge: {(lora_out - merged_out).abs().max():.2e}") # ~0
Walkthrough
-
Freezing -- All parameters of the original linear layer have
requires_grad = False. Only the LoRA matricesAandBare trainable. This means the optimizer only stores states forr * (d_in + d_out)parameters instead ofd_in * d_out. -
Initialization --
Bis initialized to zeros, making the initial LoRA contribution zero. The model starts exactly at the pre-trained weights.Auses Kaiming initialization so that gradients have reasonable magnitude from the first step. -
Scaling -- The factor
alpha / rnormalizes the LoRA update. When you increase rank, each individual component contributes less. Thealphahyperparameter controls the overall magnitude of the adaptation (analogous to a learning rate multiplier). -
Forward pass -- The computation
x @ A @ Bis factored:x @ Areduces fromd_intordimensions, then@ Bprojects back tod_out. This is cheaper than a fulld_in x d_outmultiply whenris small. -
Merging -- After training,
A @ Bcan be added to the original weight matrix, yielding a standard linear layer with no inference overhead. This is a key advantage of LoRA over adapter methods.
Complexity Analysis
- Trainable parameters:
r * (d_in + d_out)per layer vs.d_in * d_outfor full fine-tuning. - For typical values (d=4096, r=16): 131K LoRA params vs. 16.7M full params per layer -- a 128x reduction.
- Forward pass overhead: Two small matrix multiplies
(B, T, d) @ (d, r)and(B, T, r) @ (r, d). For smallr, this is negligible. - Memory savings: Optimizer states (Adam stores 2 extra copies) are only for LoRA params. For a 7B model, this reduces GPU memory from ~56GB to ~2GB.
Interview Tips
Key points to nail: (1) Why initialize B to zero -- preserves pre-trained behavior at the start of fine-tuning. (2) The scaling factor alpha/r -- explain why it is needed when varying rank. (3) Mergeability -- LoRA can be absorbed into the base weights for zero-cost inference, unlike adapters. (4) Where to apply LoRA -- typically Q and V projections in attention layers (the original paper shows these are most effective). (5) Comparison with other PEFT methods: adapters add sequential bottleneck layers (slower inference), prefix tuning modifies the input (limited capacity).
Quiz
Quiz — 3 Questions
Why is the B matrix in LoRA initialized to zeros?
What is the purpose of the scaling factor alpha/r in LoRA?
What is the key advantage of LoRA over adapter-based fine-tuning methods?