TorchedUp
ProblemsPremium
TorchedUp
Numerically Stable SoftmaxEasy
ProblemsPremium

Numerically Stable Softmax

Implement softmax in a way that does not overflow on large inputs.

Signature: def softmax(x: np.ndarray) -> np.ndarray

Example:

  • Input: [1.0, 2.0, 3.0]
  • Output: [0.0900, 0.2447, 0.6652]

Requirements:

  • Output is a valid probability distribution (non-negative, sums to 1)
  • Must produce correct, finite results for inputs containing very large values (e.g., [1000.0, 1001.0, 1002.0]) — a naive exp(x) / exp(x).sum() returns nan here

Hint: softmax(x) is shift-invariant: softmax(x) = softmax(x - c) for any constant c. Choose c so exp cannot overflow.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○basic 1D
○uniform
○large values stability🔒 Premium
○output sums to 1
○output is non-negative
○shift-invariant: softmax(x) == softmax(x + c)
○preserves argmax
Advertisement