Implement the contrastive search selection rule from Su et al. 2022. Among a set of candidate tokens, pick the one that balances model confidence against a degeneration penalty (max cosine similarity to previously generated hidden states).
Signature:
def contrastive_select(logits: np.ndarray, candidate_ids: np.ndarray, hidden_states: np.ndarray, candidate_hiddens: np.ndarray, alpha: float) -> int
logits: shape (vocab_size,) — next-token logitscandidate_ids: shape (k,) — token ids to considerhidden_states: shape (t, d) — hidden states of the previously generated tokenscandidate_hiddens: shape (k, d) — hidden state each candidate would produce if chosenalpha: float in [0, 1]Score: (1 - alpha) * softmax(logits)[c] - alpha * max_j cos(candidate_hiddens[c], hidden_states[j])
Return the token id from candidate_ids that maximizes this score (break ties by lower index in candidate_ids). Use a stable softmax.
Math
Asked at
Test Results