TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

232. Softmax (PyTorch)

Easy

Implement softmax in PyTorch using primitive tensor ops only.

Signature: def softmax(x: torch.Tensor) -> torch.Tensor

The rule: you may NOT call F.softmax, torch.softmax, or nn.Softmax. The point is to write the algorithm yourself — calling the built-in defeats the lesson. We verify your output matches F.softmax(x, dim=-1) so you can compare while you debug.

Allowed primitives: .exp(), .max(), .sum(), basic arithmetic, broadcasting, indexing.

Requirements:

  • Operate along the last dimension (dim=-1)
  • Output is a valid probability distribution (non-negative, sums to 1 along dim=-1)
  • Numerically stable for large inputs (e.g. [1000., 1001., 1002.] should not produce nan)
  • Returns a torch.Tensor (the harness converts to list automatically)

PyTorch idioms you'll learn here vs the NumPy version:

  • .max(dim=-1, keepdim=True) returns a NamedTuple (values, indices) — access .values to get the tensor. NumPy's .max(axis=-1, keepdims=True) returns the array directly. This is the single most common trip-up porting from NumPy.
  • keepdim not keepdims. dim not axis. PyTorch silently ignores unknown kwargs and falls through to defaults — your output shape will be wrong if you use the NumPy spelling.

Hint: softmax is shift-invariant. Subtract the max along dim=-1 before exponentiating to avoid overflow.

Math

softmax(xi​)=∑j​exj​−max(x)exi​−max(x)​

Related problems

  • Numerically Stable SoftmaxeasyNumPy

Asked at

NumPy

import numpy as np

 

def softmax(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?