A Triton GEMM kernel processes (BLOCK_M, BLOCK_N) output tiles by accumulating over BLOCK_K-sized strips. Each program needs SRAM for an A tile (BLOCK_M, BLOCK_K), a B tile (BLOCK_K, BLOCK_N), and an fp32 accumulator (BLOCK_M, BLOCK_N) (4 bytes per element regardless of input dtype).
Signature: def triton_block_feasible(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, sram_bytes, dtype_bytes) -> list
SRAM used:
used = (BLOCK_M*BLOCK_K + BLOCK_K*BLOCK_N) * dtype_bytes + BLOCK_M*BLOCK_N * 4
Return [feasible_int, sram_used_bytes, utilization] where feasible_int is 1 if used <= sram_bytes, else 0. Utilization is used / sram_bytes (or 0.0 if sram_bytes <= 0).
Math
Asked at
Test Results