TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

214. LogSumExp Trick

Easy

Compute the log-sum-exp of a vector along its last axis, in a numerically stable way.

Signature: def logsumexp(x: np.ndarray) -> np.ndarray

The mathematical identity is:

LSE(x) = log(sum(exp(x))) = max(x) + log(sum(exp(x - max(x))))

Both forms are equivalent in real arithmetic, but the right-hand "shifted" form is what every production library uses. Why? Consider x = [1000, 1001, 1002]:

  • Naive: exp(1000) overflows to inf in float64. log(inf + inf + inf) = inf. Wrong.
  • Stable: subtract max = 1002 first, get exp([-2, -1, 0]) = [0.135, 0.368, 1.0], sum = 1.503, log(1.503) ≈ 0.4076, plus 1002 = 1002.4076. Correct.

The shift cancels exactly in real math (you add max back at the end), but in floating point it pulls the largest exponent down to 0, where exp is well-behaved.

Reduce along the last axis only. Input may be 1D or 2D (or higher). Output has one fewer dimension than input.

Math

LSE(x)=logi∑​exi​=m+logi∑​exi​−m,m=imax​xi​

Asked at

NumPy

import numpy as np

 

def logsumexp(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?