Backpropagation Through Time (BPTT) is the algorithm for training RNNs. Given a sequence of inputs and a loss at each step, compute gradients of the loss with respect to all parameters by unrolling the computation graph through time.
For a vanilla RNN with step:
h_t = tanh(W_h @ h_{t-1} + W_x @ x_t + b)
Given a sequence of inputs and integer class targets at each step, compute the forward pass then backpropagate to get dW_h, dW_x, db.
Signature: def rnn_bptt(xs, h0, W_h, W_x, b, targets, W_out, b_out)
xs: (T, input_size) — input sequenceh0: (hidden_size,) — initial hidden stateW_h: (hidden_size, hidden_size)W_x: (hidden_size, input_size)b: (hidden_size,)targets: (T,) — integer class labels at each stepW_out: (num_classes, hidden_size) — output projectionb_out: (num_classes,) — output bias(dW_h, dW_x, db) — gradients w.r.t. RNN weights, averaged over sequenceThe output layer uses cross-entropy loss averaged over all T steps. The backward pass flows gradients from each step's loss back through the hidden states to earlier steps.
Math
Asked at
Test Results