Implement InfoNCE for one query against one positive and several negatives.
Signature: def info_nce(query: torch.Tensor, positives: torch.Tensor, negatives: torch.Tensor, tau: float = 0.07) -> torch.Tensor
The rule: you may NOT call F.cross_entropy, F.log_softmax, F.softmax, or any high-level loss. Implement log-sum-exp by hand.
Allowed primitives: @ (matmul/dot), .exp(), .log(), .max(), .sum(), torch.cat, indexing.
Inputs (assume already L2-normalized — do not renormalize):
query: shape (d,)positives: shape (d,) — a single positive keynegatives: shape (n, d)Logits: form one logit for the positive (dot product of query with positives, scaled by 1/tau) and one logit per negative (dot product of query with each row of negatives, scaled the same way).
Loss: the standard InfoNCE / softmax-cross-entropy loss with the positive as the target — see the math reference below. Implement it via -(positive logit) + logsumexp(all logits) and use a numerically stable log-sum-exp (subtract the max before exp, then add it back outside the log).
Math
Related problems
Asked at
import numpy as np
def info_nce(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?