A classic fused kernel computes Y = ReLU(X @ W + b) in one pass. Compute the DRAM bytes used unfused (3 separate kernels) vs fused (1 kernel), and the savings.
Signature: def gemm_bias_relu_bytes(M: int, N: int, K: int, dtype_bytes: int) -> list
Unfused:
Fused: read X + read W + read b + write Y. (MK + KN + N + M*N)
Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).
Math
Asked at
Test Results