Implement the fused softmax-cross-entropy loss as a torch.autograd.Function. The whole point of fusing: the gradient w.r.t. logits collapses to (y - target) / batch, dodging the dense softmax Jacobian entirely.
The rule: you may NOT call F.cross_entropy, F.nll_loss, F.log_softmax, F.softmax, nn.CrossEntropyLoss. Implement softmax + log + index gather + mean by hand, and write the analytic backward.
Forward: logits (..., C) (any leading dims), integer targets (...) → scalar loss = mean over all positions of -log(softmax(logits)[..., target]). Typical shapes are (B, C) for classification and (B, T, V) for token-level CE; your code should handle both.
Backward: dL/dlogits = (softmax(logits) - one_hot(target)) / N where N is the total number of target positions (i.e. logits.numel() / C).
The driver sxce_run(mode, logits, targets) dispatches 'loss' | 'grad_logits' | 'gradcheck'.
Math
Related problems
Asked at
import numpy as np
def sxce_run(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?