Can I use this parallel iterator pattern with Cyth

2019-05-26 15:30发布

问题:

With C++11 I have been using the following pattern for implementing a graph data structure with parallel iterators. Nodes are just indices, edges are entries in an adjacency data structure. For iterating over all nodes, a function (lambda, closure...) is passed to a parallelForNodes method and called with each node as an argument. Iteration details are nicely encapsulated in the method.

Now I would like to try the same concept with Cython. Cython provides the cython.parallel.prange function which uses OpenMP for parallelizing a loop over a range. For parallelism to work, Python's Global Interpreter Lock needs to be deactivated with the nogil=True parameter. Without the GIL, using Python objects is not allowed, which makes this tricky.

Is it possible to use this approach with Cython?

class Graph:

    def __init__(self, n=0):
        self.n = n
        self.m = 0  
        self.z = n  # max node id
        self.adja = [[] for i in range(self.z)]
        self.deg = [0 for i in range(self.z)]

    def forNodes(self, handle):
        for u in range(self.z):
            handle(u)

    def parallelForNodes(self, handle):
        # first attempt which will fail...
        for u in prange(self.z, nogil=True):
            handle(u)


# usage 

def initialize(u):
    nonlocal ls
    ls[u] = 1

G.parallelForNodes(initialize)

回答1:

Firstly, things cannot be Python objects without the GIL.

from cython.parallel import prange

cdef class Graph:
    cdef int n, m, z

    def __cinit__(self, int n=0):
        self.z = n  # max node id

    cdef void parallelForNodes(self, void (*handle)(int) nogil) nogil:
        cdef int u
        for u in prange(self.z, nogil=True):
            handle(u)

The biggest catch there is that our function pointer was also nogil.

parallelForNodes does not have to be nogil itself, but there's no reason for it not to be.

Then we need a C function to call:

cdef int[100] ls
cdef void initialize(int u) nogil:
    global ls
    ls[u] = 1

and it just works!

Graph(100).parallelForNodes(initialize)

# Print it!
cdef int[:] ls_ = ls
print(list(ls_))