TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

45. GRU Cell

Medium

The Gated Recurrent Unit (GRU) is a streamlined alternative to LSTM with only 2 gates (reset and update) instead of 4. It merges the cell state and hidden state into one, reducing parameters while retaining the ability to capture long-range dependencies.

Given input x and previous hidden state h_prev:

r = sigmoid(W_r @ [h_prev, x] + b_r)   # reset gate
z = sigmoid(W_z @ [h_prev, x] + b_z)   # update gate
n = tanh(W_n @ [r * h_prev, x] + b_n)  # candidate hidden
h = (1 - z) * n + z * h_prev           # new hidden state

In practice, use concatenated weight matrices. Weight W_ih is (3*H, input_size) and W_hh is (3*H, H). The 3 gate components are [r, z, n] in order, so the first H rows of each weight/bias correspond to the reset gate, the next H to the update gate, and the last H to the candidate.

Follow the PyTorch GRUCell convention: for the candidate n, the reset gate is applied to the hidden-to-hidden pre-activation only (not to the input-to-hidden term).

Signature: def gru_cell(x, h_prev, W_ih, W_hh, b_ih, b_hh)

  • x: (input_size,)
  • h_prev: (hidden_size,)
  • W_ih: (3*hidden_size, input_size)
  • W_hh: (3*hidden_size, hidden_size)
  • b_ih, b_hh: (3*hidden_size,)
  • Returns: h_next (hidden_size,)

Math

rt​zt​nt​ht​​=σ(Wir​xt​+bir​+Whr​ht−1​+bhr​)=σ(Wiz​xt​+biz​+Whz​ht−1​+bhz​)=tanh(Win​xt​+bin​+rt​⊙(Whn​ht−1​+bhn​))=(1−zt​)⊙nt​+zt​⊙ht−1​​

Asked at

NumPy

import numpy as np

 

def gru_cell(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?