TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

77. LSTM Backprop Through Time

Hard

Backpropagating through an LSTM cell computes gradients w.r.t. all gate weights. Given the forward-pass inputs/outputs and an upstream gradient (dh_next, dc_next), compute the gradients w.r.t. x, h_prev, c_prev, and all weight matrices.

For a single LSTM step forward (recap):

gates = W_ih @ x + b_ih + W_hh @ h_prev + b_hh
i = sigmoid(gates[:H]),  f = sigmoid(gates[H:2H])
g = tanh(gates[2H:3H]),  o = sigmoid(gates[3H:])
c_next = f*c_prev + i*g
h_next = o * tanh(c_next)

Gate ordering in gates is [i, f, g, o] — match this convention when packing/unpacking gradients. Use the chain rule through this graph; the math reference summarises the gate-level gradients.

Signature: def lstm_bptt(x, h_prev, c_prev, W_ih, W_hh, b_ih, b_hh, dh_next, dc_next)

Returns: (dx, dh_prev, dc_prev, dW_ih, dW_hh, db) — db is the merged bias gradient of shape (4H,) (since b_ih and b_hh only enter the pre-activation as a sum, both share this same gradient).

Math

δo​δcδi​δg​∂Wih​∂L​​=δh⊙tanh(ct​)⊙σ′(o)=δh⊙o⊙(1−tanh2(ct​))+δcnext​=δc⊙g⊙σ′(i),δf​=δc⊙ct−1​⊙σ′(f)=δc⊙i⊙(1−g2)=δgates​⊗x,∂x∂L​=Wih⊤​δgates​​

Asked at

NumPy

import numpy as np

 

def lstm_bptt(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?