Numpy: Difference between dot(a,b) and (a*b).sum()

2019-04-18 12:10发布

问题:

For 1-D numpy arrays, this two expressions should yield the same result (theorically):

(a*b).sum()/a.sum()
dot(a, b)/a.sum()

The latter uses dot() and is faster. But which one is more accurate? Why?

Some context follows.

I wanted to compute the weighted variance of a sample using numpy. I found the dot() expression in another answer, with a comment stating that it should be more accurate. However no explanation is given there.

回答1:

Numpy dot is one of the routines that calls the BLAS library that you link on compile (or builds its own). The importance of this is the BLAS library can make use of Multiply–accumulate operations (usually Fused-Multiply Add) which limit the number of roundings that the computation performs.

Take the following:

>>> a=np.ones(1000,dtype=np.float128)+1E-14 
>>> (a*a).sum()  
1000.0000000000199948
>>> np.dot(a,a)
1000.0000000000199948

Not exact, but close enough.

>>> a=np.ones(1000,dtype=np.float64)+1E-14
>>> np.dot(a,a)
1000.0000000000176  #off by 2.3948e-12
>>> (a*a).sum()
1000.0000000000059  #off by 1.40948e-11

The np.dot(a, a) will be the more accurate of the two as it use approximately half the number of floating point roundings that the naive (a*a).sum() does.

A book by Nvidia has the following example for 4 digits of precision. rn stands for 4 round to the nearest 4 digits:

x = 1.0008
x2 = 1.00160064                    #    true value
rn(x2 − 1) = 1.6006 × 10−4         #    fused multiply-add
rn(rn(x2) − 1) = 1.6000 × 10−4     #    multiply, then add

Of course floating point numbers are not rounded to the 16th decimal place in base 10, but you get the idea.

Placing np.dot(a,a) in the above notation with some additional pseudo code:

out=0
for x in a:
    out=rn(x*x+out)   #Fused multiply add

While (a*a).sum() is:

arr=np.zeros(a.shape[0])   
for x in range(len(arr)):
    arr[x]=rn(a[x]*a[x])

out=0
for x in arr:
    out=rn(x+out)

From this its easy to see that the number is rounded twice as many times using (a*a).sum() compared to np.dot(a,a). These small differences summed can change the answer minutely. Additional exmaples can be found here.