TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

191. QKV Projection Fusion

Medium

A standard transformer block computes Q = X @ Wq, K = X @ Wk, V = X @ Wv — three separate GEMMs that each re-read X. Fusing them into a single GEMM with concatenated weight W_qkv of shape (d, 3d) reads X only once.

Signature: def qkv_fusion_bytes(seq_len: int, d_model: int, dtype_bytes: int) -> list

Shapes: X is (seq_len, d_model). Each of Wq, Wk, Wv is (d_model, d_model); each of Q, K, V is (seq_len, d_model). The fused weight is (d_model, 3*d_model) and the fused output is (seq_len, 3*d_model).

Count the DRAM traffic for the unfused version (three independent GEMMs, each reading its inputs and writing its output through DRAM) and the fused version (one GEMM that reads each input once and writes the stacked output once). Multiply element counts by dtype_bytes.

Return [unfused_bytes, fused_bytes, savings_bytes] (all ints).

Math

[QKV]=X⋅[Wq​Wk​Wv​]

Asked at

NumPy

import numpy as np

 

def qkv_fusion_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?