TorchedUp
ProblemsPremium
TorchedUp
Debug: RNN Vanishing GradientMedium
ProblemsPremium

Debug: RNN Vanishing Gradient

The implementation below has two bugs that cause vanishing gradients during BPTT. Find and fix them.

Signature: def rnn_step_buggy(x, h_prev, Wx, Wh, b)

import numpy as np

def rnn_step_buggy(x, h_prev, Wx, Wh, b):
    pre_act = x @ Wx.T + h_prev @ Wh.T + b
    h = 1 / (1 + np.exp(-pre_act))   # BUG 1
    grad_scale = np.max(np.abs(h))
    h = h / (grad_scale + 1e-8)       # BUG 2
    return h

Bug 1: Wrong activation function. RNNs use tanh (output range [-1, 1], gradient near 1 at 0), not sigmoid (output range [0,1], max gradient 0.25 — 4× smaller, accelerates vanishing).

Bug 2: Normalizing the hidden state by its max absolute value destroys the signal and makes gradients vanish even faster. RNNs do not normalize hidden states this way.

Fixed version:

def rnn_step_buggy(x, h_prev, Wx, Wh, b):
    pre_act = x @ Wx.T + h_prev @ Wh.T + b
    h = np.tanh(pre_act)   # fix: tanh not sigmoid
    return h               # fix: no normalization

Implement the corrected rnn_step_buggy (same signature, correct behavior).

Asked at

Python (numpy)0/3 runs today

Test Results

○zero input/hidden → tanh(0)=0 (sigmoid/norm would give non-zero)
○large positive input → tanh saturates near 1 (not 0.5 as sigmoid would)
○negative pre-activation → tanh gives negative output (sigmoid cannot)
○random weights seed=42 — tanh output range check
Advertisement