TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

172. InfoNCE Contrastive Loss

Medium

Implement the InfoNCE contrastive loss for a single query against one positive and several negatives.

Signature: def info_nce(query: np.ndarray, positives: np.ndarray, negatives: np.ndarray, tau: float = 0.07) -> float

Inputs (assume already L2-normalized — do not renormalize):

  • query: shape (d,)
  • positives: shape (d,) — a single positive key
  • negatives: shape (n, d)

Logits:

  • logit_pos = (query . positives) / tau
  • logits_neg[i] = (negatives[i] . query) / tau

Loss: -logit_pos + log(exp(logit_pos) + sum(exp(logits_neg)))

Use log-sum-exp for stability.

Math

L=−logexp(q⋅k+/τ)+∑i​exp(q⋅ki−​/τ)exp(q⋅k+/τ)​

Related problems

  • InfoNCE (PyTorch)mediumPyTorch

Asked at

NumPy

import numpy as np

 

def info_nce(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?