Implement Grouped Query Attention in PyTorch using primitive tensor ops only. GQA — used in LLaMA 2/3 and Mistral — sits between MHA and MQA: H_q query heads share H_kv key/value heads (with H_q divisible by H_kv).
When H_kv = H_q it's MHA. When H_kv = 1 it's MQA.
Signature: def grouped_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor
Q: (..., H_q, N, d_k)K: (..., H_kv, N, d_k)V: (..., H_kv, N, d_k)(..., N, H_q * d_k) — head outputs concatenated along last dim, in original head order.The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
The PyTorch idiom for grouped broadcasting (no copies): let ratio = H_q / H_kv. Split Q's head axis into a pair of axes (H_kv, ratio) via reshape, and add a size-1 axis on K and V at the matching position. Now K and V align with Q on the H_kv axis, and the new ratio axis on the query side broadcasts against the size-1 axis on the KV side at zero copy cost — every group of ratio query heads automatically pairs with the same shared K/V head. Run scaled dot-product softmax attention through this aligned/broadcast layout, then merge the (H_kv, ratio) axes back into a single H_q head axis before concatenating along the feature dim.
PyTorch idioms vs NumPy:
tensor.unsqueeze(dim) is the PyTorch spelling of np.expand_dims — works with negative axes, stays a view.reshape here splits an axis without moving any data — same as NumPy. After it, the head dim becomes (H_kv, ratio) so each KV head pairs with ratio query heads automatically.Math
Related problems
Asked at
import numpy as np
def grouped_query_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?