TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

8. Scaled Dot-Product Attention

Medium

Implement scaled dot-product attention.

Signature: def scaled_dot_product_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray

Each query attends to all keys with similarity scaled by sqrt(d_k), the result is softmax-normalized across keys, then used to take a weighted sum over the values. d_k is the key dimension. See the math reference below.

Math

Attention(Q,K,V)=softmax(dk​​QK⊤​)V

Related problems

  • Scaled Dot-Product Attention (PyTorch)mediumPyTorch

Asked at

Python 30/10 runs today

Output

Anything you print() in your code will show up here after you click Run.

Test Results

○identity QKV
○uniform keys
○peaked scores🔒 Premium
○output non-negative when V is non-negative
○single-query attention: output row sums to 1 when V=I
○shift-invariant on Q when keys are uniform across rows
○batched 3D input (B=2, N=3, d_k=4)🔒 Premium