TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

226. KV Cache Size Sweep

Medium

The scalar version of this problem (kv-cache-size) gives one number per call. The production version is a sweep over many (B, T) configs at once — you usually want a heatmap of "KV bytes" across batch sizes and context lengths.

Implement: def kv_cache_size_sweep(B, T, n_layers, d_kv, dtype_bytes) where:

  • B is a 1-D array of shape (N,) — batch sizes per config.
  • T is a 1-D array of shape (N,) — sequence lengths per config (paired with B elementwise).
  • n_layers, d_kv, dtype_bytes are scalars (assumed constant across the sweep).

Return shape (N,) of int64 byte counts.

Formula (per config):

bytes[i] = 2 * B[i] * T[i] * n_layers * d_kv * dtype_bytes

where the leading 2 accounts for both K and V tensors, and d_kv = n_kv_heads * head_dim is the per-token KV vector size.

The recipe: vectorize the formula directly. Since B and T are arrays and the rest are scalars, the expression broadcasts naturally:

return (2 * B * T * n_layers * d_kv * dtype_bytes).astype(np.int64)

No loop. No np.vectorize wrapper. Just multiply.

Math

bytesi​=2⋅Bi​⋅Ti​⋅L⋅dkv​⋅b

Asked at

NumPy

import numpy as np

 

def kv_cache_size_sweep(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?