Python, use multiprocessing to further speed up a

2019-05-04 00:12发布

问题:

the code shown here are simplied but triggers the same PicklingError. I know there is a lot discussion on what can and cannot be pickled, but I did find the solution from them.

I write a simple cython script with the following function:

def pow2(int a) : 
    return a**2 

The compilation is working, I can call this function in python script.

However, I am wondering how to use this function with multiprocessing,

from multiprocessing import Pool
from fast import pow2
p = Pool(processes =4 )
y = p.map( pow2, np.arange( 10, dtype=int))

gives me an PicklingError:

dtw is the name of the package, and fast is fast.pyx.

How can I get around this problem? Thanks in advance

回答1:

Instead of using multiprocessing, which implies writting data on disk due to the pickling process you can use the OpenMP wrapper prange. In your case you could use it like shown below.

  • note the use of x*x instead of x**2, avoiding the function call pow(x, 2)):
  • a part of the array is passed to each thread, using double pointers
  • the last thread takes more values when size % num_threads != 0

Code:

#cython: wraparound=False
#cython: boundscheck=False
#cython: cdivision=True
#cython: nonecheck=False
#cython: profile=False
import numpy as np
cimport numpy as np
from cython.parallel import prange

cdef void cpow2(int size, double *inp, double *out) nogil:
    cdef int i
    for i in range(size):
        out[i] = inp[i]*inp[i]

def pow2(np.ndarray[np.float64_t, ndim=1] inp,
         np.ndarray[np.float64_t, ndim=1] out,
         int num_threads=4):
    cdef int thread
    cdef np.ndarray[np.int32_t, ndim=1] sub_sizes, pos
    size = np.shape(inp)[0]
    sub_sizes = np.zeros(num_threads, np.int32) + size//num_threads
    pos = np.zeros(num_threads, np.int32)
    sub_sizes[num_threads-1] += size % num_threads
    pos[1:] = np.cumsum(sub_sizes)[:num_threads-1]
    for thread in prange(num_threads, nogil=True, chunksize=1,
                         num_threads=num_threads, schedule='static'):
        cpow2(sub_sizes[thread], &inp[pos[thread]], &out[pos[thread]])

def main():
    a = np.arange(642312323).astype(np.float64)
    pow2(a, out=a, num_threads=4)