Declaring a numpy boolean mask in Cython

2019-08-23 12:08发布

问题:

How should I declare the type of a boolean mask in Cython? Do I actually need to declare it? Here is the example:

cpdef my_func(np.ndarray[np.double_t, ndim = 2] array_a,
            np.ndarray[np.double_t, ndim = 2] array_b,
            np.ndarray[np.double_t, ndim = 2] array_c):

    mask = ((array_a > 1) & (array_b == 2) & (array_c == 3)
    array_a[mask] = 0.
    array_b[mask] = array_c[mask]
    return array_a, array_b, array_c

回答1:

You need to "cast" np.uint8_t to bool via np.ndarray[np.uint8_t, ndim = 2, cast=True] mask = ..., i.e.

cimport numpy as np
cpdef my_func(np.ndarray[np.double_t, ndim = 2] array_a,
            np.ndarray[np.double_t, ndim = 2] array_b,
            np.ndarray[np.double_t, ndim = 2] array_c):
    cdef np.ndarray[np.uint8_t, ndim = 2, cast=True] mask = (array_a > 1) & (arr
ay_b == 2) & (array_c == 3)
    array_a[mask] = 0.
    array_b[mask] = array_c[mask]
    return array_a, array_b, array_c

otherwise (without cast=True) the code compiles but throws during the runtime because of the type mismatch.

However, you don't need to define the type of mask at all and can use it as a python-object: there will be some performance penalty or, more precise, a missed opportunity to speed things a little bit up by early type binding, but in your case it probably doesn't matter anyway.


One more thing: I don't know how you real code looks like, but I hope you are aware, that cython won't speedup your example at all - there is nothing to gain compared to numpy.


We can easily verify, that a bool-np.array uses 8bit per a value (at least on my system). This is not obvious at all, for example it could use only a bit per value (a lot like a bitset):

import sys
import numpy as np
a=np.random.random((10000,))
sys.getsizeof(a)
>>> 80096
sys.getsizeof(a<.5)
>>> 10096

It is pretty obvious the double array needs 8 bytes per element + 86 bytes overhead, the mask needs only one byte per element.

We can also see, that False is represented by 0 and True by 1:

print (a<.5).view(np.uint8)
[1 0 1 ..., 0 0 1]

Using cast=True make it possible to access the raw bytes in the underlying array, a kind of reinterpret_cast of the array-memory.

Here is some, albeit old, information.