TorchedUp
ProblemsPremium
TorchedUp
QKV Projection FusionMedium
ProblemsPremium

QKV Projection Fusion

A standard transformer block computes Q = X @ Wq, K = X @ Wk, V = X @ Wv — three separate GEMMs that each re-read X. Fusing them into a single GEMM with concatenated weight W_qkv of shape (d, 3d) reads X only once.

Signature: def qkv_fusion_bytes(seq_len: int, d_model: int, dtype_bytes: int) -> list

Unfused: 3 separate gemms, each reads X (seq*d), reads weight (d*d), writes output (seq*d):

unfused = 3 * (seq*d + d*d + seq*d) * dtype_bytes = 3 * (2*seq*d + d*d) * dtype_bytes

Fused: 1 gemm reading X once, weight (d, 3d), writes 3*seq*d:

fused = (seq*d + 3*d*d + 3*seq*d) * dtype_bytes

Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○small fp16
○medium fp16
○large fp16🔒 Premium
Advertisement