Given two arrays, A
(shape: M X C) and B
(shape: N X C), is there a way to subtract each row of A
from each row of B
without using loops? The final output would be of shape (M N X C).
Example
A = np.array([[ 1, 2, 3],
[100, 200, 300]])
B = np.array([[ 10, 20, 30],
[1000, 2000, 3000],
[ -10, -20, -2]])
Desired result (can have some other shape) (edited):
array([[ -9, -18, -27],
[-999, -1998, -2997],
[ 11, 22, 5],
[ 90, 180, 270],
[-900, -1800, -2700],
[ 110, 220, 302]])
Shape: 6 X 3
(Loop is too slow, and "outer" subtracts each element instead of each row)
It's possible to do it efficiently (without using any loops) by leveraging broadcasting
like:
In [28]: (A[:, np.newaxis] - B).reshape(-1, A.shape[1])
Out[28]:
array([[ -9, -18, -27],
[ -999, -1998, -2997],
[ 11, 22, 5],
[ 90, 180, 270],
[ -900, -1800, -2700],
[ 110, 220, 302]])
Or, for a little faster solution than broadcasting
, we would have to use numexpr like:
In [31]: A_3D = A[:, np.newaxis]
In [32]: import numexpr as ne
# pass the expression for subtraction as a string to `evaluate` function
In [33]: ne.evaluate('A_3D - B').reshape(-1, A.shape[1])
Out[33]:
array([[ -9, -18, -27],
[ -999, -1998, -2997],
[ 11, 22, 5],
[ 90, 180, 270],
[ -900, -1800, -2700],
[ 110, 220, 302]], dtype=int64)
One more least efficient approach would be by using np.repeat and np.tile to match the shapes of both arrays. But, note that this is least efficient because it makes copies when trying to match the shapes.
In [27]: np.repeat(A, B.shape[0], 0) - np.tile(B, (A.shape[0], 1))
Out[27]:
array([[ -9, -18, -27],
[ -999, -1998, -2997],
[ 11, 22, 5],
[ 90, 180, 270],
[ -900, -1800, -2700],
[ 110, 220, 302]])
Using the Kronecker product (numpy.kron
):
>>> import numpy as np
>>> A = np.array([[ 1, 2, 3],
... [100, 200, 300]])
>>> B = np.array([[ 10, 20, 30],
... [1000, 2000, 3000],
... [ -10, -20, -2]])
>>> (m,c) = A.shape
>>> (n,c) = B.shape
>>> np.kron(A,np.ones((n,1))) - np.kron(np.ones((m,1)),B)
array([[ -9., -18., -27.],
[ -999., -1998., -2997.],
[ 11., 22., 5.],
[ 90., 180., 270.],
[ -900., -1800., -2700.],
[ 110., 220., 302.]])