TorchedUp
ProblemsPremium
TorchedUp
Multi-Head AttentionHard
ProblemsPremium

Multi-Head Attention

Implement multi-head attention (without learned projection matrices).

Signature: def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, num_heads: int) -> np.ndarray

  • Split Q, K, V along the last dimension into num_heads heads
  • Compute scaled dot-product attention for each head independently
  • Concatenate results along the last dimension

Assume d_model % num_heads == 0.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○1 head equals single-head attention
○2 heads d=4
○2 heads uniform input🔒 Premium
○non-negative output when V is non-negative
○per-head row sums to 1: V=I → output row sums to num_heads (heads=2 → 2)
○V=I per head: each head row sums to 1 (numeric: total sum = num_heads)
Advertisement