TorchedUp
ProblemsPremium
TorchedUp
Debug: Broken Transformer Encoder (5 Bugs)Hard
ProblemsPremium

Debug: Broken Transformer Encoder (5 Bugs)

The implementation below has 5 bugs. Find and fix all of them.

Signature: def transformer_encoder_buggy(x, Wq, Wk, Wv, Wo, W1, b1, W2, b2, ln1_g, ln1_b, ln2_g, ln2_b, n_heads=2)

import numpy as np

def transformer_encoder_buggy(x, Wq, Wk, Wv, Wo, W1, b1, W2, b2,
                               ln1_g, ln1_b, ln2_g, ln2_b, n_heads=2):
    S, d = x.shape
    d_h = d // n_heads

    # Multi-head attention
    Q = x @ Wq.T
    K = x @ Wk.T
    V = x @ Wv.T
    Q = Q.reshape(S, n_heads, d_h).transpose(1, 0, 2)
    K = K.reshape(S, n_heads, d_h).transpose(1, 0, 2)
    V = V.reshape(S, n_heads, d_h).transpose(1, 0, 2)
    heads = []
    for i in range(n_heads):
        scores = Q[i] @ K[i].T / d_h           # BUG 1: missing sqrt
        scores = np.exp(scores)                  # BUG 2: no subtract-max (numerical instability, wrong)
        A = scores / scores.sum(-1, keepdims=True)
        heads.append(A @ V[i])
    concat = np.concatenate(heads, axis=-1)
    attn_out = concat @ Wo.T

    # Residual + LayerNorm 1
    x2 = attn_out                                # BUG 3: missing residual (should be x + attn_out)
    mu = x2.mean(-1, keepdims=True)
    var = x2.var(-1, keepdims=True)
    x2 = ln1_g * (x2 - mu) / np.sqrt(var + 1e-5) + ln1_b

    # FFN: Linear -> ReLU -> Linear
    h = np.maximum(0, x2 @ W1.T + b1)
    ff = h @ W2 + b2                             # BUG 4: W2 not transposed (should be W2.T)

    # Residual + LayerNorm 2
    x3 = x2 + ff
    mu = x3.mean(-1, keepdims=True)
    std = np.sqrt(x3.var(-1, keepdims=True))
    x3 = ln2_g * (x3 - mu) / std + ln2_b        # BUG 5: missing eps in denominator (divide-by-zero risk)

    return x3

Bugs summary:

  1. / d_h → should be / np.sqrt(d_h) (scale by sqrt of head dim)
  2. np.exp(scores) without subtract-max → numerically unstable; correct softmax subtracts scores.max(-1, keepdims=True) first
  3. x2 = attn_out → should be x2 = x + attn_out (missing residual connection)
  4. h @ W2 → should be h @ W2.T (projection needs transpose)
  5. / std → should be / np.sqrt(var + 1e-5) (missing eps, can divide by zero)

Implement the corrected version with all 5 bugs fixed.

Asked at

Python (numpy)0/3 runs today

Test Results

○seed=42, S=3, d=4, n_heads=2
Advertisement