Implement the FFN sublayer used in every Transformer block, in PyTorch with primitive tensor ops only.
FFN(x) = LayerNorm(x + W2 * GELU(W1 * x + b1) + b2)
Signature: def transformer_mlp(x, W1, b1, W2, b2, gamma, beta) -> torch.Tensor
x: (..., d_model)W1: (d_ff, d_model), b1: (d_ff,)W2: (d_model, d_ff), b2: (d_model,)gamma, beta: (d_model,)eps = 1e-5The rule: you may NOT call nn.LayerNorm, F.layer_norm, or F.gelu. Hand-roll both LN (mean/var) and GELU (the exact erf form below).
Use the exact GELU: 0.5 * h * (1 + erf(h / sqrt(2))). Do not use the tanh approximation — expected outputs are computed with exact GELU.
PyTorch idioms:
torch.erf(h / 2.0**0.5) is the exact-form GELU. F.gelu(h, approximate='none') matches this; F.gelu(h, approximate='tanh') does not.x.var(dim=-1, keepdim=True, unbiased=False) for population variance.@ broadcasts naturally over leading dims, so the same code handles (d,), (B, d), and (B, T, d).Math
Related problems
Asked at
import numpy as np
def transformer_mlp(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?