可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I am trying to perform the following matrix and tensor multiplication, but batched.
I have a list of x
vectors:
x = np.array([[2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0]])
and the following matrix and tensor:
R = np.array(
[
[1.0, 1.0],
[0.0, 1.0],
]
)
T = np.array(
[
[
[2.0, 0.0],
[0.0, 0.0],
],
[
[0.0, 0.0],
[0.0, 2.0],
]
]
)
The batched matrix multiplication is relatively straightforward:
x.dot(R.T)
However I am struggling with the second part.
I tried using tensordot
but with no success so far. What am I missing?
回答1:
Since cache usage isn't an issue on a sequence of small tensors (as it would be for general dot products of large matrices) it is easy to formulate the problem with simple loops.
Example
import numba as nb
import numpy as np
import time
@nb.njit(fastmath=True,parallel=True)
def tensor_mult(T,x):
res=np.empty((x.shape[0],T.shape[0]),dtype=T.dtype)
for l in nb.prange(x.shape[0]):
for i in range(T.shape[0]):
sum=0.
for j in range(T.shape[1]):
for k in range(T.shape[2]):
sum+=T[i,j,k]*x[l,j]*x[l,k]
res[l,i]=sum
return res
Benchmarking
x = np.random.rand(1000000,6)
T = np.random.rand(6,6,6)
#first call has a compilation overhead (about 0.6s)
res=tensor_mult(T,x)
t1=time.time()
for i in range(10):
#@divakar
#Tx = np.tensordot(T,x,axes=((1),(1)))
#out = np.einsum('ikl,lk->li',Tx,x)
res=tensor_mult(T,x)
print(time.time()-t1)
Results (4C/8T)
Divakars solution: 191ms
Simple loops: 62.4ms
回答2:
We can use a combination of tensor matrix-multiplication
with np.tensordot
and einsum
to basically do it in two steps -
Tx = np.tensordot(T,x,axes=((1),(1)))
out = np.einsum('ikl,lk->li',Tx,x)
Benchmarking
Setup based on OP's comments :
In [1]: import numpy as np
In [2]: x = np.random.rand(1000000,6)
In [3]: T = np.random.rand(6,6,6)
Timings -
# @Han Altae-Tran's soln
In [4]: %%timeit
...: W = np.matmul(T,x.T)
...: ZT = np.sum(W*x.T[np.newaxis,:,:], axis=1).T
...:
1 loops, best of 3: 496 ms per loop
# @Paul Panzer's soln-1
In [5]: %timeit np.einsum('ijk,lj,lk->li', T, x, x)
1 loops, best of 3: 831 ms per loop
# @Paul Panzer's soln-2
In [6]: %timeit ((x[:, None, None, :]@T).squeeze()@x[..., None]).squeeze()
1 loops, best of 3: 1.39 s per loop
# @Paul Panzer's soln-3
In [7]: %timeit np.einsum('ijl,lj->li', T@x.T, x)
1 loops, best of 3: 358 ms per loop
# From this post's soln
In [8]: %%timeit
...: Tx = np.tensordot(T,x,axes=((1),(1)))
...: out = np.einsum('ikl,lk->li',Tx,x)
...:
1 loops, best of 3: 168 ms per loop
回答3:
You can more or less directly translate your formula to an einsum
:
>>> np.einsum('ijk,lj,lk->li', T, x, x)
array([[ 8., 8.],
[18., 18.],
[32., 32.],
[50., 50.]])
Only using @
:
>>> ((x[:, None, None, :]@T).squeeze()@x[..., None]).squeeze()
array([[ 8., 8.],
[18., 18.],
[32., 32.],
[50., 50.]])
Or a hybrid:
>>> np.einsum('ijl,lj->li', T@x.T, x)
array([[ 8., 8.],
[18., 18.],
[32., 32.],
[50., 50.]])
回答4:
As pointed out by Paul, einsum is an easy way to accomplish the task, but if speed is a concern, then it's generally better to stick to typical numpy functions.
This can be accomplished by writing out the equation and translating the steps into matrix operations.
Let X
be the m x d
matrix of data you want to batch over, and Z
be the m x d
result you desire. We will arrive at Z.T
(transpose) because it's easier.
Notice that in order to arrive at the equation for the R
contribution we can write
Then we can this as a numpy matrix multiply R.dot(X.T)
.
Similarly, observe that the T
contribution is
Inside the parenthesis is a batch matrix multiply between T
and X.T
. Thus, if we define the quantity inside the parenthesis to be
We can arrive at it in numpy using W = np.matmul(T,X.T)
. Continuing our simplification, we see that the T
contribution is
Which is equivalent to np.sum(W*X.T[np.newaxis,:,:], axis=1)
. Putting everything together, we end up with
W = np.matmul(T,X.T)
ZT = R.dot(X.T) + np.sum(W*X.T[np.newaxis,:,:], axis=1)
Z = ZT.T
For larger batch sizes, this is about 3-4 times faster than the einsum function when d=2
. If we were to avoid using as many transposes, it could perhaps be even a bit faster.