Implement Medusa speculative decoding: multiple draft heads predict the next K tokens simultaneously, and a single verification pass accepts the longest valid prefix.
Signature: def medusa_verify(draft_tokens, target_probs, draft_probs, temperature=1.0)
draft_tokens: list of K candidate token IDs (from draft heads)target_probs: (K, vocab_size) — target model probabilities for each positiondraft_probs: (K, vocab_size) — draft head probabilities for each positiontemperature: sampling temperature (applied to draft probs for acceptance)(accepted_tokens, n_accepted)
accepted_tokens: list of accepted token IDs (may be fewer than K)n_accepted: number of accepted tokens (int)For each position i in sequence:
token = draft_tokens[i]
p_target = target_probs[i, token]
p_draft = draft_probs[i, token]
# Accept with probability min(1, p_target / p_draft)
if random() < min(1.0, p_target / p_draft):
accept token_i, continue to i+1
else:
reject, stop
Deterministic version (for testing): accept token i if p_target[i, token_i] >= p_draft[i, token_i], otherwise reject and stop.
This problem uses the deterministic version (no randomness) to enable reproducible test cases.
Asked at
Test Results