Calculating the outer product for a sequence of nu

2019-07-16 09:33发布

问题:

This question already has an answer here:

  • numpy elementwise outer product 1 answer

I have a list of 3D points p stored in an ndarray with shape (N, 3). I want to compute the outer product for each 3d point with itself:

N = int(1e4)
p = np.random.random((N, 3))
result = np.zeros((N, 3, 3))
for i in range(N):
    result[i, :, :] = np.outer(p[i, :], p[i, :])

Is there a way to compute this outer product without any python-level loops? The problem is that np.outer does not support anything like an axis argument.

回答1:

You can use broadcasting:

p[..., None] * p[:, None, :]

This syntax inserts an axis at the end of the first term (making it Nx3x1) and the middle of the second term (making it Nx1x3). These are then broadcast and yield an Nx3x3 result.



回答2:

A much better solution than my previous one is using np.einsum:

np.einsum('...i,...j', p, p)

which is even faster than the broadcasting approach:

In [ ]: %timeit p[..., None] * p[:, None, :]
514 µs ± 4.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [ ]: %timeit np.einsum('...i,...j', p, p)
169 µs ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

As for how it works I'm not quite sure, I just messed around with einsum until I got the answer I wanted:

In [ ]: np.all(np.einsum('...i,...j', p, p) == p[..., None] * p[:, None, :])
Out[ ]: True


回答3:

You could at least use apply_along_axis:

result = np.apply_along_axis(lambda point: np.outer(point, point), 1, p)

Surprisingly, however, this is in fact slower than your method:

In [ ]: %%timeit N = int(1e4); p = np.random.random((N, 3))
   ...: result = np.apply_along_axis(lambda point: np.outer(point, point), 1, p)
61.5 ms ± 1.84 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]: %%timeit N = int(1e4); p = np.random.random((N, 3))
   ...: result = np.zeros((N, 3, 3))
   ...: for i in range(N):
   ...:     result[i, :, :] = np.outer(p[i, :], p[i, :])
46 ms ± 709 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)