TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

259. Flash Attention Tiled (PyTorch)

Hard

Implement the tiled Flash Attention forward pass in PyTorch using primitive tensor ops only. The point: never materialize the full N×N attention matrix. Instead, iterate over blocks of K/V and maintain running softmax statistics (m, l) so the output O is updated incrementally — exact same result, O(N) memory.

Signature: def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, block_size: int = 2) -> torch.Tensor

  • Q: (..., N, d)
  • K: (..., N, d)
  • V: (..., N, d_v)
  • block_size: tile size B
  • Returns: (..., N, d_v) — identical to standard softmax(Q K^T / sqrt(d)) @ V

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. You must also NOT materialize the full N×N attention matrix in any single tensor — the whole point is to compute it block-by-block.

The online-softmax idea (per (i, j) block pair): see the math reference below for the formal update. The key insight: when a new K/V tile reveals a row-max larger than any seen so far, the previously accumulated O_i and the running normalizer l_i were both computed against a stale max. Rescale both by exp(m_prev - m_new) (a number ≤ 1) before adding the current tile's contribution, then update m to the new max. This keeps the running totals consistent with a single-pass numerically-stable softmax.

PyTorch idioms vs NumPy:

  • tensor.max(dim=-1).values returns a tensor; without .values you get the NamedTuple. NumPy's np.max(arr, axis=-1) returns the array directly.
  • torch.maximum(a, b) is the elementwise max (NumPy: np.maximum). tensor.max(...) is the reduction — different op.
  • tensor.unsqueeze(-1) (PyTorch) vs arr[..., None] (NumPy). Both work in PyTorch; unsqueeze is more explicit.
  • torch.full(shape, float('-inf')) matches NumPy's np.full(shape, -np.inf).

Math

minew​Pij​Oi​​=max(miprev​, jmax​Sij​)=exp(Sij​−minew​),Sij​=d​Qi​Kj⊤​​←emiprev​−minew​⋅li​+∑Pij​Oi​⋅emiprev​−minew​⋅li​+Pij​Vj​​​

Related problems

  • Flash Attention (Tiled)hardNumPy

Asked at

NumPy

import numpy as np

 

def flash_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?