TorchedUp
ProblemsPremium
TorchedUp
LSTM Backprop Through TimeHard
ProblemsPremium

LSTM Backprop Through Time

Backpropagating through an LSTM cell computes gradients w.r.t. all gate weights. Given the forward-pass inputs/outputs and an upstream gradient (dh_next, dc_next), compute the gradients w.r.t. x, h_prev, c_prev, and all weight matrices.

For a single LSTM step backward:

gates = W_ih @ x + b_ih + W_hh @ h_prev + b_hh
i = sigmoid(gates[:H]),  f = sigmoid(gates[H:2H])
g = tanh(gates[2H:3H]),  o = sigmoid(gates[3H:])
c_next = f*c_prev + i*g
h_next = o * tanh(c_next)

Backward pass:

do   = dh * tanh(c_next) * sigmoid'(o_gate)
dc   = dh * o * tanh'(c_next) + dc_next
di   = dc * g * sigmoid'(i_gate)
df   = dc * c_prev * sigmoid'(f_gate)
dg   = dc * i * (1 - g²)
dgates = concat([di, df, dg, do])
dW_ih += outer(dgates, x);  dW_hh += outer(dgates, h_prev)
dx     = W_ih.T @ dgates
dh_prev = W_hh.T @ dgates
dc_prev = dc * f

Signature: def lstm_bptt(x, h_prev, c_prev, W_ih, W_hh, b_ih, b_hh, dh_next, dc_next)

Returns: (dx, dh_prev, dc_prev, dW_ih, dW_hh, db) — db = dgates (merged bias gradient of shape (4H,))

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○seed 0, H=2, I=3, zero h_prev/c_prev
○seed 7, H=3, I=4, non-zero h_prev/c_prev/dc_next
○seed 0, H=2, I=3, only dc_next non-zero🔒 Premium
○gradient matches central-difference numerical estimate
Advertisement