可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I am having a difficulty with applying a function to an array when the function contains a condition. I have an inefficient workaround and am looking for an efficient (fast) approach. In a simple example:
pts = np.linspace(0,1,11)
def fun(x, y):
if x > y:
return 0
else:
return 1
Now, if I run:
result = fun(pts, pts)
then I get the error
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
raised at the if x > y
line. My inefficient workaround, which gives the correct result but is too slow is:
result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
for j in range(len(pts)):
result[i,j] = fun(pts[i], pts[j])
What is the best way to obtain this in a nicer (and more importantly, faster) way?
I am having a difficulty with applying a function to an array when the function contains a condition. I have an inefficient workaround and am looking for an efficient (fast) approach. In a simple example:
pts = np.linspace(0,1,11)
def fun(x, y):
if x > y:
return 0
else:
return 1
Now, if I run:
result = fun(pts, pts)
then I get the error
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
raised at the if x > y
line. My inefficient workaround, which gives the correct result but is too slow is:
result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
for j in range(len(pts)):
result[i,j] = fun(pts[i], pts[j])
What is the best way to obtain this in a nicer (and more importantly, faster) way?
EDIT: using
def fun(x, y):
if x > y:
return 0
else:
return 1
x = np.array(range(10))
y = np.array(range(10))
xv,yv = np.meshgrid(x,y)
result = fun(xv, yv)
still raises the same ValueError
.
回答1:
The error is quite explicit - suppose you have
x = np.array([1,2])
y = np.array([2,1])
such that
(x>y) == np.array([0,1])
what should be the result of your if np.array([0,1])
statement? is it true or false? numpy
is telling you this is ambiguous. Using
(x>y).all()
or
(x>y).any()
is explicit, and thus numpy
is offering you solutions - either any cell pair fulfills the condition, or all of them - both an unambiguous truth value. You have to define for yourself exactly what you meant by vector x is larger than vector y.
The numpy
solution to operate on all pairs of x
and y
such that x[i]>y[j]
is to use mesh grid to generate all pairs:
>>> import numpy as np
>>> x=np.array(range(10))
>>> y=np.array(range(10))
>>> xv,yv=np.meshgrid(x,y)
>>> xv[xv>yv]
array([1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8,
9, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 6, 7, 8, 9, 7, 8, 9, 8, 9, 9])
>>> yv[xv>yv]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8])
either send xv
and yv
to fun
, or create the mesh in the function, depending on what makes more sense. This generates all pairs xi,yj
such that xi>yj
. If you want the actual indices just return xv>yv
, where each cell ij
corresponds x[i]
and y[j]
. In your case:
def fun(x, y):
xv,yv=np.meshgrid(x,y)
return xv>yv
will return a matrix where fun(x,y)[i][j]
is True if x[i]>y[j]
, or False otherwise. Alternatively
return np.where(xv>yv)
will return a tuple of two arrays of pairs of the indices, such that
for i,j in fun(x,y):
will guarantee x[i]>y[j]
as well.
回答2:
In [253]: x = np.random.randint(0,10,5)
In [254]: y = np.random.randint(0,10,5)
In [255]: x
Out[255]: array([3, 2, 2, 2, 5])
In [256]: y
Out[256]: array([2, 6, 7, 6, 5])
In [257]: x>y
Out[257]: array([ True, False, False, False, False])
In [258]: np.where(x>y,0,1)
Out[258]: array([0, 1, 1, 1, 1])
For a cartesian comparison to these two 1d arrays, reshape one so it can use broadcasting
:
In [259]: x[:,None]>y
Out[259]:
array([[ True, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[ True, False, False, False, False]])
In [260]: np.where(x[:,None]>y,0,1)
Out[260]:
array([[0, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 1]])
Your function, with the if
only works for scalar inputs. If given arrays, the a>b
produces a boolean array, which cannot be used in an if
statement. Your iteration works because it passes scalar values. For some complex functions that's the best you can do (np.vectorize
can make the iteration simpler, but not faster).
My answer is to look at the array comparison, and derive the answer from that. In this case, the 3 argument where
does a nice job of mapping the boolean array onto the desired 1/0. There are other ways of doing this mapping as well.
Your double loop requires an added layer of coding, the broadcasted None
.
回答3:
For a more complex example or if the arrays you are dealing with are a bit larger, or if you can write to a already preallocated array you could consider Numba
.
Example
import numba as nb
import numpy as np
@nb.njit()
def fun(x, y):
if x > y:
return 0
else:
return 1
@nb.njit(parallel=False)
#@nb.njit(parallel=True)
def loop(x,y):
result=np.empty((x.shape[0],y.shape[0]),dtype=np.int32)
for i in nb.prange(x.shape[0]):
for j in range(y.shape[0]):
result[i,j] = fun(x[i], y[j])
return result
@nb.njit(parallel=False)
def loop_preallocated(x,y,result):
for i in nb.prange(x.shape[0]):
for j in range(y.shape[0]):
result[i,j] = fun(x[i], y[j])
return result
Timings
x = np.array(range(1000))
y = np.array(range(1000))
#Compilation overhead of the first call is neglected
res=np.where(x[:,None]>y,0,1) -> 2.46ms
loop(single_threaded) -> 1.23ms
loop(parallel) -> 1.0ms
loop(single_threaded)* -> 0.27ms
loop(parallel)* -> 0.058ms
*Maybe influenced by cache. Test on your own examples.