TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

178. Contrastive Search Token Selection

Hard

Implement the contrastive search selection rule from Su et al. 2022. Among a set of candidate tokens, pick the one that balances model confidence against a degeneration penalty (max cosine similarity to previously generated hidden states).

Signature:

def contrastive_select(logits: np.ndarray, candidate_ids: np.ndarray, hidden_states: np.ndarray, candidate_hiddens: np.ndarray, alpha: float) -> int

  • logits: shape (vocab_size,) — next-token logits
  • candidate_ids: shape (k,) — token ids to consider
  • hidden_states: shape (t, d) — hidden states of the previously generated tokens
  • candidate_hiddens: shape (k, d) — hidden state each candidate would produce if chosen
  • alpha: float in [0, 1]

Score: (1 - alpha) * softmax(logits)[c] - alpha * max_j cos(candidate_hiddens[c], hidden_states[j])

Return the token id from candidate_ids that maximizes this score (break ties by lower index in candidate_ids). Use a stable softmax.

Math

x∗=argcmax​[(1−α)pθ​(c∣x<t​)−αj<tmax​cos(hc​,hj​)]

Asked at

NumPy

import numpy as np

 

def contrastive_select(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?