Optimize/ Vectorize Mahalanobis distance calculati

2019-03-04 07:32发布

I have the following piece of Matlab code, which calculates Mahalanobis distances between a vector and a matrix with several iterations. I am trying to find a faster method to do this by vectorization but without success.

S.data=0+(20-0).*rand(15000,3);
S.a=0+(20-0).*rand(2500,3);

S.resultat=ones(length(S.data),length(S.a))*nan;
S.b=ones(length(S.a),3,length(S.a))*nan;

for i=1:length(S.data)
    for j=1:length(S.a)
         S.a2=S.a;
         S.a2(j,:)=S.data(i,:);
         S.b(:,:,j)=S.a2;
           if j==length(S.a)
              for k=1:length(S.a);
                   S.resultat(i,k)=mahal(S.a(k,:),S.b(:,:,k));
              end
           end    
    end   
end

I have now modified the code and avoid one of the loop. But it is still very long. If someone have an idea, I will be very greatful!

S.data=0+(20-0).*rand(15000,3);
S.a=0+(20-0).*rand(2500,3);

S.resultat=ones(length(S.data),length(S.a))*nan;
   for i=1:length(S.data)
       for j=1:length(S.a)
       S.a2=S.a;
       S.a2(j,:)=S.data(i,:);
       S.resultat(i,j)=mahal(S.a(j,:),S.a2);    
       end   
   end

1条回答
唯我独甜
2楼-- · 2019-03-04 08:20

Introduction and solution code

You can replace the innermost loop that uses mahal with something that is a bit vectorized, as it uses some pre-calculated values (with the help of bsxfun) inside a loop-shortened and hacked version of mahal.

Basically you have a 2D array, let's call it A for easy reference and a 3D array, let's call it B. Let the output be stored be into a variable out. So, the innermost code snippet could be extracted and based on the assumed variable names.

Original loopy code

for k=1:size(A,1)
    out(k)=mahal(A(k,:),B(:,:,k));
end

So, what I did was to hack into mahal.m and look for portions that could be vectorized when the inputs are 2D and 3D. Now, mahal uses qr inside it, which could not be vectorized. Thus, we end up with a hacked code.

Hacked code

%// Pre-calculate certain values that could be avoided than using into loop
meanB = mean(B,1); %// mean of B along dim-1
B_meanB = bsxfun(@minus,B,meanB); %// B minus mean values of B
A_B_meanB = A' - reshape(meanB,size(B,2),[]); %//'# A minus B_meanB 

%// QR calculations in a for-loop starts until the output is obtained
for k = 1:size(A,1)
    [~,R] = qr(B_meanB(:,:,k),0);
    out2(k) = sum((R'\A_B_meanB(:,k)).^2)*(size(A,1)-1);
end

Now, to extend this hack solution to the problem code, one can introduce few more tweaks to pre-calculate more values being used those nested loops.

Final solution code

A = S.a; %// Get data from S
[rx,cx] = size(A); %// Get size parameters
Atr = A'; %//'# Pre-calculate transpose of A

%// Pre-calculate replicated B and the indices to be modified at each iteration
B_rep = repmat(S.a,1,1,rx);
B_idx = bsxfun(@plus,[(0:cx-1)*rx + 1]',[0:rx-1]*(rx*cx+1)); %//'

out = zeros(size(S.data,1),rx); %// initialize output array
for i=1:length(S.data)

    B = B_rep;
    B(B_idx) = repmat(S.data(i,:)',1,rx); %//'
    meanB = mean(B,1); %// mean of B along dim-1

    B_meanB = bsxfun(@minus,B,meanB); %// B minus mean values of B
    A_B_meanB = Atr - reshape(meanB,3,[]); %// A minus B_meanB
    for jj = 1:rx
        [~,R] = qr(B_meanB(:,:,jj),0);
        out(i,jj) = sum((R'\A_B_meanB(:,jj)).^2)*(rx-1); %//'
    end

end
S.resultat = out;

Benchmarking

Here's the benchmarking code to compare the proposed solution against the code listed in the problem -

%// Random inputs
S.data=0+(20-0).*rand(1500,3); %(size 10x reduced for a quicker runtime test)
S.a=0+(20-0).*rand(250,3);

S.resultat=ones(length(S.data),length(S.a))*nan;
disp('----------------------------- With original code')
tic

S.b=ones(length(S.a),3,length(S.a))*nan;
for i=1:length(S.data)
    for j=1:length(S.a)
        S.a2=S.a;
        S.a2(j,:)=S.data(i,:);
        S.b(:,:,j)=S.a2;
        if j==length(S.a)
            for k=1:length(S.a);
                S.resultat(i,k)=mahal(S.a(k,:),S.b(:,:,k));
            end
        end
    end
end

toc, clear i j S.a2 k S.resultat

S.resultat=ones(length(S.data),length(S.a))*nan;
disp('----------------------------- With proposed solution code')
tic

[ ... Proposed solution code ...]

toc

Runtimes -

----------------------------- With original code
Elapsed time is 17.734394 seconds.
----------------------------- With proposed solution code
Elapsed time is 6.602860 seconds.

Thus, we might get around 2.7x speedup with the proposed approach and some tweaks!

查看更多
登录 后发表回答