FlashAttention tiles the attention computation so that Q, K, V, and O blocks fit in SRAM along with two running scalars per row (max and sumexp). Given an SM's SRAM capacity, find the largest block size B such that
B * (4 * d + 2) * dtype_bytes <= sram_bytes
Signature: def flash_attention_tile_size(sram_bytes: int, head_dim: int, dtype_bytes: int) -> list
Return [block_size, sram_used_bytes] (both ints). Block size must be at least 1.
Example: SRAM = 98304 bytes (96 KB), d = 64, fp16 → B = 98304 // (2 * (4*64 + 2)) = 98304 // 516 = 190.
Math
Asked at
Test Results