TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

262. Transformer Decoder Block (PyTorch)

Hard

Implement a GPT-style Pre-LN decoder block in PyTorch using primitive tensor ops only: masked (causal) self-attention + FFN, each with a Pre-LN residual.

x = x + CausalAttn(LN1(x)) @ Wo
x = x + W2 @ GELU(W1 @ LN2(x) + b1) + b2

Causal mask: position i may attend to positions 0..i only. Set the upper-triangle (j > i) of the score matrix to -inf before the softmax shift.

Signature: def transformer_decoder_block(x, Wq, Wk, Wv, Wo, W1, b1, W2, b2, gamma1, beta1, gamma2, beta2) -> torch.Tensor

  • x: (..., seq, d)
  • Wq, Wk, Wv, Wo: (d, d)
  • W1: (d_ff, d), b1: (d_ff,)
  • W2: (d, d_ff), b2: (d,)
  • gamma1, beta1, gamma2, beta2: (d,)
  • LayerNorm eps = 1e-5. GELU: exact erf form.

The rule: you may NOT call nn.MultiheadAttention, F.scaled_dot_product_attention (even with is_causal=True), nn.LayerNorm, F.layer_norm, or F.gelu. No cross-attention helpers either — this block is decoder-only (no encoder context).

PyTorch idioms vs NumPy:

  • Build the mask with torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1) — note diagonal=1 (the diagonal itself stays attendable).
  • Apply with scores.masked_fill(mask, float('-inf')) — this is the idiomatic PyTorch way to force masked positions to zero post-softmax. Using a large negative finite value (-1e9) also works but -inf produces exact zeros.
  • Subtract scores.amax(dim=-1, keepdim=True) for numerical stability after masking. The masked -inf entries contribute exp(-inf) = 0 cleanly.

Math

x′=x+CausalAttn(LN1​(x)),out=x′+FFN(LN2​(x′))

Related problems

  • Transformer Decoder BlockhardNumPy

Asked at

NumPy

import numpy as np

 

def transformer_decoder_block(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?