Fast Exp calculation: possible to improve accuracy

2019-01-13 11:28发布

问题:

I am trying out the fast Exp(x) function that previously was described in this answer to an SO question on improving calculation speed in C#:

public static double Exp(double x)
{
  var tmp = (long)(1512775 * x + 1072632447);
  return BitConverter.Int64BitsToDouble(tmp << 32);
}

The expression is using some IEEE floating point "tricks" and is primarily intended for use in neural sets. The function is approximately 5 times faster than the regular Math.Exp(x) function.

Unfortunately, the numeric accuracy is only -4% -- +2% relative to the regular Math.Exp(x) function, ideally I would like to have accuracy within at least the sub-percent range.

I have plotted the quotient between the approximate and the regular Exp functions, and as can be seen in the graph the relative difference appears to be repeated with practically constant frequency.

Is it possible to take advantage of this regularity to improve the accuracy of the "fast exp" function further without substantially reducing the calculation speed, or would the computational overhead of an accuracy improvement outweigh the computational gain of the original expression?

(As a side note, I have also tried one of the alternative approaches proposed in the same SO question, but this approach does not seem to be computationally efficient in C#, at least not for the general case.)

UPDATE MAY 14

Upon request from @Adriano, I have now performed a very simple benchmark. I have performed 10 million computations using each of the alternative exp functions for floating point values in the range [-100, 100]. Since the range of values I am interested in spans from -20 to 0 I have also explicitly listed the function value at x = -5. Here are the results:

      Math.Exp: 62.525 ms, exp(-5) = 0.00673794699908547
Empty function: 13.769 ms
     ExpNeural: 14.867 ms, exp(-5) = 0.00675211846828461
    ExpSeries8: 15.121 ms, exp(-5) = 0.00641270968867667
   ExpSeries16: 32.046 ms, exp(-5) = 0.00673666189488182
          exp1: 15.062 ms, exp(-5) = -12.3333325982094
          exp2: 15.090 ms, exp(-5) = 13.708332516253
          exp3: 16.251 ms, exp(-5) = -12.3333325982094
          exp4: 17.924 ms, exp(-5) = 728.368055056781
          exp5: 20.972 ms, exp(-5) = -6.13293614238501
          exp6: 24.212 ms, exp(-5) = 3.55518353166184
          exp7: 29.092 ms, exp(-5) = -1.8271053775984
      exp7 +/-: 38.482 ms, exp(-5) = 0.00695945286970704

ExpNeural is equivalent to the Exp function specified in the beginning of this text. ExpSeries8 is the formulation that I originally claimed was not very efficient on .NET; when implementing it exactly like Neil it was actually very fast. ExpSeries16 is the analogous formula but with 16 multiplications instead of 8. exp1 through exp7 are the different functions from Adriano's answer below. The final variant of exp7 is a variant where the sign of x is checked; if negative the function returns 1/exp(-x) instead.

Unfortunately, neither of the expN functions listed by Adriano are sufficient in the broader negative value range I am considering. The series expansion approach by Neil Coffey seems to be more suitable in "my" value range, although it is too rapidly diverging with larger negative x, especially when using "only" 8 multiplications.

回答1:

In case anyone wants to replicate the relative error function shown in the question, here's a way using Matlab (the "fast" exponent is not very fast in Matlab, but it is accurate):

t = 1072632447+[0:ceil(1512775*pi)];
x = (t - 1072632447)/1512775;
ex = exp(x);
t = uint64(t);
import java.lang.Double;
et = arrayfun( @(n) java.lang.Double.longBitsToDouble(bitshift(n,32)), t );
plot(x, et./ex);

Now, the period of the error exactly coincides with when the binary value of tmp overflows from the mantissa into the exponent. Let's break our data into bins by discarding the bits that become the exponent (making it periodic), and keeping only the high eight remaining bits (to make our lookup table a reasonable size):

index = bitshift(bitand(t,uint64(2^20-2^12)),-12) + 1;

Now we calculate the mean required adjustment:

relerrfix = ex./et;
adjust = NaN(1,256);
for i=1:256; adjust(i) = mean(relerrfix(index == i)); end;
et2 = et .* adjust(index);

The relative error is decreased to +/- .0006. Of course, other tables sizes are possible as well (for example, a 6-bit table with 64 entries gives +/- .0025) and the error is almost linear in table size. Linear interpolation between table entries would improve the error yet further, but at the expense of performance. Since we've already met the accuracy goal, let's avoid any further performance hits.

At this point it's some trivial editor skills to take the values computed by MatLab and create a lookup table in C#. For each computation, we add a bitmask, table lookup, and double-precision multiply.

static double FastExp(double x)
{
    var tmp = (long)(1512775 * x + 1072632447);
    int index = (int)(tmp >> 12) & 0xFF;
    return BitConverter.Int64BitsToDouble(tmp << 32) * ExpAdjustment[index];
}

The speedup is very similar to the original code -- for my computer, this is about 30% faster compiled as x86 and about 3x as fast for x64. With mono on ideone, it's a substantial net loss (but so is the original).

Complete source code and testcase: http://ideone.com/UwNgx

using System;
using System.Diagnostics;

namespace fastexponent
{
    class Program
    {
        static double[] ExpAdjustment = new double[256] {
            1.040389835,
            1.039159306,
            1.037945888,
            1.036749401,
            1.035569671,
            1.034406528,
            1.033259801,
            1.032129324,
            1.031014933,
            1.029916467,
            1.028833767,
            1.027766676,
            1.02671504,
            1.025678708,
            1.02465753,
            1.023651359,
            1.022660049,
            1.021683458,
            1.020721446,
            1.019773873,
            1.018840604,
            1.017921503,
            1.017016438,
            1.016125279,
            1.015247897,
            1.014384165,
            1.013533958,
            1.012697153,
            1.011873629,
            1.011063266,
            1.010265947,
            1.009481555,
            1.008709975,
            1.007951096,
            1.007204805,
            1.006470993,
            1.005749552,
            1.005040376,
            1.004343358,
            1.003658397,
            1.002985389,
            1.002324233,
            1.001674831,
            1.001037085,
            1.000410897,
            0.999796173,
            0.999192819,
            0.998600742,
            0.998019851,
            0.997450055,
            0.996891266,
            0.996343396,
            0.995806358,
            0.995280068,
            0.99476444,
            0.994259393,
            0.993764844,
            0.993280711,
            0.992806917,
            0.992343381,
            0.991890026,
            0.991446776,
            0.991013555,
            0.990590289,
            0.990176903,
            0.989773325,
            0.989379484,
            0.988995309,
            0.988620729,
            0.988255677,
            0.987900083,
            0.987553882,
            0.987217006,
            0.98688939,
            0.98657097,
            0.986261682,
            0.985961463,
            0.985670251,
            0.985387985,
            0.985114604,
            0.984850048,
            0.984594259,
            0.984347178,
            0.984108748,
            0.983878911,
            0.983657613,
            0.983444797,
            0.983240409,
            0.983044394,
            0.982856701,
            0.982677276,
            0.982506066,
            0.982343022,
            0.982188091,
            0.982041225,
            0.981902373,
            0.981771487,
            0.981648519,
            0.981533421,
            0.981426146,
            0.981326648,
            0.98123488,
            0.981150798,
            0.981074356,
            0.981005511,
            0.980944219,
            0.980890437,
            0.980844122,
            0.980805232,
            0.980773726,
            0.980749562,
            0.9807327,
            0.9807231,
            0.980720722,
            0.980725528,
            0.980737478,
            0.980756534,
            0.98078266,
            0.980815817,
            0.980855968,
            0.980903079,
            0.980955475,
            0.981017942,
            0.981085714,
            0.981160303,
            0.981241675,
            0.981329796,
            0.981424634,
            0.981526154,
            0.981634325,
            0.981749114,
            0.981870489,
            0.981998419,
            0.982132873,
            0.98227382,
            0.982421229,
            0.982575072,
            0.982735318,
            0.982901937,
            0.983074902,
            0.983254183,
            0.983439752,
            0.983631582,
            0.983829644,
            0.984033912,
            0.984244358,
            0.984460956,
            0.984683681,
            0.984912505,
            0.985147403,
            0.985388349,
            0.98563532,
            0.98588829,
            0.986147234,
            0.986412128,
            0.986682949,
            0.986959673,
            0.987242277,
            0.987530737,
            0.987825031,
            0.988125136,
            0.98843103,
            0.988742691,
            0.989060098,
            0.989383229,
            0.989712063,
            0.990046579,
            0.990386756,
            0.990732574,
            0.991084012,
            0.991441052,
            0.991803672,
            0.992171854,
            0.992545578,
            0.992924825,
            0.993309578,
            0.993699816,
            0.994095522,
            0.994496677,
            0.994903265,
            0.995315266,
            0.995732665,
            0.996155442,
            0.996583582,
            0.997017068,
            0.997455883,
            0.99790001,
            0.998349434,
            0.998804138,
            0.999264107,
            0.999729325,
            1.000199776,
            1.000675446,
            1.001156319,
            1.001642381,
            1.002133617,
            1.002630011,
            1.003131551,
            1.003638222,
            1.00415001,
            1.004666901,
            1.005188881,
            1.005715938,
            1.006248058,
            1.006785227,
            1.007327434,
            1.007874665,
            1.008426907,
            1.008984149,
            1.009546377,
            1.010113581,
            1.010685747,
            1.011262865,
            1.011844922,
            1.012431907,
            1.013023808,
            1.013620615,
            1.014222317,
            1.014828902,
            1.01544036,
            1.016056681,
            1.016677853,
            1.017303866,
            1.017934711,
            1.018570378,
            1.019210855,
            1.019856135,
            1.020506206,
            1.02116106,
            1.021820687,
            1.022485078,
            1.023154224,
            1.023828116,
            1.024506745,
            1.025190103,
            1.02587818,
            1.026570969,
            1.027268461,
            1.027970647,
            1.02867752,
            1.029389072,
            1.030114973,
            1.030826088,
            1.03155163,
            1.032281819,
            1.03301665,
            1.033756114,
            1.034500204,
            1.035248913,
            1.036002235,
            1.036760162,
            1.037522688,
            1.038289806,
            1.039061509,
            1.039837792,
            1.040618648
        };

        static double FastExp(double x)
        {
            var tmp = (long)(1512775 * x + 1072632447);
            int index = (int)(tmp >> 12) & 0xFF;
            return BitConverter.Int64BitsToDouble(tmp << 32) * ExpAdjustment[index];
        }

        static void Main(string[] args)
        {
            double[] x = new double[1000000];
            double[] ex = new double[x.Length];
            double[] fx = new double[x.Length];
            Random r = new Random();
            for (int i = 0; i < x.Length; ++i)
                x[i] = r.NextDouble() * 40;

            Stopwatch sw = new Stopwatch();
            sw.Start();
            for (int j = 0; j < x.Length; ++j)
                ex[j] = Math.Exp(x[j]);
            sw.Stop();
            double builtin = sw.Elapsed.TotalMilliseconds;
            sw.Reset();
            sw.Start();
            for (int k = 0; k < x.Length; ++k)
                fx[k] = FastExp(x[k]);
            sw.Stop();
            double custom = sw.Elapsed.TotalMilliseconds;

            double min = 1, max = 1;
            for (int m = 0; m < x.Length; ++m) {
                double ratio = fx[m] / ex[m];
                if (min > ratio) min = ratio;
                if (max < ratio) max = ratio;
            }

            Console.WriteLine("minimum ratio = " + min.ToString() + ", maximum ratio = " + max.ToString() + ", speedup = " + (builtin / custom).ToString());
         }
    }
}


回答2:

Try following alternatives (exp1 is faster, exp7 is more precise).

Code

public static double exp1(double x) { 
    return (6+x*(6+x*(3+x)))*0.16666666f; 
}

public static double exp2(double x) {
    return (24+x*(24+x*(12+x*(4+x))))*0.041666666f;
}

public static double exp3(double x) {
    return (120+x*(120+x*(60+x*(20+x*(5+x)))))*0.0083333333f;
}

public static double exp4(double x) {
    return 720+x*(720+x*(360+x*(120+x*(30+x*(6+x))))))*0.0013888888f;
}

