GQA, used in LLaMA 2/3 and Mistral, reduces KV cache memory by having multiple query heads share a single set of key/value heads.
With (H_q) query heads and (H_{kv}) key/value heads ((H_q > H_{kv}), (H_q) divisible by (H_{kv})), each KV head serves (H_q / H_{kv}) query heads.
GQA generalizes: when (H_{kv} = H_q) it's MHA; when (H_{kv} = 1) it's MQA.
Signature: def grouped_query_attention(Q, K, V, num_kv_heads)
Q: (num_q_heads, seq_len, d_k)K: (num_kv_heads, seq_len, d_k)V: (num_kv_heads, seq_len, d_v)num_kv_heads: int(seq_len, num_q_heads * d_v) — concatenated outputs of all query headsMath
Asked at
Test Results