TorchedUp
ProblemsPremium
TorchedUp
Contrastive Search Token SelectionHard
ProblemsPremium

Contrastive Search

Implement the contrastive search selection rule from Su et al. 2022. Among a set of candidate tokens, pick the one that balances model confidence against a degeneration penalty (max cosine similarity to previously generated hidden states).

Signature:

def contrastive_select(logits: np.ndarray, candidate_ids: np.ndarray, hidden_states: np.ndarray, candidate_hiddens: np.ndarray, alpha: float) -> int

  • logits: shape (vocab_size,) — next-token logits
  • candidate_ids: shape (k,) — token ids to consider
  • hidden_states: shape (t, d) — hidden states of the previously generated tokens
  • candidate_hiddens: shape (k, d) — hidden state each candidate would produce if chosen
  • alpha: float in [0, 1]

Score: (1 - alpha) * softmax(logits)[c] - alpha * max_j cos(candidate_hiddens[c], hidden_states[j])

Return the token id from candidate_ids that maximizes this score (break ties by lower index in candidate_ids). Use a stable softmax.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○alpha=0 → pure likelihood
○alpha=1 → pure penalty avoids repetition
○mixed alpha picks balanced candidate🔒 Premium
Advertisement