TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

270. KV Cache in a Decoder Loop

Medium

Implement the pre-allocated KV cache that real LLM inference engines (vLLM, SGLang, TensorRT-LLM) use. The naive np.concatenate approach reallocates a tensor on every decode step — for a 4096-token generation that's thousands of allocations, easily TLE territory in production. The production pattern is: pre-allocate the full (n_layers, max_seq_len, d_model) buffer at KVCache(...) time, track a per-layer write position, and slice-write each new entry into the next free slot.

This is a Build-in-Context problem. A working Model (with realistic prefill and decode_step signatures) is sitting in engine/model.py — you can read it to see exactly how your KVCache.append is called. The test harness invokes Model.generate(prompt, max_new_tokens) and compares the output token sequence to the reference; if your cache is correct, the sequence matches deterministically.

The contract (read engine/model.py for the calling code):

  • KVCache(n_layers, d_model, max_seq_len, dtype=np.float32) — pre-allocate. The buffer's outer dim must be max_seq_len, sized to len(prompt) + max_new_tokens. The __init__ is already written; you fill in append.
  • append(layer_idx, k, v) is called two ways:
    • From prefill: k and v are batched with shape (prompt_len, d_model) and one call covers the whole prompt for that layer.
    • From decode_step: k and v are single tokens with shape (d_model,) and one call adds exactly one row. Your append must handle both.
  • Return (K_full, V_full) covering only the populated rows — a slice/view of the pre-allocated buffer, not the full buffer with trailing zeros.

Why this matters in production

A real attention call reads the cache: softmax(Q @ K_full.T / sqrt(d)) @ V_full. If K_full is the entire pre-allocated buffer (with zeros after the live data), the attention weights spread across non-existent past tokens and the output is garbage. If your write index is wrong, you overwrite past K/V and history is corrupted. If your dtype handling is wrong, fp16 caches silently upcast to fp32 and your memory budget is double what it should be.

Math

K(ℓ)[pos:pos+n, :]←knew​,V(ℓ)[pos:pos+n, :]←vnew​

Related problems

  • KV CachemediumNumPy
  • Paged Attention (vLLM)hardNumPy
  • KV Cache Quantization (INT8)mediumNumPy
  • Prefix Caching (Prompt KV Reuse)hardNumPy

Asked at

NumPy

import numpy as np

 

def run_generate(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?