TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

254. Masked Attention (PyTorch)

Medium

Implement scaled dot-product attention with an arbitrary boolean mask, in PyTorch using primitive tensor ops only.

Signature: def attention_masked(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor

  • Q: (..., T_q, d_k)
  • K: (..., T_k, d_k)
  • V: (..., T_k, d_v)
  • attn_mask: boolean tensor where True = attend, False = block (broadcasts over leading dims).
  • Returns: (..., T_q, d_v)

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

The masking idiom in PyTorch is scores.masked_fill(~mask, float('-inf')) — set blocked positions to -inf before the softmax so exp(-inf) = 0 and they don't contribute to the normalizer.

PyTorch idioms vs NumPy:

  • tensor.masked_fill(mask, value) is the idiomatic way to inject -inf. NumPy uses np.where(mask, 0.0, -np.inf) + addition.
  • Pass a Python bool tensor; ~mask flips True/False. (NumPy uses np.where directly with the bool array.)
  • .swapaxes(-1, -2) over .T for batch-shape-agnostic transpose.

Math

Attention(Q,K,V)=softmax(dk​​QK⊤​+M)V,Mij​={0−∞​maskij​=1maskij​=0​

Related problems

  • Masked Attention (Causal + Padding)mediumNumPy

Asked at

NumPy

import numpy as np

 

def attention_masked(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?