Implement Flash Attention v2 tiled attention computation. FA2 improves on FA1 by assigning all work to Q-blocks (outer loop over Q, inner loop over K/V) for better GPU thread occupancy.
Signature: def flash_attn_v2(Q, K, V, block_size=2)
Q, K, V: (S, d) — query/key/value matricesblock_size: tile size (default 2)(S, d) — same output as standard attentionFor each Q block, maintain three running statistics across the K/V loop: a per-row running max m, a per-row running sum-of-exps l, and an unnormalized output accumulator o (numerator).
When processing a new K/V tile, compute the tile's scores, take the per-row tile max, and update the running max. Whenever the running max changes, the previously accumulated l and o were computed against a stale max, so rescale them by exp(m_old - m_new) (a number ≤ 1) before adding the current tile's contribution. After the K/V loop finishes, divide o by l to get the final softmax-attention output for that Q block.
Standard attention materializes the full (S, S) attention matrix in SRAM (GPU on-chip memory). FA2 tiles it into block_size × block_size blocks, never materializing the full matrix. Memory = O(block_size × d) instead of O(S²).
The online softmax trick (maintain running max + running sum) allows exact softmax without a second pass over all K/V.
Asked at
import numpy as np
def flash_attn_v2(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?