TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

215. Masked Attention (Causal + Padding)

Medium

Implement scaled dot-product attention with an arbitrary 2D boolean mask that controls which key positions each query position is allowed to attend to.

Signature: def attention_masked(Q, K, V, attn_mask) -> output

  • Q: shape (T_q, d_k)
  • K: shape (T_k, d_k)
  • V: shape (T_k, d_v)
  • attn_mask: shape (T_q, T_k). 1 (or True) means attend; 0 (or False) means block.
  • Output: shape (T_q, d_v)

Run standard scaled dot-product softmax attention, but apply the mask to the scores before the softmax — push blocked entries to -inf so exp zeros them and the softmax denominator only counts legal keys. Multiplying out the mask after the softmax is wrong (the un-masked weights still sum to 1, so post-multiplication leaves the output mis-scaled). See the math reference below.

Common mask patterns

Causal mask (decoder-only LMs): position i can attend to positions <= i. The mask is the lower-triangular indicator np.tri(T).

Padding mask (variable-length batched inputs): if some key positions are <PAD> filler, every query masks those positions out. The mask has columns of zeros at the padded positions.

Combined: in causal LMs with padded sequences, both masks are AND'd together — block if either causal or padding says block.

You don't need to handle the masks differently — the function takes a single arbitrary 2D mask and applies it.

Math

Attention(Q,K,V)=softmax(dk​​QK⊤​+M)V,Mij​={0−∞​maskij​=1maskij​=0​

Related problems

  • Masked Attention (PyTorch)mediumPyTorch

Asked at

NumPy

import numpy as np

 

def attention_masked(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?