A standard transformer block computes Q = X @ Wq, K = X @ Wk, V = X @ Wv — three separate GEMMs that each re-read X. Fusing them into a single GEMM with concatenated weight W_qkv of shape (d, 3d) reads X only once.
Signature: def qkv_fusion_bytes(seq_len: int, d_model: int, dtype_bytes: int) -> list
Unfused: 3 separate gemms, each reads X (seq*d), reads weight (d*d), writes output (seq*d):
unfused = 3 * (seq*d + d*d + seq*d) * dtype_bytes = 3 * (2*seq*d + d*d) * dtype_bytes
Fused: 1 gemm reading X once, weight (d, 3d), writes 3*seq*d:
fused = (seq*d + 3*d*d + 3*seq*d) * dtype_bytes
Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).
Math
Asked at
Test Results