TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

20. Multi-Head Attention

Hard

Implement multi-head attention (without learned projection matrices).

Signature: def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, num_heads: int) -> np.ndarray

  • Split Q, K, V along the last dimension into num_heads heads
  • Compute scaled dot-product attention for each head independently
  • Concatenate results along the last dimension

Assume d_model % num_heads == 0.

Math

MHA(Q,K,V)=Concat(head1​,…,headh​)WO

Related problems

  • Multi-Head Attention (PyTorch)hardPyTorch

Asked at

Python 30/10 runs today

Output

Anything you print() in your code will show up here after you click Run.

Test Results

○1 head equals single-head attention
○2 heads d=4
○2 heads uniform input🔒 Premium
○non-negative output when V is non-negative
○per-head row sums to 1: V=I → output row sums to num_heads (heads=2 → 2)
○V=I per head: each head row sums to 1 (numeric: total sum = num_heads)
○batched 3D input (B=2, N=3, d=4, heads=2)🔒 Premium