TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

269. Backprop: Softmax + Cross-Entropy fused (PyTorch)

Medium

Implement the fused softmax-cross-entropy loss as a torch.autograd.Function. The whole point of fusing: the gradient w.r.t. logits collapses to (y - target) / batch, dodging the dense softmax Jacobian entirely.

The rule: you may NOT call F.cross_entropy, F.nll_loss, F.log_softmax, F.softmax, nn.CrossEntropyLoss. Implement softmax + log + index gather + mean by hand, and write the analytic backward.

Forward: logits (..., C) (any leading dims), integer targets (...) → scalar loss = mean over all positions of -log(softmax(logits)[..., target]). Typical shapes are (B, C) for classification and (B, T, V) for token-level CE; your code should handle both.

Backward: dL/dlogits = (softmax(logits) - one_hot(target)) / N where N is the total number of target positions (i.e. logits.numel() / C).

The driver sxce_run(mode, logits, targets) dispatches 'loss' | 'grad_logits' | 'gradcheck'.

Math

L=−B1​b∑​logyb,tb​​,∂xb,i​∂L​=B1​(yb,i​−1i=tb​​)

Related problems

  • Backprop: Softmax + Cross-Entropy (fused)mediumNumPy

Asked at

NumPy

import numpy as np

 

def sxce_run(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?