TorchedUp
ProblemsPremium
TorchedUp
LoRA BackwardHard
ProblemsPremium

LoRA Backward Pass

Given the LoRA branch y_lora = (alpha/r) * x @ A.T @ B.T, compute the gradients of the loss w.r.t. A and B (the only trainable matrices — W is frozen).

Signature: def lora_backward(x: np.ndarray, A: np.ndarray, B: np.ndarray, dL_dy: np.ndarray, alpha: float, r: int) -> tuple

Shapes:

  • x: (batch, in)
  • A: (r, in)
  • B: (out, r)
  • dL_dy: (batch, out) — upstream gradient w.r.t. the LoRA output

Returns: (dA, dB) with shapes (r, in) and (out, r).

Hint: let h = x @ A.T (shape (batch, r)). Then

  • dB = (alpha/r) * dL_dy.T @ h
  • dA = (alpha/r) * (dL_dy @ B).T @ x

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○identity B, simple
○r=2 batch=2
○scaling alpha/r🔒 Premium
○gradient matches central-difference numerical estimate
Advertisement