Avoid overflow with softplus function in python

2020-06-12 06:02发布

问题:

I am trying to implement the following softplus function:

log(1 + exp(x))

I've tried it with math/numpy and float64 as data type, but whenever x gets too large (e.g. x = 1000) the result is inf.

Can you assist me on how to successfully handle this function with large numbers?

回答1:

There is a relation which one can use:

log(1+exp(x)) = log(1+exp(x)) - log(exp(x)) + x = log(1+exp(-x)) + x

So a safe implementation, as well as mathematically sound, would be:

log(1+exp(-abs(x))) + max(x,0)

This works both for math and numpy functions (use e.g.: np.log, np.exp, np.abs, np.maximum).



回答2:

Since for x>30 we have log(1+exp(x)) ~= log(exp(x)) = x, a simple stable implementation is

def safe_softplus(x, limit=30):
  if x>limit:
    return x
  else:
    return np.log1p(np.exp(x))

In fact | log(1+exp(30)) - 30 | < 1e-10, so this implementation makes errors smaller than 1e-10 and never overflows. In particular for x=1000 the error of this approximation will be much smaller than float64 resolution, so it is impossible to even measure it on the computer.



回答3:

i use this code to work in arrays

def safe_softplus(x):
    inRanges = (x < 100)
    return np.log(1 + np.exp(x*inRanges))*inRanges + x*(1-inRanges)