TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

243. InfoNCE (PyTorch)

Medium

Implement InfoNCE for one query against one positive and several negatives.

Signature: def info_nce(query: torch.Tensor, positives: torch.Tensor, negatives: torch.Tensor, tau: float = 0.07) -> torch.Tensor

The rule: you may NOT call F.cross_entropy, F.log_softmax, F.softmax, or any high-level loss. Implement log-sum-exp by hand.

Allowed primitives: @ (matmul/dot), .exp(), .log(), .max(), .sum(), torch.cat, indexing.

Inputs (assume already L2-normalized — do not renormalize):

  • query: shape (d,)
  • positives: shape (d,) — a single positive key
  • negatives: shape (n, d)

Logits: form one logit for the positive (dot product of query with positives, scaled by 1/tau) and one logit per negative (dot product of query with each row of negatives, scaled the same way).

Loss: the standard InfoNCE / softmax-cross-entropy loss with the positive as the target — see the math reference below. Implement it via -(positive logit) + logsumexp(all logits) and use a numerically stable log-sum-exp (subtract the max before exp, then add it back outside the log).

Math

L=−logexp(q⋅k+/τ)+∑i​exp(q⋅ki−​/τ)exp(q⋅k+/τ)​

Related problems

  • InfoNCE Contrastive LossmediumNumPy

Asked at

NumPy

import numpy as np

 

def info_nce(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?