How to vectorize finding the closest point out of

2019-02-19 13:12发布

问题:

BigList = rand(20, 3)
LittleList = rand(5, 3)

I'd like to find for each row in the big list the 'closest' row in the little list, as defined by the euclidean norm (i.e. sum of squared distances between the corresponding values in the k=3 dimension).

I can see how to do this using two loops, but it seems like there ought to be a better way to do this using built in matrix operations.

回答1:

Approach #1

There is a built in MATLAB function pdist2 which finds "Pairwise distance between two sets of observations". With it, you can calculate the euclidean distance matrix and then find indices of minimum values along the appropriate dimension in the distance matrix that would represent the "closest" for each row of bigList in littleList.

Here's the one-liner with it -

[~,minIdx] = min(pdist2(bigList,littleList),[],2); %// minIdx is what you are after

Approach #2

If you care about performance, here's a method that leverages fast matrix multiplication in MATLAB and most of the code presented here is taken from this smart solution.

dim = 3;
numA = size(bigList,1);
numB = size(littleList,1);

helpA = zeros(numA,3*dim);
helpB = zeros(numB,3*dim);
for idx = 1:dim
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
end
[~,minIdx] = min(helpA * helpB',[],2); %//'# minIdx is what you are after

Benchmarking

Benchmarking Code -

N1 = 1750; N2 = 4*N1; %/ datasize
littleList = rand(N1, 3);
bigList = rand(N2, 3);

for k = 1:50000
    tic(); elapsed = toc(); %// Warm up tic/toc
end

disp('------------- With squeeze + bsxfun + permute based approach [LuisMendo]')
tic
d = squeeze(sum((bsxfun(@minus, bigList, permute(littleList, [3 2 1]))).^2, 2));
[~, ind] = min(d,[],2);
toc,  clear d ind

disp('------------- With double permutes + bsxfun based approach [Shai]')
tic
d = bsxfun( @minus, permute( bigList, [1 3 2] ), permute( littleList, [3 1 2] ) ); %//diff in third dimension
d = sum( d.^2, 3 ); %// sq euclidean distance
[~,minIdx] = min( d, [], 2 );
toc
clear d minIdx

disp('------------- With bsxfun + matrix-multiplication based approach [Shai]')
tic
nb = sum( bigList.^2, 2 ); %// norm of bigList's items
nl = sum( littleList.^2, 2 ); %// norm of littleList's items
d = bsxfun(@plus, nb, nl.' ) - 2 * bigList * littleList'; %// all the distances
[~,minIdx] = min(d,[],2);
toc, clear nb nl d minIdx

disp('------------- With matrix multiplication based approach  [Divakar]')
tic
dim = 3;
numA = size(bigList,1);
numB = size(littleList,1);

helpA = zeros(numA,3*dim);
helpB = zeros(numB,3*dim);
for idx = 1:dim
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
end
[~,minIdx] = min(helpA * helpB',[],2);
toc, clear dim numA numB helpA helpB idx minIdx

disp('------------- With pdist2 based approach [Divakar]')
tic
[~,minIdx] = min(pdist2(bigList,littleList),[],2);
toc, clear minIdx

Benchmark results -

------------- With squeeze + bsxfun + permute based approach [LuisMendo]
Elapsed time is 0.718529 seconds.
------------- With double permutes + bsxfun based approach [Shai]
Elapsed time is 0.971690 seconds.
------------- With bsxfun + matrix-multiplication based approach [Shai]
Elapsed time is 0.328442 seconds.
------------- With matrix multiplication based approach  [Divakar]
Elapsed time is 0.159092 seconds.
------------- With pdist2 based approach [Divakar]
Elapsed time is 0.310850 seconds.

Quick conclusions: The runtimes with Shai's second approach that was a combination of bsxfun and matrix multiplication were very close with the one based on pdist2 and no clear winner could be decided between those two.



回答2:

The proper way is of course using nearest-neighbor searching algorithms.
However, if your dimension is not too high and your data sets are not big than you can simply use bsxfun:

d = bsxfun( @minus, permute( bigList, [1 3 2] ), permute( littleList, [3 1 2] ) ); %//diff in third dimension
d = sum( d.^2, 3 ); %// sq euclidean distance
[minDist minIdx] = min( d, [], 2 );

In addition to Matrix multiplication approach proposed here, there is another matrix multiplication without loops

nb = sum( bigList.^2, 2 ); %// norm of bigList's items
nl = sum( littleList.^2, 2 ); %// norm of littleList's items
d = bsxfun( @sum, nb, nl.' ) - 2 * bigList * littleList'; %// all the distances

The observation behind this method is that for Euclidean distance (L2-norm)

|| a - b ||^2 = ||a||^2 + ||b||^2 - 2<a,b> 

With <a,b> being the dot product of the two vectors.



回答3:

You can do it with bsxfun:

d = squeeze(sum((bsxfun(@minus, BigList, permute(LittleList, [3 2 1]))).^2, 2));
[~, ind] = min(d,[],2);