TorchedUp
ProblemsPremium
TorchedUp
Full Transformer (Encoder-Decoder)Hard
ProblemsPremium

Full Transformer (Encoder-Decoder)

Implement a mini encoder-decoder Transformer for sequence-to-sequence tasks (like translation).

Signature: def mini_transformer(src, tgt, weights)

  • src: (S, d_model) — source sequence
  • tgt: (T, d_model) — target sequence (teacher-forced during inference)
  • weights: dict with all weight matrices (see below)
  • Returns: (T, d_model) — decoder output

Architecture (1 encoder layer + 1 decoder layer)

Encoder layer:

  1. Multi-head self-attention on src (no mask)
  2. Add & LayerNorm
  3. FFN: Linear → ReLU → Linear
  4. Add & LayerNorm → memory

Decoder layer:

  1. Masked multi-head self-attention on tgt (causal mask: upper-triangular -1e9)
  2. Add & LayerNorm
  3. Cross-attention: Q from decoder, K/V from memory
  4. Add & LayerNorm
  5. FFN: Linear → ReLU → Linear
  6. Add & LayerNorm → output

Multi-Head Attention

For n_heads heads with head dim d_h = d_model // n_heads:

Q = x @ Wq.T,  K = x @ Wk.T,  V = x @ Wv.T
# Split into heads: reshape (S, d) → (n_heads, S, d_h)
# Per head: scores = Q_h @ K_h.T / sqrt(d_h)  [+ optional mask]
#           attn = softmax(scores) @ V_h
# Concatenate heads → (S, d) → @ Wo.T

LayerNorm

LayerNorm(x, gamma, beta) = gamma * (x - mean) / sqrt(var + 1e-5) + beta

Weights dict keys

| Key | Shape | Description | |-----|-------|-------------| | n_heads | int | number of attention heads | | enc_Wq/Wk/Wv/Wo | (d, d) | encoder self-attention | | enc_ln1_g/b, enc_ln2_g/b | (d,) | encoder layer norms | | enc_W1 | (d_ff, d), enc_b1 (d_ff,) | encoder FFN layer 1 | | enc_W2 | (d, d_ff), enc_b2 (d,) | encoder FFN layer 2 | | dec_self_Wq/Wk/Wv/Wo | (d, d) | decoder masked self-attention | | dec_cross_Wq/Wk/Wv/Wo | (d, d) | decoder cross-attention | | dec_ln1_g/b, dec_ln2_g/b, dec_ln3_g/b | (d,) | decoder layer norms | | dec_W1 | (d_ff, d), dec_b1 (d_ff,) | decoder FFN layer 1 | | dec_W2 | (d, d_ff), dec_b2 (d,) | decoder FFN layer 2 |

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○seed=42, S=3 src tokens, T=2 tgt tokens, d=4, n_heads=2
○seed=13, S=2 src, T=3 tgt, d=4, n_heads=2 — longer target
Advertisement