Implement Multi-Query Attention in PyTorch using primitive tensor ops only. MQA has multiple Query heads but only a single Key/Value head shared across all of them — used in PaLM, Falcon, and Gemma to slash KV cache memory by a factor of H.
Signature: def multi_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor
Q: (..., H, N, d_k) — per-head queries (already split across heads)K: (..., N, d_k) — single shared key headV: (..., N, d_k) — single shared value head(..., N, H * d_k) — outputs concatenated across headsThe rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
The PyTorch idiom for sharing K/V across heads:
# Q has shape (..., H, N, d_k). K, V have shape (..., N, d_k).
# Insert a head-axis of size 1 so K/V broadcast across all H query heads:
K_b = K.unsqueeze(-3) # (..., 1, N, d_k)
V_b = V.unsqueeze(-3) # (..., 1, N, d_k)
scores = Q @ K_b.swapaxes(-1, -2) / sqrt(d_k) # broadcasts to (..., H, N, N)
This is the whole trick: MQA is MHA where the KV head dim is 1 and broadcasts. Same kernel, ~H× less KV cache memory at inference.
PyTorch idioms vs NumPy:
tensor.unsqueeze(dim) inserts a size-1 dim. NumPy uses np.expand_dims(arr, dim) or arr[..., None, :, :]. unsqueeze reads cleaner and works with negative axes.expand / broadcast_to call needed.Math
Related problems
Asked at
import numpy as np
def multi_query_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?