TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

259. Flash Attention Tiled (PyTorch)

Hard

Implement the tiled Flash Attention forward pass in PyTorch using primitive tensor ops only. The point: never materialize the full N×N attention matrix. Instead, iterate over blocks of K/V and maintain running softmax statistics (m, l) so the output O is updated incrementally — exact same result, O(N) memory.

Signature: def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, block_size: int = 2) -> torch.Tensor

  • Q: (..., N, d)
  • K: (..., N, d)
  • V: (..., N, d_v)
  • block_size: tile size B
  • Returns: (..., N, d_v) — identical to standard softmax(Q K^T / sqrt(d)) @ V

The rule: you may NOT call F.scaled_dot_product_attention, nn.MultiheadAttention, F.softmax, torch.softmax, or nn.Softmax. You must also NOT materialize the full N×N attention matrix in any single tensor — the whole point is to compute it block-by-block.

The online-softmax update (per (i, j) block pair):

S_ij = Q_i @ K_j^T / sqrt(d)              # (B, B)  — only the current tile
m_new = max(m_prev, S_ij.max(dim=-1).values)
P_ij = exp(S_ij - m_new[..., None])
correction = exp(m_prev - m_new)
cur_l = correction * l_prev
O_i = (O_i * cur_l[..., None] + P_ij @ V_j) / (cur_l + P_ij.sum(-1))[..., None]
l_i = cur_l + P_ij.sum(-1)
m_i = m_new

The key insight: when a new tile reveals a larger m, we rescale both the running output O_i and the running normalizer l_i by exp(m_prev - m_new) so the math stays consistent with a single-pass softmax.

PyTorch idioms vs NumPy:

  • tensor.max(dim=-1).values returns a tensor; without .values you get the NamedTuple. NumPy's np.max(arr, axis=-1) returns the array directly.
  • torch.maximum(a, b) is the elementwise max (NumPy: np.maximum). tensor.max(...) is the reduction — different op.
  • tensor.unsqueeze(-1) (PyTorch) vs arr[..., None] (NumPy). Both work in PyTorch; unsqueeze is more explicit.
  • torch.full(shape, float('-inf')) matches NumPy's np.full(shape, -np.inf).

Math

minew​Pij​Oi​​=max(miprev​, jmax​Sij​)=exp(Sij​−minew​),Sij​=d​Qi​Kj⊤​​←emiprev​−minew​⋅li​+∑Pij​Oi​⋅emiprev​−minew​⋅li​+Pij​Vj​​​

Related problems

  • Flash Attention (Tiled)hardNumPy

Asked at

NumPy

import numpy as np

 

def flash_attention(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?