I am curious if anyone can explain what exactly leads to the discrepancy in this particular handling of C versus Fortran ordered arrays in numpy
. See the code below:
system:
Ubuntu 18.10
Miniconda python 3.7.1
numpy 1.15.4
def test_array_sum_function(arr):
idx=0
val1 = arr[idx, :].sum()
val2 = arr.sum(axis=(1))[idx]
print('axis sums:', val1)
print(' ', val2)
print(' equal:', val1 == val2)
print('total sum:', arr.sum())
n = 2_000_000
np.random.seed(42)
rnd = np.random.random(n)
print('Fortran order:')
arrF = np.zeros((2, n), order='F')
arrF[0, :] = rnd
test_array_sum_function(arrF)
print('\nC order:')
arrC = np.zeros((2, n), order='C')
arrC[0, :] = rnd
test_array_sum_function(arrC)
prints:
Fortran order:
axis sums: 999813.1414744433
999813.1414744079
equal: False
total sum: 999813.1414744424
C order:
axis sums: 999813.1414744433
999813.1414744433
equal: True
total sum: 999813.1414744433
This is almost certainly a consequence of numpy sometimes using pairwise summation and sometimes not.
Let's build a diagnostic array:
This strongly suggests that 1D arrays and contiguous axes use pairwise summation while strided axes in a multidimensional array don't.
Note that to see that effect the array has to be large enough, otherwise numpy falls back to ordinary summation.
Floating point math isn't necessarily associative, i.e.
(a+b)+c != a+(b+c)
.Since you're adding along different axes, the order of operations is different, which can affect the final result. As a simple example, consider the matrix whose sum is 1.
(Interestingly,
a.T.sum()
still gives 0, as doesaT = a.T; aT.sum()
, so I'm not sure how exactly this is implemented in the backend)The C order is using the sequence of operations (left-to-right)
1e100 + 1 + (-1e100) + 0
whereas the Fortran order uses1e100 + (-1e100) + 1 + 0
. The problem is that(1e100+1) == 1e100
because floats don't have enough precision to represent that small difference, so the1
gets lost.In general, don't do equality testing on floating point numbers, instead compare using a small epsilon (
if abs(float1 - float2) < 0.00001
ornp.isclose
). If you need arbitrary float precision, use theDecimal
library or fixed-point representation andint
s.