Implement F.cross_entropy(logits, targets) in PyTorch using primitive tensor ops only.
Signature: def cross_entropy(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor
The rule: you may NOT call F.cross_entropy, F.nll_loss, F.log_softmax, F.softmax, or nn.CrossEntropyLoss. Implement log-softmax + NLL yourself.
Allowed primitives: .exp(), .log(), .max(...).values, .sum(), .mean(), .gather(...), indexing.
Inputs (matching the standard PyTorch API):
logits: shape (N, C) — raw scores, not probabilitiestargets: shape (N,) — integer class indices in [0, C)Returns: scalar tensor — the mean over the batch of -log_softmax(logits)[range(N), targets].
Numerical stability: compute log-softmax via the subtract-max trick:
log_softmax(x) = (x - max(x)) - log(sum(exp(x - max(x))))
This is different from the NumPy version, which took already-softmaxed probabilities and one-hot labels. The PyTorch convention takes raw logits + integer targets, exactly matching F.cross_entropy.
Math
Related problems
Asked at
import numpy as np
def cross_entropy(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?