Removing rows with duplicates in a NumPy array

2019-01-18 04:39发布

问题:

I have a (N,3) array of numpy values:

>>> vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
>>> vals
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 7],
       [0, 4, 5],
       [2, 2, 1],
       [0, 0, 0],
       [5, 4, 3]])

I'd like to remove rows from the array that have a duplicate value. For example, the result for the above array should be:

>>> duplicates_removed
array([[1, 2, 3],
       [4, 5, 6],
       [0, 4, 5],
       [5, 4, 3]])

I'm not sure how to do this efficiently with numpy without looping (the array could be quite large). Anyone know how I could do this?

回答1:

This is an option:

import numpy
vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
a = (vals[:,0] == vals[:,1]) | (vals[:,1] == vals[:,2]) | (vals[:,0] == vals[:,2])
vals = numpy.delete(vals, numpy.where(a), axis=0)


回答2:

Here's an approach to handle generic number of columns and still be a vectorized method -

def rows_uniq_elems(a):
    a_sorted = np.sort(a,axis=-1)
    return a[(a_sorted[...,1:] != a_sorted[...,:-1]).all(-1)]

Steps :

  • Sort along each row.

  • Look for differences between consecutive elements in each row. Thus, any row with at least one zero differentiation indicates a duplicate element. We will use this to get a mask of valid rows. So, the final step is to simply select valid rows off input array, using the mask.

Sample run -

In [49]: a
Out[49]: 
array([[1, 2, 3, 7],
       [4, 5, 6, 7],
       [7, 8, 7, 8],
       [0, 4, 5, 6],
       [2, 2, 1, 1],
       [0, 0, 0, 3],
       [5, 4, 3, 2]])

In [50]: rows_uniq_elems(a)
Out[50]: 
array([[1, 2, 3, 7],
       [4, 5, 6, 7],
       [0, 4, 5, 6],
       [5, 4, 3, 2]])


回答3:

numpy.array([v for v in vals if len(set(v)) == len(v)])

Mind you, this still loops behind the scenes. You can't avoid that. But it should work fine even for millions of rows.



回答4:

Its six years on, but this question helped me, so I ran a comparison for speed for the answers given by Divakar, Benjamin, Marcelo Cantos and Curtis Patrick.

import numpy as np
vals = np.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])

def rows_uniq_elems1(a):
    idx = a.argsort(1)
    a_sorted = a[np.arange(idx.shape[0])[:,None], idx]
    return a[(a_sorted[:,1:] != a_sorted[:,:-1]).all(-1)]

def rows_uniq_elems2(a):
    a = (a[:,0] == a[:,1]) | (a[:,1] == a[:,2]) | (a[:,0] == a[:,2])
    return np.delete(a, np.where(a), axis=0)

def rows_uniq_elems3(a):
    return np.array([v for v in a if len(set(v)) == len(v)])

def rows_uniq_elems4(a):
    return np.array([v for v in a if len(np.unique(v)) == len(v)])

Results:

%timeit rows_uniq_elems1(vals)
10000 loops, best of 3: 67.9 µs per loop

%timeit rows_uniq_elems2(vals)
10000 loops, best of 3: 156 µs per loop

%timeit rows_uniq_elems3(vals)
1000 loops, best of 3: 59.5 µs per loop

%timeit rows_uniq_elems(vals)
10000 loops, best of 3: 268 µs per loop

It seems that using set beats numpy.unique. In my case I needed to do this over a much larger array:

bigvals = np.random.randint(0,10,3000).reshape([3,1000])

%timeit rows_uniq_elems1(bigvals)
10000 loops, best of 3: 276 µs per loop

%timeit rows_uniq_elems2(bigvals)
10000 loops, best of 3: 192 µs per loop

%timeit rows_uniq_elems3(bigvals)
10000 loops, best of 3: 6.5 ms per loop

%timeit rows_uniq_elems4(bigvals)
10000 loops, best of 3: 35.7 ms per loop

The methods without list comprehensions are much faster. However, the number of rows are hard coded, and are difficult to extend to more than three columns, so in my case at least the list comprehension with the set is the best answer.

EDITED because I confused rows and columns in bigvals



回答5:

Identical to Marcelo, but I think using numpy.unique() instead of set() may get across exactly what you are shooting for.

numpy.array([v for v in vals if len(numpy.unique(v)) == len(v)])