Implement Multi-Query Attention in PyTorch using primitive tensor ops only. MQA has multiple Query heads but only a single Key/Value head shared across all of them — used in PaLM, Falcon, and Gemma to slash KV cache memory by a factor of H.
Signature: def multi_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor
Q: (..., H, N, d_k) — per-head queries (already split across heads)K: (..., N, d_k) — single shared key headV: (..., N, d_k) — single shared value head(..., N, H * d_k) — outputs concatenated across headsThe rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
The PyTorch idiom for sharing K/V across heads: Q already has a head axis of size H at position -3. K and V lack a head axis altogether, so insert a size-1 head axis at the same position (unsqueeze) and let broadcasting handle the expansion across all H query heads when you compute scores. This is the whole trick: MQA is MHA where the KV head dim is 1 and broadcasts. Same kernel, ~H× less KV cache memory at inference.
PyTorch idioms vs NumPy:
tensor.unsqueeze(dim) inserts a size-1 dim. NumPy uses np.expand_dims(arr, dim) or arr[..., None, :, :]. unsqueeze reads cleaner and works with negative axes.expand / broadcast_to call needed.Math
Related problems
Asked at
import numpy as np
def multi_query_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?