TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

257. Multi-Query Attention (PyTorch)

Medium

Implement Multi-Query Attention in PyTorch using primitive tensor ops only. MQA has multiple Query heads but only a single Key/Value head shared across all of them — used in PaLM, Falcon, and Gemma to slash KV cache memory by a factor of H.

Signature: def multi_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor

  • Q: (..., H, N, d_k) — per-head queries (already split across heads)
  • K: (..., N, d_k) — single shared key head
  • V: (..., N, d_k) — single shared value head
  • Returns: (..., N, H * d_k) — outputs concatenated across heads

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

The PyTorch idiom for sharing K/V across heads:

# Q has shape (..., H, N, d_k). K, V have shape (..., N, d_k).
# Insert a head-axis of size 1 so K/V broadcast across all H query heads:
K_b = K.unsqueeze(-3)   # (..., 1, N, d_k)
V_b = V.unsqueeze(-3)   # (..., 1, N, d_k)
scores = Q @ K_b.swapaxes(-1, -2) / sqrt(d_k)   # broadcasts to (..., H, N, N)

This is the whole trick: MQA is MHA where the KV head dim is 1 and broadcasts. Same kernel, ~H× less KV cache memory at inference.

PyTorch idioms vs NumPy:

  • tensor.unsqueeze(dim) inserts a size-1 dim. NumPy uses np.expand_dims(arr, dim) or arr[..., None, :, :]. unsqueeze reads cleaner and works with negative axes.
  • Broadcasting matches NumPy semantics — size-1 dims expand against larger dims for free, no expand / broadcast_to call needed.

Math

MQA(Q,K,V)h​=softmax(dk​​Qh​K⊤​)Vfor h=1,…,H

Related problems

  • Multi-Query Attention (MQA)easyNumPy

Asked at

NumPy

import numpy as np

 

def multi_query_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?