Implement the tiled Flash Attention forward pass in PyTorch using primitive tensor ops only. The point: never materialize the full N×N attention matrix. Instead, iterate over blocks of K/V and maintain running softmax statistics (m, l) so the output O is updated incrementally — exact same result, O(N) memory.
Signature: def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, block_size: int = 2) -> torch.Tensor
Q: (..., N, d)K: (..., N, d)V: (..., N, d_v)block_size: tile size B(..., N, d_v) — identical to standard softmax(Q K^T / sqrt(d)) @ VThe rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. You must also NOT materialize the full N×N attention matrix in any single tensor — the whole point is to compute it block-by-block.
The online-softmax idea (per (i, j) block pair): see the math reference below for the formal update. The key insight: when a new K/V tile reveals a row-max larger than any seen so far, the previously accumulated O_i and the running normalizer l_i were both computed against a stale max. Rescale both by exp(m_prev - m_new) (a number ≤ 1) before adding the current tile's contribution, then update m to the new max. This keeps the running totals consistent with a single-pass numerically-stable softmax.
PyTorch idioms vs NumPy:
tensor.max(dim=-1).values returns a tensor; without .values you get the NamedTuple. NumPy's np.max(arr, axis=-1) returns the array directly.torch.maximum(a, b) is the elementwise max (NumPy: np.maximum). tensor.max(...) is the reduction — different op.tensor.unsqueeze(-1) (PyTorch) vs arr[..., None] (NumPy). Both work in PyTorch; unsqueeze is more explicit.torch.full(shape, float('-inf')) matches NumPy's np.full(shape, -np.inf).Math
Related problems
Asked at
import numpy as np
def flash_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?