I have a special algorithm where as one of the lasts steps I need to carry out a multiplication of a 3-D array with a 2-D array such that each matrix-slice of the 3-D array is multiplied wich each column of the 2-D array. In other words, if, say A
is an N x N x N
matrix and B
is an N x N
matrix, I need to compute a matrix C
of size N x N
where C(:,i) = A(:,:,i)*B(:,i);
.
The naive way to implement this is a loop, i.e.,
C = zeros(N,N);
for i = 1:N
C(:,i) = A(:,:,i)*B(:,i);
end
However, loops aren't the fastest in Matlab and should be avoided. I'm looking for faster ways of doing this. Right now, what I do is to use the fact that (now Mathjax would be great!):
[A1 b1, A2 b2, ..., AN bN] = [A1, A2, ..., AN]*blkdiag(b1,b2,...,bN)
This allows to get rid of the loop, however, we have to create a block-diagonal matrix of size N^2 x N
. I'm making it via sparse
to be efficient, i.e., like this:
A_long = reshape(A,N,N^2);
b_cell = mat2cell(B,N,ones(1,N)); % convert matrix to cell array of vectors
b_cell{1} = sparse(b_cell{1}); % make first element sparse, this is enough to trigger blkdiag into sparse mode
B_blk = blkdiag(b_cell{:});
C = A_long*B_blk;
According to my benchmarks, this approach is faster than the loop by a factor of around two (for large N), despite the necessary preparations (the multiplication alone is 3 to 4-fold faster than the loop).
Here is a quick benchmark I did, varying the problem size N
and measuring the time for the loop and the alternative approach (with and without the preparation steps). For large N
the speedup is around 2...2.5.
Still, this looks awfully complicated to me. Is there a simpler or better way to achieve this? This looks like it's a quite generic/standard problem so I could imagine that solutions are around, I just don't know what to search for really.
P.S.: blkdiag(A1,...,AN)*B
is an obvious alternative but here the block diagonal is already N^2 x N^2
so I don't think it can be better than what I did.
edit: Thanks to everyone for commenting! I have carried out a new benchmark on a Matlab R2016b. Unfortunately, I do not have both versions on the same computer so we cannot compare the absolute numbers but the relative comparison is still interesting, since it has changed a bit. Here it is:
And here is a zoom on the high-N area:
Couple of observations:
- SumRepDot is the solution proposed by Divakar, namely, to use
squeeze(sum(bsxfun(@times,A,permute(B,[3,1,2])),2))
which on R2016b simplifies tosqueeze(sum(A.*permute(B,[3,1,2]),2))
. It is faster than the loop for highN
by a factor of around 1.2...1.4. - The loop is still "slow" in a sense that the multiplication with the sparse block diagonal matrix is much faster.
- For the latter, the preparation overhead seems to become negligible for high
N
which makes it overall a factor of 3...4 faster than the loop. This is a nice result.