Compute the log-sum-exp of a vector along its last axis, in a numerically stable way.
Signature: def logsumexp(x: np.ndarray) -> np.ndarray
The mathematical identity is:
LSE(x) = log(sum(exp(x))) = max(x) + log(sum(exp(x - max(x))))
Both forms are equivalent in real arithmetic, but the right-hand "shifted" form is what every production library uses. Why? Consider x = [1000, 1001, 1002]:
exp(1000) overflows to inf in float64. log(inf + inf + inf) = inf. Wrong.max = 1002 first, get exp([-2, -1, 0]) = [0.135, 0.368, 1.0], sum = 1.503, log(1.503) ≈ 0.4076, plus 1002 = 1002.4076. Correct.The shift cancels exactly in real math (you add max back at the end), but in floating point it pulls the largest exponent down to 0, where exp is well-behaved.
Reduce along the last axis only. Input may be 1D or 2D (or higher). Output has one fewer dimension than input.
Math
Asked at
import numpy as np
def logsumexp(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?