The scalar version of this problem (kv-cache-size) gives one number per call. The production version is a sweep over many (B, T) configs at once — you usually want a heatmap of "KV bytes" across batch sizes and context lengths.
Implement: def kv_cache_size_sweep(B, T, n_layers, d_kv, dtype_bytes) where:
B is a 1-D array of shape (N,) — batch sizes per config.T is a 1-D array of shape (N,) — sequence lengths per config (paired with B elementwise).n_layers, d_kv, dtype_bytes are scalars (assumed constant across the sweep).Return shape (N,) of int64 byte counts.
Formula (per config):
bytes[i] = 2 * B[i] * T[i] * n_layers * d_kv * dtype_bytes
where the leading 2 accounts for both K and V tensors, and d_kv = n_kv_heads * head_dim is the per-token KV vector size.
The recipe: vectorize the formula directly. Since B and T are arrays and the rest are scalars, the expression broadcasts naturally:
return (2 * B * T * n_layers * d_kv * dtype_bytes).astype(np.int64)
No loop. No np.vectorize wrapper. Just multiply.
Math
Asked at
import numpy as np
def kv_cache_size_sweep(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?