Implement the tiled Flash Attention forward pass in PyTorch using primitive tensor ops only. The point: never materialize the full N×N attention matrix. Instead, iterate over blocks of K/V and maintain running softmax statistics (m, l) so the output O is updated incrementally — exact same result, O(N) memory.
Signature: def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, block_size: int = 2) -> torch.Tensor
Q: (..., N, d)K: (..., N, d)V: (..., N, d_v)block_size: tile size B(..., N, d_v) — identical to standard softmax(Q K^T / sqrt(d)) @ VThe rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. You must also NOT materialize the full N×N attention matrix in any single tensor — the whole point is to compute it block-by-block.
The online-softmax update (per (i, j) block pair):
S_ij = Q_i @ K_j^T / sqrt(d) # (B, B) — only the current tile
m_new = max(m_prev, S_ij.max(dim=-1).values)
P_ij = exp(S_ij - m_new[..., None])
correction = exp(m_prev - m_new)
cur_l = correction * l_prev
O_i = (O_i * cur_l[..., None] + P_ij @ V_j) / (cur_l + P_ij.sum(-1))[..., None]
l_i = cur_l + P_ij.sum(-1)
m_i = m_new
The key insight: when a new tile reveals a larger m, we rescale both the running output O_i and the running normalizer l_i by exp(m_prev - m_new) so the math stays consistent with a single-pass softmax.
PyTorch idioms vs NumPy:
tensor.max(dim=-1).values returns a tensor; without .values you get the NamedTuple. NumPy's np.max(arr, axis=-1) returns the array directly.torch.maximum(a, b) is the elementwise max (NumPy: np.maximum). tensor.max(...) is the reduction — different op.tensor.unsqueeze(-1) (PyTorch) vs arr[..., None] (NumPy). Both work in PyTorch; unsqueeze is more explicit.torch.full(shape, float('-inf')) matches NumPy's np.full(shape, -np.inf).Math
Related problems
Asked at
import numpy as np
def flash_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?