Efficient implementation of a sequence of matrix-v

2019-05-06 20:19发布

问题:

I have a special algorithm where as one of the lasts steps I need to carry out a multiplication of a 3-D array with a 2-D array such that each matrix-slice of the 3-D array is multiplied wich each column of the 2-D array. In other words, if, say A is an N x N x N matrix and B is an N x N matrix, I need to compute a matrix C of size N x N where C(:,i) = A(:,:,i)*B(:,i);.

The naive way to implement this is a loop, i.e.,

C = zeros(N,N);
for i = 1:N
    C(:,i) = A(:,:,i)*B(:,i);
end

However, loops aren't the fastest in Matlab and should be avoided. I'm looking for faster ways of doing this. Right now, what I do is to use the fact that (now Mathjax would be great!):

[A1 b1, A2 b2, ..., AN bN] = [A1, A2, ..., AN]*blkdiag(b1,b2,...,bN)

This allows to get rid of the loop, however, we have to create a block-diagonal matrix of size N^2 x N. I'm making it via sparse to be efficient, i.e., like this:

A_long = reshape(A,N,N^2);
b_cell = mat2cell(B,N,ones(1,N)); % convert matrix to cell array of vectors
b_cell{1} = sparse(b_cell{1});    % make first element sparse, this is enough to trigger blkdiag into sparse mode
B_blk = blkdiag(b_cell{:});
C = A_long*B_blk;

According to my benchmarks, this approach is faster than the loop by a factor of around two (for large N), despite the necessary preparations (the multiplication alone is 3 to 4-fold faster than the loop).

Here is a quick benchmark I did, varying the problem size N and measuring the time for the loop and the alternative approach (with and without the preparation steps). For large N the speedup is around 2...2.5.

Still, this looks awfully complicated to me. Is there a simpler or better way to achieve this? This looks like it's a quite generic/standard problem so I could imagine that solutions are around, I just don't know what to search for really.

P.S.: blkdiag(A1,...,AN)*B is an obvious alternative but here the block diagonal is already N^2 x N^2 so I don't think it can be better than what I did.

edit: Thanks to everyone for commenting! I have carried out a new benchmark on a Matlab R2016b. Unfortunately, I do not have both versions on the same computer so we cannot compare the absolute numbers but the relative comparison is still interesting, since it has changed a bit. Here it is:

And here is a zoom on the high-N area:

Couple of observations:

  • SumRepDot is the solution proposed by Divakar, namely, to use squeeze(sum(bsxfun(@times,A,permute(B,[3,1,2])),2)) which on R2016b simplifies to squeeze(sum(A.*permute(B,[3,1,2]),2)). It is faster than the loop for high N by a factor of around 1.2...1.4.
  • The loop is still "slow" in a sense that the multiplication with the sparse block diagonal matrix is much faster.
  • For the latter, the preparation overhead seems to become negligible for high N which makes it overall a factor of 3...4 faster than the loop. This is a nice result.