TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

252. Transformer MLP Block (PyTorch)

Easy

Implement the FFN sublayer used in every Transformer block, in PyTorch with primitive tensor ops only.

FFN(x) = LayerNorm(x + W2 * GELU(W1 * x + b1) + b2)

Signature: def transformer_mlp(x, W1, b1, W2, b2, gamma, beta) -> torch.Tensor

  • x: (..., d_model)
  • W1: (d_ff, d_model), b1: (d_ff,)
  • W2: (d_model, d_ff), b2: (d_model,)
  • gamma, beta: (d_model,)
  • LayerNorm eps = 1e-5

The rule: you may NOT call nn.LayerNorm, F.layer_norm, or F.gelu. Hand-roll both LN (mean/var) and GELU (the exact erf form below).

Use the exact GELU: 0.5 * h * (1 + erf(h / sqrt(2))). Do not use the tanh approximation — expected outputs are computed with exact GELU.

PyTorch idioms:

  • torch.erf(h / 2.0**0.5) is the exact-form GELU. F.gelu(h, approximate='none') matches this; F.gelu(h, approximate='tanh') does not.
  • x.var(dim=-1, keepdim=True, unbiased=False) for population variance.
  • Matmul @ broadcasts naturally over leading dims, so the same code handles (d,), (B, d), and (B, T, d).

Math

FFN(x)=LayerNorm(x+W2​GELU(W1​x+b1​)+b2​)

Related problems

  • Transformer MLP BlockeasyNumPy

Asked at

NumPy

import numpy as np

 

def transformer_mlp(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?