Backpropagating through an LSTM cell computes gradients w.r.t. all gate weights. Given the forward-pass inputs/outputs and an upstream gradient (dh_next, dc_next), compute the gradients w.r.t. x, h_prev, c_prev, and all weight matrices.
For a single LSTM step backward:
gates = W_ih @ x + b_ih + W_hh @ h_prev + b_hh
i = sigmoid(gates[:H]), f = sigmoid(gates[H:2H])
g = tanh(gates[2H:3H]), o = sigmoid(gates[3H:])
c_next = f*c_prev + i*g
h_next = o * tanh(c_next)
Backward pass:
do = dh * tanh(c_next) * sigmoid'(o_gate)
dc = dh * o * tanh'(c_next) + dc_next
di = dc * g * sigmoid'(i_gate)
df = dc * c_prev * sigmoid'(f_gate)
dg = dc * i * (1 - g²)
dgates = concat([di, df, dg, do])
dW_ih += outer(dgates, x); dW_hh += outer(dgates, h_prev)
dx = W_ih.T @ dgates
dh_prev = W_hh.T @ dgates
dc_prev = dc * f
Signature: def lstm_bptt(x, h_prev, c_prev, W_ih, W_hh, b_ih, b_hh, dh_next, dc_next)
Returns: (dx, dh_prev, dc_prev, dW_ih, dW_hh, db) — db = dgates (merged bias gradient of shape (4H,))
Math
Asked at
Test Results