Cythonized function unexpectedly slow

2019-07-20 20:58发布

问题:

I wanted to speed up a function that I'm using a lot and I though about using cython. However, after trying all the possible cython optimizations that I've been able to find in the documentation, the cython code is about 6 times slower than the python+numpy function. Disappointing!

This is my test code: (forward1 is the python function, forward2 is the cython function)

#geometry.py
def forward1(points, rotation, translation):
    '''points are in columns'''
    return np.dot(rotation, points - translation[:, np.newaxis])

#geometry.pyx
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.float64_t[:,:] forward2(np.float64_t[:,:] points, np.float64_t[:,:] rotation, np.float64_t[:] translation):
    '''points are in columns'''
    cdef unsigned int I, J
    I = points.shape[0]
    J = points.shape[1]
    cdef np.float64_t[:,:] tmp = np.empty((I, J), dtype=np.float64)
    cdef unsigned int i
    for i in range(J):
        tmp[0, i] = points[0, i] - translation[0]        
        tmp[1, i] = points[1, i] - translation[1]        
    cdef np.float64_t[:,:] result = np.dot(rotation, tmp)
    return result

def test_forward2(points, rotation, translation):
    import timeit
    cdef np.float64_t[:,:] points2 = points
    cdef np.float64_t[:,:] rotation2 = rotation
    cdef np.float64_t[:] translation2 = translation
    t = timeit.Timer(lambda: forward2(points2, rotation2, translation2))
    print min(t.repeat(3, 10))

and then I time it:

t = timeit.Timer(lambda: forward1(points, rotation, translation))
print min(t.repeat(3, 10))
0.000368164520751

test_forward2(points, rotation, translation)
0.0023365181969

Is there anything I can do to the cython code to make it faster?

If forward1 can't be sped up in cython, can I hope any speed up using weave?

EDIT:

Just for the record, another thing I've tried to speed up the function is to pass points in fortran order, as my points are stored in columns and there are quite a few of them. I also define the local tmp as fortran order. I think the subtraction part of the function should be faster but numpy.dot seems to require a C order output (anyway to work around this?), so altogether there is no speed up with this either. I also tried to transpose the points so that the subtraction part is faster in C order, but it seems the dot product is still the most expensive part.

Also, I noticed that numpy.dot can't use memoryviews as out argument, even if it's C order, is this a bug?

回答1:

Just glancing at your code, it looks like something (A subtraction of arrays and dot product.) that numpy is already very optimized for.

Cython is great for speeding up cases where numpy often performs poorly (e.g. iterative algorithms where the iteration is written in python), but in this case, the inner loop is already being preformed by a BLAS library.

If you want to speed things up, the first place I'd look is what BLAS/LAPACK/ATLAS/etc libraries numpy is linked against. Using a "tuned" linear algebra library (e.g. ATLAS or Intel's MKL) will make a large (>10x in some cases) difference in cases like this.

To find out what you're currently using have a look at the output of numpy.show_config()