During autoregressive generation, a language model generates one token at a time. Without caching, we'd recompute the keys and values for every past token at each step — O(N²) total compute.
The KV Cache stores the keys and values from all past tokens. For each new token, we:
Implement kv_cache_attention: given the existing key-value cache (past_keys, past_values) and a new token's (new_key, new_value, query), append the new key/value to the cache and return the attention output for this token.
Signature: def kv_cache_attention(past_keys, past_values, new_key, new_value, query)
past_keys: (n_past, d_k)past_values: (n_past, d_v)new_key: (1, d_k) — current token's keynew_value: (1, d_v) — current token's valuequery: (1, d_k) — current token's queryoutput of shape (1, d_v)Scale by sqrt(d_k). Use standard softmax attention.
Math
Asked at
Test Results