A standard transformer block computes Q = X @ Wq, K = X @ Wk, V = X @ Wv — three separate GEMMs that each re-read X. Fusing them into a single GEMM with concatenated weight W_qkv of shape (d, 3d) reads X only once.
Signature: def qkv_fusion_bytes(seq_len: int, d_model: int, dtype_bytes: int) -> list
Shapes: X is (seq_len, d_model). Each of Wq, Wk, Wv is (d_model, d_model); each of Q, K, V is (seq_len, d_model). The fused weight is (d_model, 3*d_model) and the fused output is (seq_len, 3*d_model).
Count the DRAM traffic for the unfused version (three independent GEMMs, each reading its inputs and writing its output through DRAM) and the fused version (one GEMM that reads each input once and writes the stacked output once). Multiply element counts by dtype_bytes.
Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).
Math
Asked at
import numpy as np
def qkv_fusion_bytes(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?