Implement scaled dot-product attention with an arbitrary 2D boolean mask that controls which key positions each query position is allowed to attend to.
Signature: def attention_masked(Q, K, V, attn_mask) -> output
Q: shape (T_q, d_k)K: shape (T_k, d_k)V: shape (T_k, d_v)attn_mask: shape (T_q, T_k). 1 (or True) means attend; 0 (or False) means block.(T_q, d_v)Run standard scaled dot-product softmax attention, but apply the mask to the scores before the softmax — push blocked entries to -inf so exp zeros them and the softmax denominator only counts legal keys. Multiplying out the mask after the softmax is wrong (the un-masked weights still sum to 1, so post-multiplication leaves the output mis-scaled). See the math reference below.
Causal mask (decoder-only LMs): position i can attend to positions <= i. The mask is the lower-triangular indicator np.tri(T).
Padding mask (variable-length batched inputs): if some key positions are <PAD> filler, every query masks those positions out. The mask has columns of zeros at the padded positions.
Combined: in causal LMs with padded sequences, both masks are AND'd together — block if either causal or padding says block.
You don't need to handle the masks differently — the function takes a single arbitrary 2D mask and applies it.
Math
Related problems
Asked at
import numpy as np
def attention_masked(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?