TorchedUp
ProblemsPremium
TorchedUp
Flash Attention v2 Work DecompositionHard
ProblemsPremium

Flash Attention v2 Work Decomposition

Implement Flash Attention v2 tiled attention computation. FA2 improves on FA1 by assigning all work to Q-blocks (outer loop over Q, inner loop over K/V) for better GPU thread occupancy.

Signature: def flash_attn_v2(Q, K, V, block_size=2)

  • Q, K, V: (S, d) — query/key/value matrices
  • block_size: tile size (default 2)
  • Returns: (S, d) — same output as standard attention

Online Softmax Algorithm (per Q block)

For each Q block, maintain running statistics across K/V blocks:

# Initialization (for each Q block):
m_i = -inf   # running max per query row
l_i = 0      # running sum of exp(scores)
o_i = 0      # running output accumulator

# For each K/V block j:
S_ij = Qi @ Kj.T / sqrt(d)      # scores for this tile
m_ij = max(S_ij, axis=-1)        # per-row max of this tile
m_new = max(m_i, m_ij)           # updated max

# Rescale and update
l_i = exp(m_i - m_new) * l_i + sum(exp(S_ij - m_new), axis=-1)
o_i = exp(m_i - m_new) * o_i + exp(S_ij - m_new) @ Vj
m_i = m_new

# After all K/V blocks:
O[q_block] = o_i / l_i

Why O(sqrt(N)) Memory?

Standard attention materializes the full (S, S) attention matrix in SRAM (GPU on-chip memory). FA2 tiles it into block_size × block_size blocks, never materializing the full matrix. Memory = O(block_size × d) instead of O(S²).

The online softmax trick (maintain running max + running sum) allows exact softmax without a second pass over all K/V.

Asked at

Python (numpy)0/3 runs today

Test Results

○seed=42, S=4, d=4, block_size=2 — same result as standard attention
○S=2, block_size=1 — tiny blocks, same result
○non-negative output when V is non-negative
○matches standard attention (reference: 3x2 case, block_size=2)
Advertisement