TorchedUp
ProblemsPremium
TorchedUp
Speculative DecodingHard
ProblemsPremium

Speculative Decoding

LLM inference is bottlenecked by memory bandwidth, not compute — the GPU stalls waiting for model weights to be loaded for each token. Speculative decoding exploits this: use a tiny draft model to propose K tokens cheaply, then verify all K in a single forward pass of the large target model.

How it works:

  1. Draft model autoregressively proposes K tokens: t_1, …, t_K with probabilities q(t_i)
  2. Run the large model in parallel over all K positions, getting distributions p_i(·)
  3. Verify each token left-to-right:
    • If p_i(t_i) ≥ q(t_i): accept token t_i
    • Else: reject with probability 1 - p_i(t_i)/q(t_i), stop
  4. If a token is rejected at position i: sample the next token from the corrected distribution max(0, p_i - q_i) renormalized

The result: on average more than 1 token per large-model forward pass, with identical output distribution to sampling from the large model alone.

Implement the deterministic threshold version (for testability — in practice you'd draw a random number):

Signature: def speculative_decode_step(draft_tokens, draft_probs, target_probs)

  • draft_tokens: (K,) int array — proposed token indices
  • draft_probs: (K,) — draft model probability for each chosen token
  • target_probs: (K, vocab_size) — target model full distributions at each position
  • Returns: (accepted_count, accepted_tokens, correction_prob)
    • accepted_count: int — how many tokens were accepted (stop at first rejection)
    • accepted_tokens: (accepted_count,) int array
    • correction_prob: (vocab_size,) — corrected distribution for the next token

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○all 3 tokens accepted
○first token rejected (p < q)
○1 accepted then rejected🔒 Premium
Advertisement