TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

27. Flash Attention (Tiled)

Hard

Standard attention computes the full N×N attention matrix, requiring O(N²) memory. Flash Attention rewrites the computation using tiling — processing blocks of queries against blocks of keys/values — and maintains running softmax statistics to produce the exact same output in O(N) memory.

Signature: def flash_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, block_size: int = 2) -> np.ndarray

You must NOT materialize the full N×N attention matrix. Instead, iterate over blocks of the key/value sequence and incrementally accumulate the attention output, using an online softmax that tracks the running per-query max and normalizing constant. The math reference summarises the update; the algorithmic details are up to you.

Math

Online softmax update:minew​=max(miprev​, jmax​Sij​)Pij​=exp(Sij​−minew​),Sij​=d​Qi​KjT​​Oi​←emiprev​−minew​⋅li​+∑Pij​Oi​⋅emiprev​−minew​⋅li​+Pij​Vj​​

Related problems

  • Flash Attention Tiled (PyTorch)hardPyTorch

Asked at

NumPy

import numpy as np

 

def flash_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?