TorchedUp
ProblemsPremium
TorchedUp
Grouped Query Attention (GQA)Medium
ProblemsPremium

Grouped Query Attention (GQA)

GQA, used in LLaMA 2/3 and Mistral, reduces KV cache memory by having multiple query heads share a single set of key/value heads.

With (H_q) query heads and (H_{kv}) key/value heads ((H_q > H_{kv}), (H_q) divisible by (H_{kv})), each KV head serves (H_q / H_{kv}) query heads.

GQA generalizes: when (H_{kv} = H_q) it's MHA; when (H_{kv} = 1) it's MQA.

Signature: def grouped_query_attention(Q, K, V, num_kv_heads)

  • Q: (num_q_heads, seq_len, d_k)
  • K: (num_kv_heads, seq_len, d_k)
  • V: (num_kv_heads, seq_len, d_v)
  • num_kv_heads: int
  • Returns: (seq_len, num_q_heads * d_v) — concatenated outputs of all query heads

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○4 Q heads, 2 KV heads (ratio=2)
○GQA degenerates to MQA (num_kv_heads=1)
○GQA degenerates to MHA (num_kv_heads = num_q_heads)🔒 Premium
○non-negative output when V is non-negative
○V=I per group: each q-head row sums to 1 (sum of output = num_q_heads)
Advertisement