During autoregressive decoding, the Transformer caches K and V for every past token at every layer. Compute the total KV cache size in bytes.
Signature: def kv_cache_bytes(batch: int, seq_len: int, n_layers: int, n_kv_heads: int, head_dim: int, dtype_bytes: int) -> int
Formula:
bytes = 2 * batch * seq_len * n_layers * n_kv_heads * head_dim * dtype_bytes
The leading 2 accounts for both K and V. n_kv_heads (not n_heads) reflects GQA/MQA — many models share KV heads across query heads.
Example: Llama-2-7B with B=1, S=2048, L=32, n_kv=32, head_dim=128, fp16 (2 bytes) -> 1.07 GB.
Math
Asked at
Test Results