TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

88. Flash Attention v2 Work Decomposition

Hard

Implement Flash Attention v2 tiled attention computation. FA2 improves on FA1 by assigning all work to Q-blocks (outer loop over Q, inner loop over K/V) for better GPU thread occupancy.

Signature: def flash_attn_v2(Q, K, V, block_size=2)

  • Q, K, V: (S, d) — query/key/value matrices
  • block_size: tile size (default 2)
  • Returns: (S, d) — same output as standard attention

Online Softmax Algorithm (per Q block)

For each Q block, maintain three running statistics across the K/V loop: a per-row running max m, a per-row running sum-of-exps l, and an unnormalized output accumulator o (numerator).

When processing a new K/V tile, compute the tile's scores, take the per-row tile max, and update the running max. Whenever the running max changes, the previously accumulated l and o were computed against a stale max, so rescale them by exp(m_old - m_new) (a number ≤ 1) before adding the current tile's contribution. After the K/V loop finishes, divide o by l to get the final softmax-attention output for that Q block.

Why O(sqrt(N)) Memory?

Standard attention materializes the full (S, S) attention matrix in SRAM (GPU on-chip memory). FA2 tiles it into block_size × block_size blocks, never materializing the full matrix. Memory = O(block_size × d) instead of O(S²).

The online softmax trick (maintain running max + running sum) allows exact softmax without a second pass over all K/V.

Asked at

NumPy

import numpy as np

 

def flash_attn_v2(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?