Implement Multi-Query Attention (MQA) — used in PaLM, Falcon, and Gemma to dramatically reduce KV cache memory. MQA has multiple Query heads but only a single Key/Value head shared across all query heads.
Signature: def multi_query_attention(x, Q_proj, K_proj, V_proj, Wo, n_heads)
x: (S, d_model)Q_proj: (d_model, d_model) — projects input to all Q headsK_proj: (d_h, d_model) — single K head projection (d_h = d_model // n_heads)V_proj: (d_h, d_model) — single V head projectionWo: (d_model, d_model) — output projectionn_heads: number of query heads(S, d_model)Q = x @ Q_proj.T # (S, d_model)
K = x @ K_proj.T # (S, d_h) ← single head
V = x @ V_proj.T # (S, d_h) ← single head
# Split Q into n_heads:
Q = Q.reshape(S, n_heads, d_h).transpose(1, 0, 2) # (n_heads, S, d_h)
# Each Q head attends to the SAME K and V:
for i in range(n_heads):
scores = Q[i] @ K.T / sqrt(d_h) # (S, S)
A = softmax(scores)
head_i = A @ V # (S, d_h) using shared K/V
concat = concatenate(heads) # (S, d_model)
output = concat @ Wo.T
| Method | Q heads | K heads | V heads | KV cache size | |--------|---------|---------|---------|----------------| | MHA | H | H | H | H × d_h | | GQA | H | G (H > G > 1) | G | G × d_h | | MQA | H | 1 | 1 | d_h |
MQA reduces KV cache by H× compared to MHA, enabling longer context at inference time.
Asked at
Test Results