During autoregressive decoding, the Transformer caches K and V for every past token at every layer. Compute the total KV cache size in bytes.
Signature: def kv_cache_bytes(batch: int, seq_len: int, n_layers: int, n_kv_heads: int, head_dim: int, dtype_bytes: int) -> int
Formula:
bytes = 2 * batch * seq_len * n_layers * n_kv_heads * head_dim * dtype_bytes
The leading 2 accounts for both K and V. n_kv_heads (not n_heads) reflects GQA/MQA — many models share KV heads across query heads.
Example: Llama-2-7B with B=1, S=2048, L=32, n_kv=32, head_dim=128, fp16 (2 bytes) -> 1.07 GB.
Math
Asked at
import numpy as np
def kv_cache_bytes(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?