TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

253. Scaled Dot-Product Attention (PyTorch)

Medium

Implement scaled dot-product attention in PyTorch using primitive tensor ops only.

Signature: def scaled_dot_product_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor

Compute softmax(Q K^T / sqrt(d_k)) @ V.

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

Allowed primitives: @ (matmul), .swapaxes(-1, -2) / .transpose(-1, -2), .exp(), .max(), .sum(), math.sqrt, basic arithmetic.

Requirements:

  • Works for 2D (N, d_k) and any leading-batch shape (..., N, d_k) — use .swapaxes(-1, -2) (or .transpose(-1, -2)), not .T (which reverses every axis and breaks on rank-3+ inputs).
  • Numerically stable softmax: subtract scores.max(dim=-1, keepdim=True).values before exponentiating.
  • Use keepdim (not NumPy's keepdims).

PyTorch idioms vs the NumPy version:

  • .max(dim=-1, keepdim=True) returns a NamedTuple (values, indices) — access .values. NumPy's .max(axis=-1) returns the array directly.
  • Prefer .swapaxes(-1, -2) over K.T. Both exist in PyTorch; only the first is rank-agnostic.

Math

Attention(Q,K,V)=softmax(dk​​QK⊤​)V

Related problems

  • Scaled Dot-Product AttentionmediumNumPy

Asked at

NumPy

import numpy as np

 

def scaled_dot_product_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?