Implement softmax in PyTorch using primitive tensor ops only.
Signature: def softmax(x: torch.Tensor) -> torch.Tensor
The rule: you may NOT call F.softmax, torch.softmax, or nn.Softmax. The point is to write the algorithm yourself — calling the built-in defeats the lesson. We verify your output matches F.softmax(x, dim=-1) so you can compare while you debug.
Allowed primitives: .exp(), .max(), .sum(), basic arithmetic, broadcasting, indexing.
Requirements:
dim=-1)[1000., 1001., 1002.] should not produce nan)torch.Tensor (the harness converts to list automatically)PyTorch idioms you'll learn here vs the NumPy version:
.max(dim=-1, keepdim=True) returns a NamedTuple (values, indices) — access .values to get the tensor. NumPy's .max(axis=-1, keepdims=True) returns the array directly. This is the single most common trip-up porting from NumPy.keepdim not keepdims. dim not axis. PyTorch silently ignores unknown kwargs and falls through to defaults — your output shape will be wrong if you use the NumPy spelling.Hint: softmax is shift-invariant. Subtract the max along dim=-1 before exponentiating to avoid overflow.
Math
Related problems
Asked at
import numpy as np
def softmax(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?