TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

258. Grouped Query Attention (PyTorch)

Medium

Implement Grouped Query Attention in PyTorch using primitive tensor ops only. GQA — used in LLaMA 2/3 and Mistral — sits between MHA and MQA: H_q query heads share H_kv key/value heads (with H_q divisible by H_kv).

When H_kv = H_q it's MHA. When H_kv = 1 it's MQA.

Signature: def grouped_query_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor

  • Q: (..., H_q, N, d_k)
  • K: (..., H_kv, N, d_k)
  • V: (..., H_kv, N, d_k)
  • Returns: (..., N, H_q * d_k) — head outputs concatenated along last dim, in original head order.

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

The PyTorch idiom for grouped broadcasting (no copies): let ratio = H_q / H_kv. Split Q's head axis into a pair of axes (H_kv, ratio) via reshape, and add a size-1 axis on K and V at the matching position. Now K and V align with Q on the H_kv axis, and the new ratio axis on the query side broadcasts against the size-1 axis on the KV side at zero copy cost — every group of ratio query heads automatically pairs with the same shared K/V head. Run scaled dot-product softmax attention through this aligned/broadcast layout, then merge the (H_kv, ratio) axes back into a single H_q head axis before concatenating along the feature dim.

PyTorch idioms vs NumPy:

  • tensor.unsqueeze(dim) is the PyTorch spelling of np.expand_dims — works with negative axes, stays a view.
  • reshape here splits an axis without moving any data — same as NumPy. After it, the head dim becomes (H_kv, ratio) so each KV head pairs with ratio query heads automatically.

Math

For group g∈[0,Hkv​), head qh=g⋅r+j, j∈[0,r):Oqh​=Attention(Qqh​,Kg​,Vg​), r=Hq​/Hkv​

Related problems

  • Grouped Query Attention (GQA)mediumNumPy

Asked at

NumPy

import numpy as np

 

def grouped_query_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?