LLM inference is bottlenecked by memory bandwidth, not compute — the GPU stalls waiting for model weights to be loaded for each token. Speculative decoding exploits this: use a tiny draft model to propose K tokens cheaply, then verify all K in a single forward pass of the large target model.
How it works:
t_1, …, t_K with probabilities q(t_i)p_i(·)p_i(t_i) ≥ q(t_i): accept token t_i1 - p_i(t_i)/q(t_i), stopmax(0, p_i - q_i) renormalizedThe result: on average more than 1 token per large-model forward pass, with identical output distribution to sampling from the large model alone.
Implement the deterministic threshold version (for testability — in practice you'd draw a random number):
Signature: def speculative_decode_step(draft_tokens, draft_probs, target_probs)
draft_tokens: (K,) int array — proposed token indicesdraft_probs: (K,) — draft model probability for each chosen tokentarget_probs: (K, vocab_size) — target model full distributions at each position(accepted_count, accepted_tokens, correction_prob)
accepted_count: int — how many tokens were accepted (stop at first rejection)accepted_tokens: (accepted_count,) int arraycorrection_prob: (vocab_size,) — corrected distribution for the next tokenMath
Asked at
Test Results