The implementation below has two bugs that cause vanishing gradients during BPTT. Find and fix them.
Signature: def rnn_step_buggy(x, h_prev, Wx, Wh, b)
import numpy as np
def rnn_step_buggy(x, h_prev, Wx, Wh, b):
pre_act = x @ Wx.T + h_prev @ Wh.T + b
h = 1 / (1 + np.exp(-pre_act)) # BUG 1
grad_scale = np.max(np.abs(h))
h = h / (grad_scale + 1e-8) # BUG 2
return h
Bug 1: Wrong activation function. RNNs use tanh (output range [-1, 1], gradient near 1 at 0), not sigmoid (output range [0,1], max gradient 0.25 — 4× smaller, accelerates vanishing).
Bug 2: Normalizing the hidden state by its max absolute value destroys the signal and makes gradients vanish even faster. RNNs do not normalize hidden states this way.
Fixed version:
def rnn_step_buggy(x, h_prev, Wx, Wh, b):
pre_act = x @ Wx.T + h_prev @ Wh.T + b
h = np.tanh(pre_act) # fix: tanh not sigmoid
return h # fix: no normalization
Implement the corrected rnn_step_buggy (same signature, correct behavior).
Asked at
Test Results