TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

261. Transformer Encoder Block (PyTorch)

Medium

Implement a full Pre-LN Transformer encoder block in PyTorch using primitive tensor ops only.

x = x + Attn(LN1(x)) @ Wo
x = x + W2 @ GELU(W1 @ LN2(x) + b1) + b2

(Standard, non-causal self-attention. Single head.)

Signature: def transformer_encoder_block(x, Wq, Wk, Wv, Wo, W1, b1, W2, b2, gamma1, beta1, gamma2, beta2) -> torch.Tensor

  • x: (..., seq_len, d_model) — supports rank 2 or 3
  • Wq, Wk, Wv, Wo: (d_model, d_model)
  • W1: (d_ff, d_model), b1: (d_ff,)
  • W2: (d_model, d_ff), b2: (d_model,)
  • gamma1, beta1, gamma2, beta2: (d_model,)
  • LayerNorm eps = 1e-5. GELU: exact erf form.

The rule: you may NOT call nn.MultiheadAttention, F.scaled_dot_product_attention, F.multi_head_attention_forward, nn.LayerNorm, F.layer_norm, or F.gelu. Build everything from @, .softmax-via-exp, .mean, .var, and .erf.

You may still skip F.softmax or implement softmax via shifted-exp / sum directly — either is fine, the test only checks output values.

PyTorch idioms vs NumPy:

  • K.transpose(-1, -2) is the @ partner for the QK^T contraction (np.swapaxes(K, -1, -2) in NumPy).
  • Softmax: scores = scores - scores.amax(dim=-1, keepdim=True); w = scores.exp(); w = w / w.sum(dim=-1, keepdim=True). Note .amax returns a tensor directly; .max returns a NamedTuple.
  • All matmuls broadcast over leading dims, so the same code handles (seq, d) and (B, seq, d) without changes.

Math

x=x+Attn(LN1​(x));x=x+FFN(LN2​(x))

Related problems

  • Transformer Encoder BlockmediumNumPy

Asked at

NumPy

import numpy as np

 

def transformer_encoder_block(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?