TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

189. FlashAttention Tile Size from SRAM

Hard

FlashAttention tiles the attention computation so that Q, K, V, and O blocks fit in SRAM along with two running scalars per row (max and sumexp). Given an SM's SRAM capacity, find the largest block size B such that

B * (4 * d + 2) * dtype_bytes <= sram_bytes

Signature: def flash_attention_tile_size(sram_bytes: int, head_dim: int, dtype_bytes: int) -> list

Return [block_size, sram_used_bytes] (both ints). Block size must be at least 1.

Example: SRAM = 98304 bytes (96 KB), d = 64, fp16 → B = 98304 // (2 * (4*64 + 2)) = 98304 // 516 = 190.

Math

B=⌊(4d+2)⋅sS​⌋

Asked at

NumPy

import numpy as np

 

def flash_attention_tile_size(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?