TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

65. Vanilla RNN + Backprop Through Time

Medium

Backpropagation Through Time (BPTT) is the algorithm for training RNNs. Given a sequence of inputs and a loss at each step, compute gradients of the loss with respect to all parameters by unrolling the computation graph through time.

For a vanilla RNN with step:

h_t = tanh(W_h @ h_{t-1} + W_x @ x_t + b)

Given a sequence of inputs and integer class targets at each step, compute the forward pass then backpropagate to get dW_h, dW_x, db.

Signature: def rnn_bptt(xs, h0, W_h, W_x, b, targets, W_out, b_out)

  • xs: (T, input_size) — input sequence
  • h0: (hidden_size,) — initial hidden state
  • W_h: (hidden_size, hidden_size)
  • W_x: (hidden_size, input_size)
  • b: (hidden_size,)
  • targets: (T,) — integer class labels at each step
  • W_out: (num_classes, hidden_size) — output projection
  • b_out: (num_classes,) — output bias
  • Returns: (dW_h, dW_x, db) — gradients w.r.t. RNN weights, averaged over sequence

The output layer uses cross-entropy loss averaged over all T steps. The backward pass flows gradients from each step's loss back through the hidden states to earlier steps.

Math

ht​y^​t​L​=tanh(Wh​ht−1​+Wx​xt​+b)=softmax(Wout​ht​+bout​)=−T1​t=1∑T​logy^​t​[ct​]​

Asked at

NumPy

import numpy as np

 

def rnn_bptt(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?