TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

83. Multi-Query Attention (MQA)

Easy

Implement Multi-Query Attention (MQA) — used in PaLM, Falcon, and Gemma to dramatically reduce KV cache memory. MQA has multiple Query heads but only a single Key/Value head shared across all query heads.

Signature: def multi_query_attention(x, Q_proj, K_proj, V_proj, Wo, n_heads)

  • x: (S, d_model)
  • Q_proj: (d_model, d_model) — projects input to all Q heads
  • K_proj: (d_h, d_model) — single K head projection (d_h = d_model // n_heads)
  • V_proj: (d_h, d_model) — single V head projection
  • Wo: (d_model, d_model) — output projection
  • n_heads: number of query heads
  • Returns: (S, d_model)

Algorithm

Project the input through Q_proj to get the full (S, d_model) query tensor and split it into n_heads heads of size d_h = d_model / n_heads. Project through K_proj and V_proj to get a single (S, d_h) key tensor and value tensor — these are shared across every query head. Run standard scaled dot-product softmax attention per query head against the shared K and V (scaling by sqrt(d_h)), concatenate the per-head outputs back to (S, d_model), and apply the output projection Wo.

MHA vs GQA vs MQA

| Method | Q heads | K heads | V heads | KV cache size | |--------|---------|---------|---------|----------------| | MHA | H | H | H | H × d_h | | GQA | H | G (H > G > 1) | G | G × d_h | | MQA | H | 1 | 1 | d_h |

MQA reduces KV cache by H× compared to MHA, enabling longer context at inference time.

Related problems

  • Multi-Query Attention (PyTorch)mediumPyTorch

Asked at

NumPy

import numpy as np

 

def multi_query_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?