Implement softmax in a way that does not overflow on large inputs.
Signature: def softmax(x: np.ndarray) -> np.ndarray
Example:
[1.0, 2.0, 3.0][0.0900, 0.2447, 0.6652]Requirements:
[1000.0, 1001.0, 1002.0]) — a naive exp(x) / exp(x).sum() returns nan hereHint: softmax(x) is shift-invariant: softmax(x) = softmax(x - c) for any constant c. Choose c so exp cannot overflow.
Math
Asked at
Test Results