TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

1. Numerically Stable Softmax

Easy

Implement softmax in a way that does not overflow on large inputs.

Signature: def softmax(x: np.ndarray) -> np.ndarray

Example:

  • Input: [1.0, 2.0, 3.0]
  • Output: [0.0900, 0.2447, 0.6652]

Requirements:

  • Output is a valid probability distribution (non-negative, sums to 1)
  • Must produce correct, finite results for inputs containing very large values (e.g., [1000.0, 1001.0, 1002.0]) — a naive exp(x) / exp(x).sum() returns nan here

Hint: softmax(x) is shift-invariant: softmax(x) = softmax(x - c) for any constant c. Choose c so exp cannot overflow.

Math

softmax(xi​)=∑j​exj​−max(x)exi​−max(x)​

Related problems

  • Softmax (PyTorch)easyPyTorch

Asked at

Python 30/10 runs today

Output

Anything you print() in your code will show up here after you click Run.

Test Results

○basic 1D
○uniform
○large values (overflow check)
○output sums to 1
○output is non-negative
○shift-invariant: softmax(x) == softmax(x + c)
○preserves argmax
○3D batch (B=2, T=2, V=4) — softmax along last axis per (B, T) row