Standard attention computes the full N×N attention matrix, requiring O(N²) memory. Flash Attention rewrites the computation using tiling — processing blocks of queries against blocks of keys/values — and maintains running softmax statistics to produce the exact same output in O(N) memory.
Signature: def flash_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, block_size: int = 2) -> np.ndarray
You must NOT materialize the full N×N attention matrix. Instead, iterate over blocks of the key/value sequence and incrementally accumulate the attention output, using an online softmax that tracks the running per-query max and normalizing constant. The math reference summarises the update; the algorithmic details are up to you.
Math
Related problems
Asked at
import numpy as np
def flash_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?