TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

111. KV Cache Size

Easy

During autoregressive decoding, the Transformer caches K and V for every past token at every layer. Compute the total KV cache size in bytes.

Signature: def kv_cache_bytes(batch: int, seq_len: int, n_layers: int, n_kv_heads: int, head_dim: int, dtype_bytes: int) -> int

Formula:

bytes = 2 * batch * seq_len * n_layers * n_kv_heads * head_dim * dtype_bytes

The leading 2 accounts for both K and V. n_kv_heads (not n_heads) reflects GQA/MQA — many models share KV heads across query heads.

Example: Llama-2-7B with B=1, S=2048, L=32, n_kv=32, head_dim=128, fp16 (2 bytes) -> 1.07 GB.

Math

bytesKV​=2⋅B⋅S⋅L⋅Hkv​⋅dhead​⋅dtype_bytes

Asked at

NumPy

import numpy as np

 

def kv_cache_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?