TorchedUp
ProblemsPremium
TorchedUp
KV Cache SizeEasy
ProblemsPremium

KV Cache Size

During autoregressive decoding, the Transformer caches K and V for every past token at every layer. Compute the total KV cache size in bytes.

Signature: def kv_cache_bytes(batch: int, seq_len: int, n_layers: int, n_kv_heads: int, head_dim: int, dtype_bytes: int) -> int

Formula:

bytes = 2 * batch * seq_len * n_layers * n_kv_heads * head_dim * dtype_bytes

The leading 2 accounts for both K and V. n_kv_heads (not n_heads) reflects GQA/MQA — many models share KV heads across query heads.

Example: Llama-2-7B with B=1, S=2048, L=32, n_kv=32, head_dim=128, fp16 (2 bytes) -> 1.07 GB.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○B=1 S=2048 L=32 GQA8
○B=4 S=512
○long context 32k🔒 Premium
Advertisement