I have two n
-by-k
-by-3
arrays a
and b
, e.g.,
import numpy as np
a = np.array([
[
[1, 2, 3],
[3, 4, 5]
],
[
[4, 2, 4],
[1, 4, 5]
]
])
b = np.array([
[
[3, 1, 5],
[0, 2, 3]
],
[
[2, 4, 5],
[1, 2, 4]
]
])
and it like to compute the dot-product of all pairs of "triplets", i.e.,
np.sum(a*b, axis=2)
A better way to do that is perhaps einsum
, but I can't seem to get the indices straight.
Any hints here?
You are loosing the third axis on those two
3D
input arrays with that sum-reduction, while keeping the first two axes aligned. Thus, withnp.einsum
, we would have the first two strings identical alongwith the third string being identical too, but would be skipped in the output string notation signalling we are reducing along that axis for both the inputs. Thus, the solution would be -