TorchedUp
ProblemsPremium
TorchedUp
Speculative Decoding — Medusa VariantHard
ProblemsPremium

Speculative Decoding — Medusa Variant

Implement Medusa speculative decoding: multiple draft heads predict the next K tokens simultaneously, and a single verification pass accepts the longest valid prefix.

Signature: def medusa_verify(draft_tokens, target_probs, draft_probs, temperature=1.0)

  • draft_tokens: list of K candidate token IDs (from draft heads)
  • target_probs: (K, vocab_size) — target model probabilities for each position
  • draft_probs: (K, vocab_size) — draft head probabilities for each position
  • temperature: sampling temperature (applied to draft probs for acceptance)
  • Returns: (accepted_tokens, n_accepted)
    • accepted_tokens: list of accepted token IDs (may be fewer than K)
    • n_accepted: number of accepted tokens (int)

Speculative Decoding Acceptance Criterion

For each position i in sequence:

token = draft_tokens[i]
p_target = target_probs[i, token]
p_draft  = draft_probs[i, token]

# Accept with probability min(1, p_target / p_draft)
if random() < min(1.0, p_target / p_draft):
    accept token_i, continue to i+1
else:
    reject, stop

Deterministic version (for testing): accept token i if p_target[i, token_i] >= p_draft[i, token_i], otherwise reject and stop.

This problem uses the deterministic version (no randomness) to enable reproducible test cases.

Asked at

Python (numpy)0/3 runs today

Test Results

○all 3 tokens accepted (target_prob >= draft_prob for all)
○second token rejected → only first token accepted
○first token immediately rejected → 0 accepted🔒 Premium
Advertisement