TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

241. Cross-Entropy (PyTorch)

Medium

Implement F.cross_entropy(logits, targets) in PyTorch using primitive tensor ops only.

Signature: def cross_entropy(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor

The rule: you may NOT call F.cross_entropy, F.nll_loss, F.log_softmax, F.softmax, or nn.CrossEntropyLoss. Implement log-softmax + NLL yourself.

Allowed primitives: .exp(), .log(), .max(...).values, .sum(), .mean(), .gather(...), indexing.

Inputs (matching the standard PyTorch API):

  • logits: shape (N, C) — raw scores, not probabilities
  • targets: shape (N,) — integer class indices in [0, C)

Returns: scalar tensor — the mean over the batch of -log_softmax(logits)[range(N), targets].

Numerical stability: compute log-softmax via the subtract-max trick:

log_softmax(x) = (x - max(x)) - log(sum(exp(x - max(x))))

This is different from the NumPy version, which took already-softmaxed probabilities and one-hot labels. The PyTorch convention takes raw logits + integer targets, exactly matching F.cross_entropy.

Math

L=−N1​i=1∑N​log∑c​exp(xi,c​)exp(xi,yi​​)​

Related problems

  • Cross-Entropy LosseasyNumPy

Asked at

NumPy

import numpy as np

 

def cross_entropy(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?