log-sum-exp trick why not recursive

2019-02-13 00:29发布

问题:

I have been researching the log-sum-exp problem. I have a list of numbers stored as logarithms which I would like to sum and store in a logarithm.

the naive algorithm is

def naive(listOfLogs):
    return math.log10(sum(10**x for x in listOfLogs))

many websites including: logsumexp implementation in C? and http://machineintelligence.tumblr.com/post/4998477107/ recommend using

def recommend(listOfLogs):
    maxLog = max(listOfLogs)
    return maxLog + math.log10(sum(10**(x-maxLog) for x in listOfLogs))

aka

def recommend(listOfLogs):
    maxLog = max(listOfLogs)
    return maxLog + naive((x-maxLog) for x in listOfLogs)

what I don't understand is if recommended algorithm is better why should we call it recursively? would that provide even more benefit?

def recursive(listOfLogs):
    maxLog = max(listOfLogs)
    return maxLog + recursive((x-maxLog) for x in listOfLogs)

while I'm asking are there other tricks to make this calculation more numerically stable?

回答1:

Some background for others: when you're computing an expression of the following type directly

ln( exp(x_1) + exp(x_2) + ... )

you can run into two kinds of problems:

  • exp(x_i) can overflow (x_i is too big), resulting in numbers that you can't add together
  • exp(x_i) can underflow (x_i is too small), resulting in a bunch of zeroes

If all the values are big, or all are small, we can divide by some exp(const) and add const to the outside of the ln to get the same value. Thus if we can pick the right const, we can shift the values into some range to prevent overflow/underflow.

The OP's question is, why do we pick max(x_i) for this const instead of any other value? Why don't we recursively do this calculation, picking the max out of each subset and computing the logarithm repeatedly?

The answer: because it doesn't matter.

The reason? Let's say x_1 = 10 is big, and x_2 = -10 is small. (These numbers aren't even very large in magnitude, right?) The expression

ln( exp(10) + exp(-10) ) 

will give you a value very close to 10. If you don't believe me, go try it. In fact, in general, ln( exp(x_1) + exp(x_2) + ... ) will give be very close to max(x_i) if some particular x_i is much bigger than all the others. (As an aside, this functional form, asymptotically, actually lets you mathematically pick the maximum from a set of numbers.)

Hence, the reason we pick the max instead of any other value is because the smaller values will hardly affect the result. If they underflow, they would have been too small to affect the sum anyway, because it would be dominated by the largest number and anything close to it. In computing terms, the contribution of the small numbers will be less than an ulp after computing the ln. So there's no reason to waste time computing the expression for the smaller values recursively if they will be lost in your final result anyway.

If you wanted to be really persnickety about implementing this, you'd divide by exp(max(x_i) - some_constant) or so to 'center' the resulting values around 1 to avoid both overflow and underflow, and that might give you a few extra digits of precision in the result. But avoiding overflow is much more important about avoiding underflow, because the former determines the result and the latter doesn't, so it's much simpler just to do it this way.



回答2:

Not really any better to do it recursively. The problem's just that you want to make sure your finite-precision arithmetic doesn't swamp the answer in noise. By dealing with the max on its own, you ensure that any junk is kept small in the final answer because the most significant component of it is guaranteed to get through.

Apologies for the waffly explanation. Try it with some numbers yourself (a sensible list to start with might be [1E-5,1E25,1E-5]) and see what happens to get a feel for it.



回答3:

As you have defined it, your recursive function will never terminate. That's because ((x-maxlog) for x in listOfLogs) still has the same number of elements as listOfLogs.

I don't think that this is easily fixable either, without significantly impacting either the performance or the precision (compared to the non-recursive version).