TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

255. Causal Attention Mask (PyTorch)

Easy

Implement causal masked attention in PyTorch using primitive tensor ops only. Position i may only attend to positions 0..i — required for autoregressive decoding (GPT family).

Signature: def causal_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor

  • Q, K, V: (..., N, d_k) (V last dim may be d_v)
  • Returns: (..., N, d_v)

The rule: you may NOT call F.scaled_dot_product_attention (with or without is_causal=True), nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. Hand-roll every step.

Build the mask: the standard PyTorch helpers are torch.triu (to construct an upper-triangular boolean mask strictly above the diagonal) and tensor.masked_fill (to set masked entries to -inf before the softmax). Make sure the mask broadcasts over any leading batch / head axes.

Math

Attention(Q,K,V)=softmax(dk​​QK⊤​+M)V,Mij​={0−∞​j≤ij>i​

Related problems

  • Causal Attention MaskeasyNumPy

Asked at

NumPy

import numpy as np

 

def causal_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?