Implement the InfoNCE contrastive loss for a single query against one positive and several negatives.
Signature: def info_nce(query: np.ndarray, positives: np.ndarray, negatives: np.ndarray, tau: float = 0.07) -> float
Inputs (assume already L2-normalized — do not renormalize):
query: shape (d,)positives: shape (d,) — a single positive keynegatives: shape (n, d)Logits:
logit_pos = (query . positives) / taulogits_neg[i] = (negatives[i] . query) / tauLoss: -logit_pos + log(exp(logit_pos) + sum(exp(logits_neg)))
Use log-sum-exp for stability.
Math
Asked at
Test Results