Numpy grouping using itertools.groupby performance

2020-01-28 02:08发布

I have many large (>35,000,000) lists of integers that will contain duplicates. I need to get a count for each integer in a list. The following code works, but seems slow. Can anyone else better the benchmark using Python and preferably Numpy?

def group():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

if __name__=='__main__':
    from timeit import Timer
    t = Timer("group()","from __main__ import group")
    print t.timeit(number=1)

which returns:

$ python bench.py 
111.377498865

Cheers!

Edit based on responses:

def group_original():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

def group_gnibbler():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,sum(1 for i in g)) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

def group_christophe():
    import numpy as np
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    counts=values.searchsorted(values, side='right') - values.searchsorted(values, side='left')
    index = np.zeros(len(values),dtype='u4,u2')
    index['f0']=values
    index['f1']=counts
    #Erroneous result!

def group_paul():
    import numpy as np
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    diff = np.concatenate(([1],np.diff(values)))
    idx = np.concatenate((np.where(diff)[0],[len(values)]))
    index = np.empty(len(idx)-1,dtype='u4,u2')
    index['f0']=values[idx[:-1]]
    index['f1']=np.diff(idx)

if __name__=='__main__':
    from timeit import Timer
    timings=[
                ("group_original","Original"),
                ("group_gnibbler","Gnibbler"),
                ("group_christophe","Christophe"),
                ("group_paul","Paul"),
            ]
    for method,title in timings:
        t = Timer("%s()"%method,"from __main__ import %s"%method)
        print "%s: %s secs"%(title,t.timeit(number=1))

which returns:

$ python bench.py 
Original: 113.385262966 secs
Gnibbler: 71.7464978695 secs
Christophe: 27.1690568924 secs
Paul: 9.06268405914 secs

Although Christophe gives incorrect results currently

10条回答
相关推荐>>
2楼-- · 2020-01-28 02:59

By request, here is a Cython version of this. I did two passes through the array. The first one finds out how many unique elements there are so that can my arrays for the unique values and counts of the appropriate size.

import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
def dogroup():
    cdef unsigned long tot = 1
    cdef np.ndarray[np.uint32_t, ndim=1] values = np.array(np.random.randint(35000000,size=35000000),dtype=np.uint32)
    cdef unsigned long i, ind, lastval
    values.sort()
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            tot += 1
    cdef np.ndarray[np.uint32_t, ndim=1] vals = np.empty(tot,dtype=np.uint32)
    cdef np.ndarray[np.uint32_t, ndim=1] count = np.empty(tot,dtype=np.uint32)
    vals[0] = values[0]
    ind = 1
    lastval = 0
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            vals[ind] = values[i]
            count[ind-1] = i - lastval
            lastval = i
            ind += 1
    count[ind-1] = len(values) - lastval

The sorting is actually taking the most time here by far. Using the values array given in my code, the sorting is taking 4.75 seconds and the actual finding of the unique values and counts takes .67 seconds. With the pure Numpy code using Paul's code (but with the same form of the values array) with the fix I suggested in a comment, finding the unique values and counts takes 1.9 seconds (sorting still takes the same amount of time of course).

It makes sense for most of the time to be taken up by the sorting because it is O(N log N) and the counting is O(N). You can speed up the sort a little bit over Numpy's (which uses C's qsort if I remember correctly), but you have to really know what you are doing and it probably isn't worthwhile. Also, there might be some way to speed up my Cython code a little bit more, but it probably isn't worthwhile.

查看更多
啃猪蹄的小仙女
3楼-- · 2020-01-28 02:59

Sorting is theta(NlogN), I'd go for amortized O(N) provided by Python's hashtable implementation. Just use defaultdict(int) for keeping counts of the integers and just iterate over the array once:

counts = collections.defaultdict(int)
for v in values:
    counts[v] += 1

This is theoretically faster, unfortunately I have no way to check now. Allocating the additional memory might make it actually slower than your solution, which is in-place.

Edit: If you need to save memory try radix sort, which is much faster on integers than quicksort (which I believe is what numpy uses).

查看更多
Bombasti
4楼-- · 2020-01-28 03:05

This is a numpy solution:

def group():
    import numpy as np
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')

    # we sort in place
    values.sort()

    # when sorted the number of occurences for a unique element is the index of 
    # the first occurence when searching from the right - the index of the first
    # occurence when searching from the left.
    #
    # np.dstack() is the numpy equivalent to Python's zip()

    l = np.dstack((values, values.searchsorted(values, side='right') - \
                   values.searchsorted(values, side='left')))

    index = np.fromiter(l, dtype='u4,u2')

if __name__=='__main__':
    from timeit import Timer
    t = Timer("group()","from __main__ import group")
    print t.timeit(number=1)

Runs in about 25 seconds on my machine compared to about 96 for your initial solution (which is a nice improvement).

There might be still room for improvement, I don't use numpy that often.

Edit: added some comments in code.

查看更多
啃猪蹄的小仙女
5楼-- · 2020-01-28 03:05

Replacing len(list(g)) with sum(1 for i in g) gives a 2x speedup

查看更多
登录 后发表回答