public static double exp5(double x) {
    return (5040+x*(5040+x*(2520+x*(840+x*(210+x*(42+x*(7+x)))))))*0.00019841269f;
}

public static double exp6(double x) {
    return (40320+x*(40320+x*(20160+x*(6720+x*(1680+x*(336+x*(56+x*(8+x))))))))*2.4801587301e-5;
}

public static double exp7(double x) {
  return (362880+x*(362880+x*(181440+x*(60480+x*(15120+x*(3024+x*(504+x*(72+x*(9+x)))))))))*2.75573192e-6;
}

Precision

Function     Error in [-1...1]              Error in [3.14...3.14]

exp1         0.05           1.8%            8.8742         38.40%
exp2         0.01           0.36%           4.8237         20.80%
exp3         0.0016152      0.59%           2.28            9.80%
exp4         0.0002263      0.0083%         0.9488          4.10%
exp5         0.0000279      0.001%          0.3516          1.50%
exp6         0.0000031      0.00011%        0.1172          0.50%
exp7         0.0000003      0.000011%       0.0355          0.15%

Credits
These implementations of exp() have been calculated by "scoofy" using Taylor series from a tanh() implementation of "fuzzpilz" (whoever they are, I just had these references on my code).



回答3:

Taylor series approximations (such as the expX() functions in Adriano's answer) are most accurate near zero and can have huge errors at -20 or even -5. If the input has a known range, such as -20 to 0 like the original question, you can use a small look up table and one additional multiply to greatly improve accuracy.

The trick is to recognize that exp() can be separated into integer and fractional parts. For example:

exp(-2.345) = exp(-2.0) * exp(-0.345)

The fractional part will always be between -1 and 1, so a Taylor series approximation will be pretty accurate. The integer part has only 21 possible values for exp(-20) to exp(0), so these can be stored in a small look up table.



回答4:

I have studied the paper by Nicol Schraudolph where the original C implementation of the above function was defined in more detail now. It does seem that it is probably not possible to substantially approve the accuracy of the exp computation without severely impacting the performance. On the other hand, the approximation is valid also for large magnitudes of x, up to +/- 700, which is of course advantageous.

The function implementation above is tuned to obtain minimum root mean square error. Schraudolph describes how the additive term in the tmp expression can be altered to achieve alternative approximation properties.

"exp" >= exp for all x                      1072693248 -  (-1) = 1072693249
"exp" <= exp for all x                                 - 90253 = 1072602995
"exp" symmetric around exp                             - 45799 = 1072647449
Mimimum possible mean deviation                        - 68243 = 1072625005
Minimum possible root-mean-square deviation            - 60801 = 1072632447

He also points out that at a "microscopic" level the approximate "exp" function exhibits stair-case behavior since 32 bits are discarded in the conversion from long to double. This means that the function is piece-wise constant on a very small scale, but the function is at least never decreasing with increasing x.



回答5:

The following code should address the accuracy requirements, as for inputs in [-87,88] the results have relative error <= 1.73e-3. I do not know C#, so this is C code, but conversion should be failry straightforward.

I assume that since the accuracy requirement is low, the use single-precision computation is fine. A classic algorithm is being used in which the computation of exp() is mapped to computation of exp2(). After argument conversion via multiplication by log2(e), exponentation by the fractional part is handled using a minimax polynomial of degree 2, while exponentation by the integral part of the argument is performed by direct manipulation of the exponent part of the IEEE-754 single-precision number.

The volatile union faciliates re-interpretation of a bit pattern as either an integer or a single-precision floating-point number, needed for the exponent manipulation. It looks like C# offers decidated re-interpretation functions for this, which is much cleaner.

The two potential performance problems are the floor() function and float->int conversion. Traditionally both were slow on x86 due to the need to handle dynamic processor state. But SSE (in particular SSE 4.1) provides instructions that allow these operations to be fast. I do not know whether C# can make use of those instructions.

 /* max. rel. error <= 1.73e-3 on [-87,88] */
 float fast_exp (float x)
 {
   volatile union {
     float f;
     unsigned int i;
   } cvt;

   /* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
   float t = x * 1.442695041f;
   float fi = floorf (t);
   float f = t - fi;
   int i = (int)fi;
   cvt.f = (0.3371894346f * f + 0.657636276f) * f + 1.00172476f; /* compute 2^f */
   cvt.i += (i << 23);                                          /* scale by 2^i */
   return cvt.f;
 }