Implement Flash Attention v2 tiled attention computation. FA2 improves on FA1 by assigning all work to Q-blocks (outer loop over Q, inner loop over K/V) for better GPU thread occupancy.
Signature: def flash_attn_v2(Q, K, V, block_size=2)
Q, K, V: (S, d) — query/key/value matricesblock_size: tile size (default 2)(S, d) — same output as standard attentionFor each Q block, maintain running statistics across K/V blocks:
# Initialization (for each Q block):
m_i = -inf # running max per query row
l_i = 0 # running sum of exp(scores)
o_i = 0 # running output accumulator
# For each K/V block j:
S_ij = Qi @ Kj.T / sqrt(d) # scores for this tile
m_ij = max(S_ij, axis=-1) # per-row max of this tile
m_new = max(m_i, m_ij) # updated max
# Rescale and update
l_i = exp(m_i - m_new) * l_i + sum(exp(S_ij - m_new), axis=-1)
o_i = exp(m_i - m_new) * o_i + exp(S_ij - m_new) @ Vj
m_i = m_new
# After all K/V blocks:
O[q_block] = o_i / l_i
Standard attention materializes the full (S, S) attention matrix in SRAM (GPU on-chip memory). FA2 tiles it into block_size × block_size blocks, never materializing the full matrix. Memory = O(block_size × d) instead of O(S²).
The online softmax trick (maintain running max + running sum) allows exact softmax without a second pass over all K/V.
Asked at
Test Results