Finding intersection of two matrices in Python wit

2019-06-27 06:34发布

I'm looking for the most efficient way of finding the intersection of two different-sized matrices. Each matrix has three variables (columns) and a varying number of observations (rows). For example, matrix A:

a = np.matrix('1 5 1003; 2 4 1002; 4 3 1008; 8 1 2005')
b = np.matrix('7 9 1006; 4 4 1007; 7 7 1050; 8 2 2003'; 9 9 3000; 7 7 1000')

If I set the tolerance for each column as col1 = 1, col2 = 2, and col3 = 10, I would want a function such that it would output the indices in a and b that are within their respective tolerance, for example:

[x1, x2] = func(a, b, col1, col2, col3)
print x1
>> [2 3]
print x2
>> [1 3]

You can see by the indices, that element 2 of a is within the tolerances of element 1 of b.

I'm thinking I could loop through each element of matrix a, check if it's within the tolerances of each element in b, and do it that way. But it seems inefficient for very large data sets.

Any suggestions for alternatives to a looping method for accomplishing this?

1条回答
走好不送
2楼-- · 2019-06-27 06:41

If you don't mind working with NumPy arrays, you could exploit broadcasting for a vectorized solution. Here's the implementation -

# Set tolerance values for each column
tol = [1, 2, 10]

# Get absolute differences between a and b keeping their columns aligned
diffs = np.abs(np.asarray(a[:,None]) - np.asarray(b))

# Compare each row with the triplet from `tol`.
# Get mask of all matching rows and finally get the matching indices
x1,x2 = np.nonzero((diffs < tol).all(2))

Sample run -

In [46]: # Inputs
    ...: a=np.matrix('1 5 1003; 2 4 1002; 4 3 1008; 8 1 2005')
    ...: b=np.matrix('7 9 1006; 4 4 1007; 7 7 1050; 8 2 2003; 9 9 3000; 7 7 1000')
    ...: 

In [47]: # Set tolerance values for each column
    ...: tol = [1, 2, 10]
    ...: 
    ...: # Get absolute differences between a and b keeping their columns aligned
    ...: diffs = np.abs(np.asarray(a[:,None]) - np.asarray(b))
    ...: 
    ...: # Compare each row with the triplet from `tol`.
    ...: # Get mask of all matching rows and finally get the matching indices
    ...: x1,x2 = np.nonzero((diffs < tol).all(2))
    ...: 

In [48]: x1,x2
Out[48]: (array([2, 3]), array([1, 3]))

Large datasizes case : If you are working with huge datasizes that cause memory issues and since you already know that the number of columns is a small number 3, you might want to have a minimal loop of 3 iterations and save huge memory footprint, like so -

na = a.shape[0]
nb = b.shape[0]
accum = np.ones((na,nb),dtype=bool)
for i in range(a.shape[1]):
    accum &=  np.abs((a[:,i] - b[:,i].ravel())) < tol[i]
x1,x2 = np.nonzero(accum)
查看更多
登录 后发表回答