TorchedUp
ProblemsPremium
TorchedUp
FlashAttention Tile Size from SRAMHard
ProblemsPremium

FlashAttention Tile Size from SRAM

FlashAttention tiles the attention computation so that Q, K, V, and O blocks fit in SRAM along with two running scalars per row (max and sumexp). Given an SM's SRAM capacity, find the largest block size B such that

B * (4 * d + 2) * dtype_bytes <= sram_bytes

Signature: def flash_attention_tile_size(sram_bytes: int, head_dim: int, dtype_bytes: int) -> list

Return [block_size, sram_used_bytes] (both ints). Block size must be at least 1.

Example: SRAM = 98304 bytes (96 KB), d = 64, fp16 → B = 98304 // (2 * (4*64 + 2)) = 98304 // 516 = 190.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○96 KB SRAM, d=64 fp16
○48 KB SRAM, d=128 fp16
○tiny SRAM clamp🔒 Premium
Advertisement