TorchedUp
ProblemsPremium
TorchedUp
GEMM + Bias + ReLU FusionEasy
ProblemsPremium

GEMM + Bias + ReLU Fusion

A classic fused kernel computes Y = ReLU(X @ W + b) in one pass. Compute the DRAM bytes used unfused (3 separate kernels) vs fused (1 kernel), and the savings.

Signature: def gemm_bias_relu_bytes(M: int, N: int, K: int, dtype_bytes: int) -> list

Unfused:

  • GEMM: read X (MK) + read W (KN) + write Y0 (M*N)
  • Bias add: read Y0 (MN) + read b (N) + write Y1 (MN)
  • ReLU: read Y1 (MN) + write Y2 (MN)

Fused: read X + read W + read b + write Y. (MK + KN + N + M*N)

Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○small fp32
○fp16
○rectangular🔒 Premium
Advertisement