TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

256. Multi-Head Attention (PyTorch)

Hard

Implement multi-head attention in PyTorch using primitive tensor ops only (no learned projections).

Signature: def multi_head_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_heads: int) -> torch.Tensor

  • Q, K, V: (..., N, d_model) (last dim is the full model dim, not the per-head dim)
  • num_heads: divides d_model
  • Returns: (..., N, d_model)

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

The PyTorch idiom for splitting heads:

# (..., N, d_model) -> (..., N, H, d_k) -> (..., H, N, d_k)
shape = list(x.shape); shape[-1:] = [num_heads, d_k]
x = x.reshape(shape).transpose(-3, -2)

After computing weights @ V_heads of shape (..., H, N, d_k), reverse the layout: .transpose(-3, -2).contiguous().reshape(..., N, d_model). The .contiguous() is required before the final reshape because transpose produces a non-contiguous view.

PyTorch idioms vs NumPy:

  • PyTorch's reshape works on non-contiguous tensors sometimes but throws if it can't preserve the strides — .contiguous() before .reshape is the safe pattern after a transpose. NumPy .reshape quietly copies when needed.
  • transpose(dim0, dim1) swaps two specific dims, exactly like swapaxes. PyTorch also has .permute(*dims) for full reorderings.

Math

MHA(Q,K,V)=Concat(head1​,…,headh​),headi​=Attention(Qi​,Ki​,Vi​)

Related problems

  • Multi-Head AttentionhardNumPy

Asked at

NumPy

import numpy as np

 

def multi_head_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?