Implement the pre-allocated KV cache that real LLM inference engines
(vLLM, SGLang, TensorRT-LLM) use. The naive np.concatenate approach
reallocates a tensor on every decode step — for a 4096-token generation
that's thousands of allocations, easily TLE territory in production. The
production pattern is: pre-allocate the full (n_layers, max_seq_len, d_model)
buffer at KVCache(...) time, track a per-layer write position, and
slice-write each new entry into the next free slot.
This is a Build-in-Context problem. A working Model (with realistic
prefill and decode_step signatures) is sitting in engine/model.py —
you can read it to see exactly how your KVCache.append is called. The
test harness invokes Model.generate(prompt, max_new_tokens) and compares
the output token sequence to the reference; if your cache is correct, the
sequence matches deterministically.
The contract (read engine/model.py for the calling code):
KVCache(n_layers, d_model, max_seq_len, dtype=np.float32) — pre-allocate.
The buffer's outer dim must be max_seq_len, sized to
len(prompt) + max_new_tokens. The __init__ is already written; you
fill in append.append(layer_idx, k, v) is called two ways:
prefill: k and v are batched with shape (prompt_len, d_model)
and one call covers the whole prompt for that layer.decode_step: k and v are single tokens with shape (d_model,)
and one call adds exactly one row.
Your append must handle both.(K_full, V_full) covering only the populated rows — a slice/view
of the pre-allocated buffer, not the full buffer with trailing zeros.Why this matters in production
A real attention call reads the cache: softmax(Q @ K_full.T / sqrt(d)) @ V_full.
If K_full is the entire pre-allocated buffer (with zeros after the live
data), the attention weights spread across non-existent past tokens and
the output is garbage. If your write index is wrong, you overwrite past
K/V and history is corrupted. If your dtype handling is wrong, fp16
caches silently upcast to fp32 and your memory budget is double what it
should be.
Math
Related problems
Asked at
import numpy as np
def run_generate(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?