Implement Multi-Query Attention (MQA) — used in PaLM, Falcon, and Gemma to dramatically reduce KV cache memory. MQA has multiple Query heads but only a single Key/Value head shared across all query heads.
Signature: def multi_query_attention(x, Q_proj, K_proj, V_proj, Wo, n_heads)
x: (S, d_model)Q_proj: (d_model, d_model) — projects input to all Q headsK_proj: (d_h, d_model) — single K head projection (d_h = d_model // n_heads)V_proj: (d_h, d_model) — single V head projectionWo: (d_model, d_model) — output projectionn_heads: number of query heads(S, d_model)Project the input through Q_proj to get the full (S, d_model) query tensor and split it into n_heads heads of size d_h = d_model / n_heads. Project through K_proj and V_proj to get a single (S, d_h) key tensor and value tensor — these are shared across every query head. Run standard scaled dot-product softmax attention per query head against the shared K and V (scaling by sqrt(d_h)), concatenate the per-head outputs back to (S, d_model), and apply the output projection Wo.
| Method | Q heads | K heads | V heads | KV cache size | |--------|---------|---------|---------|----------------| | MHA | H | H | H | H × d_h | | GQA | H | G (H > G > 1) | G | G × d_h | | MQA | H | 1 | 1 | d_h |
MQA reduces KV cache by H× compared to MHA, enabling longer context at inference time.
Related problems
Asked at
import numpy as np
def multi_query_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?