I have two n
-by-k
-by-3
arrays a
and b
, e.g.,
import numpy as np
a = np.array([
[
[1, 2, 3],
[3, 4, 5]
],
[
[4, 2, 4],
[1, 4, 5]
]
])
b = np.array([
[
[3, 1, 5],
[0, 2, 3]
],
[
[2, 4, 5],
[1, 2, 4]
]
])
and it like to compute the dot-product of all pairs of "triplets", i.e.,
np.sum(a*b, axis=2)
A better way to do that is perhaps einsum
, but I can't seem to get the indices straight.
Any hints here?