Standard attention computes the full N×N attention matrix, requiring O(N²) memory. Flash Attention rewrites the computation using tiling — processing blocks of queries against blocks of keys/values — and maintains running softmax statistics to produce the exact same output in O(N) memory.
Signature: def flash_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, block_size: int = 2) -> np.ndarray
You must NOT materialize the full N×N attention matrix. Instead, iterate over blocks of the key/value sequence and update running accumulators.
The key insight: for a query block Q_i and key block K_j, compute local scores S_ij = Q_i K_j^T / sqrt(d). Maintain running max m_i and running sum l_i to perform online softmax normalization:
m_i_new = max(m_i_prev, max(S_ij, axis=-1))
P_ij = exp(S_ij - m_i_new)
correction = exp(m_i_prev - m_i_new)
O_i = (O_i * correction * l_i + P_ij @ V_j) / (correction * l_i + sum(P_ij))
l_i = correction * l_i + sum(P_ij, axis=-1)
Math
Asked at
Test Results