TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

187. GEMM + Bias + ReLU Fusion

Easy

A classic fused kernel computes Y = ReLU(X @ W + b) in one pass. Compute the DRAM bytes used unfused (three separate kernels — GEMM, then bias-add, then ReLU — each reading and writing its full intermediate tensor through DRAM) vs fused (one kernel that reads each input once and writes Y once), and report the savings.

Signature: def gemm_bias_relu_bytes(M: int, N: int, K: int, dtype_bytes: int) -> list

Shapes: X is (M, K), W is (K, N), b is (N,), Y is (M, N). Bias is broadcast (do not count it as M*N).

Return [unfused_bytes, fused_bytes, savings_bytes] (all ints — multiply element counts by dtype_bytes).

Math

Y=ReLU(XW+b)

Asked at

NumPy

import numpy as np

 

def gemm_bias_relu_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?