TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

28. Grouped Query Attention (GQA)

Medium

GQA, used in LLaMA 2/3 and Mistral, reduces KV cache memory by having multiple query heads share a single set of key/value heads.

With (H_q) query heads and (H_{kv}) key/value heads ((H_q > H_{kv}), (H_q) divisible by (H_{kv})), each KV head serves (H_q / H_{kv}) query heads.

GQA generalizes: when (H_{kv} = H_q) it's MHA; when (H_{kv} = 1) it's MQA.

Signature: def grouped_query_attention(Q, K, V, num_kv_heads)

  • Q: (num_q_heads, seq_len, d_k)
  • K: (num_kv_heads, seq_len, d_k)
  • V: (num_kv_heads, seq_len, d_v)
  • num_kv_heads: int
  • Returns: (seq_len, num_q_heads * d_v) — concatenated outputs of all query heads

Math

heads per KV group=Hkv​Hq​​For group g:Oqh​=Attention(Qqh​,Kg​,Vg​),qh∈[g⋅r, (g+1)⋅r)

Related problems

  • Grouped Query Attention (PyTorch)mediumPyTorch

Asked at

NumPy

import numpy as np

 

def grouped_query_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?