TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

257. Multi-Query Attention (PyTorch)

Medium

Implement Multi-Query Attention in PyTorch using primitive tensor ops only. MQA has multiple Query heads but only a single Key/Value head shared across all of them — used in PaLM, Falcon, and Gemma to slash KV cache memory by a factor of H.

Signature: def multi_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor

  • Q: (..., H, N, d_k) — per-head queries (already split across heads)
  • K: (..., N, d_k) — single shared key head
  • V: (..., N, d_k) — single shared value head
  • Returns: (..., N, H * d_k) — outputs concatenated across heads

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

The PyTorch idiom for sharing K/V across heads: Q already has a head axis of size H at position -3. K and V lack a head axis altogether, so insert a size-1 head axis at the same position (unsqueeze) and let broadcasting handle the expansion across all H query heads when you compute scores. This is the whole trick: MQA is MHA where the KV head dim is 1 and broadcasts. Same kernel, ~H× less KV cache memory at inference.

PyTorch idioms vs NumPy:

  • tensor.unsqueeze(dim) inserts a size-1 dim. NumPy uses np.expand_dims(arr, dim) or arr[..., None, :, :]. unsqueeze reads cleaner and works with negative axes.
  • Broadcasting matches NumPy semantics — size-1 dims expand against larger dims for free, no expand / broadcast_to call needed.

Math

MQA(Q,K,V)h​=softmax(dk​​Qh​K⊤​)Vfor h=1,…,H

Related problems

  • Multi-Query Attention (MQA)easyNumPy

Asked at

NumPy

import numpy as np

 

def multi_query_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?