A classic fused kernel computes Y = ReLU(X @ W + b) in one pass. Compute the DRAM bytes used unfused (three separate kernels — GEMM, then bias-add, then ReLU — each reading and writing its full intermediate tensor through DRAM) vs fused (one kernel that reads each input once and writes Y once), and report the savings.
Signature: def gemm_bias_relu_bytes(M: int, N: int, K: int, dtype_bytes: int) -> list
Shapes: X is (M, K), W is (K, N), b is (N,), Y is (M, N). Bias is broadcast (do not count it as M*N).
Return [unfused_bytes, fused_bytes, savings_bytes] (all ints — multiply element counts by dtype_bytes).
Math
Asked at
import numpy as np
def gemm_bias_relu_bytes(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?