Implement multi-head attention in PyTorch using primitive tensor ops only (no learned projections).
Signature: def multi_head_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_heads: int) -> torch.Tensor
Q, K, V: (..., N, d_model) (last dim is the full model dim, not the per-head dim)num_heads: divides d_model(..., N, d_model)The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
The PyTorch idiom for splitting heads:
# (..., N, d_model) -> (..., N, H, d_k) -> (..., H, N, d_k)
shape = list(x.shape); shape[-1:] = [num_heads, d_k]
x = x.reshape(shape).transpose(-3, -2)
After computing weights @ V_heads of shape (..., H, N, d_k), reverse the layout: .transpose(-3, -2).contiguous().reshape(..., N, d_model). The .contiguous() is required before the final reshape because transpose produces a non-contiguous view.
PyTorch idioms vs NumPy:
reshape works on non-contiguous tensors sometimes but throws if it can't preserve the strides — .contiguous() before .reshape is the safe pattern after a transpose. NumPy .reshape quietly copies when needed.transpose(dim0, dim1) swaps two specific dims, exactly like swapaxes. PyTorch also has .permute(*dims) for full reorderings.Math
Related problems
Asked at
import numpy as np
def multi_head_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?