Efficiently calculating weighted distance in MATLA

2019-02-22 06:24发布

问题:

Several posts exist about efficiently calculating pairwise distances in MATLAB. These posts tend to concern quickly calculating euclidean distance between large numbers of points.

I need to create a function which quickly calculates the pairwise differences between smaller numbers of points (typically less than 1000 pairs). Within the grander scheme of the program i am writing, this function will be executed many thousands of times, so even small gains in efficiency are important. The function needs to be flexible in two ways:

  1. On any given call, the distance metric can be euclidean OR city-block.
  2. The dimensions of the data are weighted.

As far as i can tell, no solution to this particular problem has been posted. The statstics toolbox offers pdist and pdist2, which accept many different distance functions, but not weighting. I have seen extensions of these functions that allow for weighting, but these extensions do not allow users to select different distance functions.

Ideally, i would like to avoid using functions from the statistics toolbox (i am not certain the user of the function will have access to those toolboxes).

I have written two functions to accomplish this task. The first uses tricky calls to repmat and permute, and the second simply uses for-loops.

function [D] = pairdist1(A, B, wts, distancemetric)

% get some information about the data
    numA = size(A,1);
    numB = size(B,1);

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else error('Function only accepts "cityblock" and "euclidean" distance')
    end

%   format weights for multiplication
    wts = repmat(wts,[numA,1,numB]);

%   get featural differences between A and B pairs
    A = repmat(A,[1 1 numB]);
    B = repmat(permute(B,[3,2,1]),[numA,1,1]);
    differences = abs(A-B).^r;

%   weigh difference values before combining them
    differences = differences.*wts;
    differences = differences.^(1/r);

%   combine features to get distance
    D = permute(sum(differences,2),[1,3,2]);
end

AND:

function [D] = pairdist2(A, B, wts, distancemetric)

% get some information about the data
    numA = size(A,1);
    numB = size(B,1);

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else error('Function only accepts "cityblock" and "euclidean" distance')
    end

%   use for-loops to generate differences
    D = zeros(numA,numB);
    for i=1:numA
        for j=1:numB
            differences = abs(A(i,:) - B(j,:)).^(1/r);
            differences = differences.*wts;
            differences = differences.^(1/r);    
            D(i,j) = sum(differences,2);
        end
    end
end

Here are the performance tests:

A = rand(10,3);
B = rand(80,3);
wts = [0.1 0.5 0.4];
distancemetric = 'cityblock';


tic
D1 = pairdist1(A,B,wts,distancemetric);
toc

tic
D2 = pairdist2(A,B,wts,distancemetric);
toc

Elapsed time is 0.000238 seconds.
Elapsed time is 0.005350 seconds.

Its clear that the repmat-and-permute version works much more quickly than the double-for-loop version, at least for smaller datasets. But i also know that calls to repmat often slow things down, however. So I am wondering if anyone in the SO community has any advice to offer to improve the efficiency of either function!

EDIT

@Luis Mendo offered a nice cleanup of the repmat-and-permute function using bsxfun. I compared his function with my original on datasets of varying size:

As the data become larger, the bsxfun version becomes the clear winner!

EDIT #2

I have finished writing the function and it is available on github [link]. I ended up finding a pretty good vectorized method for computing euclidean distance [link], so i use that method in the euclidean case, and i took @Divakar's advice for city-block. It is still not as fast as pdist2, but its must faster than either of the approaches i laid out earlier in this post, and easily accepts weightings.

回答1:

You can replace repmat by bsxfun. Doing so avoids explicit repetition, therefore it's more memory-efficient, and probably faster:

function D = pairdist1(A, B, wts, distancemetric)

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else
        error('Function only accepts "cityblock" and "euclidean" distance')
    end

    differences  = abs(bsxfun(@minus, A, permute(B, [3 2 1]))).^r;
    differences = bsxfun(@times, differences, wts).^(1/r);
    D = permute(sum(differences,2),[1,3,2]);

end


回答2:

For r = 1 ("cityblock" case), you can use bsxfun to get elementwise subtractions and then use matrix-multiplication, which must speed up things. The implementation would look something like this -

%// Calculate absolute elementiwse subtractions
absm = abs(bsxfun(@minus,permute(A,[1 3 2]),permute(B,[3 1 2])));

%// Perform matrix multiplications with the given weights and reshape
D = reshape(reshape(absm,[],size(A,2))*wts(:),size(A,1),[]);