TorchedUp
ProblemsPremium
TorchedUp
Vanilla RNN + Backprop Through TimeMedium
ProblemsPremium

Vanilla RNN + Backprop Through Time

Backpropagation Through Time (BPTT) is the algorithm for training RNNs. Given a sequence of inputs and a loss at each step, compute gradients of the loss with respect to all parameters by unrolling the computation graph through time.

For a vanilla RNN with step:

h_t = tanh(W_h @ h_{t-1} + W_x @ x_t + b)

Given a sequence of inputs and integer class targets at each step, compute the forward pass then backpropagate to get dW_h, dW_x, db.

Signature: def rnn_bptt(xs, h0, W_h, W_x, b, targets, W_out, b_out)

  • xs: (T, input_size) — input sequence
  • h0: (hidden_size,) — initial hidden state
  • W_h: (hidden_size, hidden_size)
  • W_x: (hidden_size, input_size)
  • b: (hidden_size,)
  • targets: (T,) — integer class labels at each step
  • W_out: (num_classes, hidden_size) — output projection
  • b_out: (num_classes,) — output bias
  • Returns: (dW_h, dW_x, db) — gradients w.r.t. RNN weights, averaged over sequence

The output layer uses cross-entropy loss averaged over all T steps. The backward pass flows gradients from each step's loss back through the hidden states to earlier steps.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○T=3 sequence, seed 42
○all-zero inputs (db non-zero, dW_x zero)
○T=1 single step (BPTT = regular backprop)🔒 Premium
○gradient matches central-difference numerical estimate
Advertisement