TorchedUp
ProblemsPremium
TorchedUp
Multi-Query Attention (MQA)Easy
ProblemsPremium

Multi-Query Attention (MQA)

Implement Multi-Query Attention (MQA) — used in PaLM, Falcon, and Gemma to dramatically reduce KV cache memory. MQA has multiple Query heads but only a single Key/Value head shared across all query heads.

Signature: def multi_query_attention(x, Q_proj, K_proj, V_proj, Wo, n_heads)

  • x: (S, d_model)
  • Q_proj: (d_model, d_model) — projects input to all Q heads
  • K_proj: (d_h, d_model) — single K head projection (d_h = d_model // n_heads)
  • V_proj: (d_h, d_model) — single V head projection
  • Wo: (d_model, d_model) — output projection
  • n_heads: number of query heads
  • Returns: (S, d_model)

Algorithm

Q = x @ Q_proj.T             # (S, d_model)
K = x @ K_proj.T             # (S, d_h)  ← single head
V = x @ V_proj.T             # (S, d_h)  ← single head

# Split Q into n_heads:
Q = Q.reshape(S, n_heads, d_h).transpose(1, 0, 2)  # (n_heads, S, d_h)

# Each Q head attends to the SAME K and V:
for i in range(n_heads):
    scores = Q[i] @ K.T / sqrt(d_h)   # (S, S)
    A = softmax(scores)
    head_i = A @ V                     # (S, d_h) using shared K/V

concat = concatenate(heads)  # (S, d_model)
output = concat @ Wo.T

MHA vs GQA vs MQA

| Method | Q heads | K heads | V heads | KV cache size | |--------|---------|---------|---------|----------------| | MHA | H | H | H | H × d_h | | GQA | H | G (H > G > 1) | G | G × d_h | | MQA | H | 1 | 1 | d_h |

MQA reduces KV cache by H× compared to MHA, enabling longer context at inference time.

Asked at

Python (numpy)0/3 runs today

Test Results

○seed=42, S=3, d=4, n_heads=2
○uniform scores (zero Q/K): all tokens equally attended
○non-negative output when x, projections, and Wo are non-negative
○zero Q/K projections → uniform attention; all output rows equal
Advertisement