Implement Grouped Query Attention in PyTorch using primitive tensor ops only. GQA — used in LLaMA 2/3 and Mistral — sits between MHA and MQA: H_q query heads share H_kv key/value heads (with H_q divisible by H_kv).
When H_kv = H_q it's MHA. When H_kv = 1 it's MQA.
Signature: def grouped_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor
Q: (..., H_q, N, d_k)K: (..., H_kv, N, d_k)V: (..., H_kv, N, d_k)(..., N, H_q * d_k) — head outputs concatenated along last dim, in original head order.The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
The PyTorch idiom for grouped broadcasting (no copies):
ratio = H_q // H_kv
# Reshape Q so its head axis splits into (H_kv, ratio):
Q_g = Q.reshape(*Q.shape[:-3], H_kv, ratio, N, d_k) # (..., H_kv, ratio, N, d_k)
K_g = K.unsqueeze(-3) # (..., H_kv, 1, N, d_k)
V_g = V.unsqueeze(-3)
# Now Q_g and K_g align on H_kv; ratio broadcasts against the size-1 axis.
scores = Q_g @ K_g.swapaxes(-1, -2) / sqrt(d_k) # (..., H_kv, ratio, N, N)
This is a zero-copy broadcast: unsqueeze produces a strided view, and PyTorch broadcasting expands the size-1 axis without materializing copies.
PyTorch idioms vs NumPy:
tensor.unsqueeze(dim) is the PyTorch spelling of np.expand_dims — works with negative axes, stays a view.reshape here splits an axis without moving any data — same as NumPy. After it, the head dim becomes (H_kv, ratio) so each KV head pairs with ratio query heads automatically.Math
Related problems
Asked at
import numpy as np
def grouped_query_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?