Implement scaled dot-product attention in PyTorch using primitive tensor ops only.
Signature: def scaled_dot_product_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor
Compute softmax(Q K^T / sqrt(d_k)) @ V.
The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.
Allowed primitives: @ (matmul), .swapaxes(-1, -2) / .transpose(-1, -2), .exp(), .max(), .sum(), math.sqrt, basic arithmetic.
Requirements:
(N, d_k) and any leading-batch shape (..., N, d_k) — use .swapaxes(-1, -2) (or .transpose(-1, -2)), not .T (which reverses every axis and breaks on rank-3+ inputs).scores.max(dim=-1, keepdim=True).values before exponentiating.keepdim (not NumPy's keepdims).PyTorch idioms vs the NumPy version:
.max(dim=-1, keepdim=True) returns a NamedTuple (values, indices) — access .values. NumPy's .max(axis=-1) returns the array directly..swapaxes(-1, -2) over K.T. Both exist in PyTorch; only the first is rank-agnostic.Math
Related problems
Asked at
import numpy as np
def scaled_dot_product_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?