将函数应用于使用numpy的阵列时该函数包含一个条件(Applying a function to

2019-10-29 16:20发布

我有当所述函数包含的条件将函数应用于阵列的难度。 我有一个变通方法效率低下,并正在寻找一个有效的(快)的方法。 在一个简单的例子:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

现在,如果我运行:

result = fun(pts, pts)

然后我得到的错误

ValueError异常:具有多于一个元素的数组的真值是不明确的。 使用a.any()或a.all()

在凸起if x > y线。 我的解决方法效率低下,从而给出正确的结果,但速度太慢是:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

什么是一个更好的(更重要的是,更快)的方式来获得最好的方法?

我有当所述函数包含的条件将函数应用于阵列的难度。 我有一个变通方法效率低下,并正在寻找一个有效的(快)的方法。 在一个简单的例子:

pts = np.linspace(0,1,11)
def fun(x, y):
    if x > y:
        return 0
    else:
        return 1

现在,如果我运行:

result = fun(pts, pts)

然后我得到的错误

ValueError异常:具有多于一个元素的数组的真值是不明确的。 使用a.any()或a.all()

在凸起if x > y线。 我的解决方法效率低下,从而给出正确的结果,但速度太慢是:

result = np.full([len(pts)]*2, np.nan)
for i in range(len(pts)):
    for j in range(len(pts)):
        result[i,j] = fun(pts[i], pts[j])

什么是一个更好的(更重要的是,更快)的方式来获得最好的方法?

编辑 :使用

def fun(x, y):
    if x > y:
        return 0
    else:
        return 1
x = np.array(range(10))
y = np.array(range(10))
xv,yv = np.meshgrid(x,y)
result = fun(xv, yv)  

仍然引起同一ValueError

Answer 1:

该错误是相当明确的 - 假设你有

x = np.array([1,2])
y = np.array([2,1])

这样

(x>y) == np.array([0,1])

什么应该是你的结果if np.array([0,1])声明? 是真的还是假的? numpy告诉你,这是不明确的。 运用

(x>y).all()

要么

(x>y).any()

是明确的,因此numpy是为你提供的解决方案-无论是任何细胞对满足条件,或所有这些-无论是明确的真值。 你必须精确地定义自己,你通过向量x的意思是不是向量y大

numpy溶液在所有对操作xy使得x[i]>y[j]是使用网格生成所有对:

>>> import numpy as np
>>> x=np.array(range(10))
>>> y=np.array(range(10))
>>> xv,yv=np.meshgrid(x,y)
>>> xv[xv>yv]
array([1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8,
       9, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 6, 7, 8, 9, 7, 8, 9, 8, 9, 9])
>>> yv[xv>yv]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
       2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8])

或者发送xvyvfun ,或在功能创建的网格,根据什么更有意义。 这产生所有对xi,yj使得xi>yj 。 如果你想实际的指数只返回xv>yv ,其中每个单元ij对应x[i]y[j] 你的情况:

def fun(x, y):
    xv,yv=np.meshgrid(x,y)
    return xv>yv

将返回一个矩阵,其中fun(x,y)[i][j]为真,如果x[i]>y[j] ,否则返回false。 另外

return  np.where(xv>yv)

将返回对指数的两个数组,这样的元组

for i,j in fun(x,y):

将保证x[i]>y[j]为好。



Answer 2:

In [253]: x = np.random.randint(0,10,5)
In [254]: y = np.random.randint(0,10,5)
In [255]: x
Out[255]: array([3, 2, 2, 2, 5])
In [256]: y
Out[256]: array([2, 6, 7, 6, 5])
In [257]: x>y
Out[257]: array([ True, False, False, False, False])
In [258]: np.where(x>y,0,1)
Out[258]: array([0, 1, 1, 1, 1])

对于笛卡尔比较这两个一维数组,重塑一个,这样它可以使用broadcasting

In [259]: x[:,None]>y
Out[259]: 
array([[ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False]])
In [260]: np.where(x[:,None]>y,0,1)
Out[260]: 
array([[0, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [0, 1, 1, 1, 1]])

你的功能,与if仅适用于标量输入。 如果给定的阵列中, a>b产生一个布尔阵列,其不能在使用if语句。 你迭代的作品,因为它通过标量值。 对于一些复杂的功能,这就是你能做的最好的( np.vectorize可以使迭代简单,但不是更快)。

我的回答是看排列比较,并从中获得了答案。 在这种情况下,3参数where不布尔阵列映射到所期望的1/0的一个很好的工作。 有这样做的映射以及其他方式。

你的双回路需要编码的附加层,广播None



Answer 3:

对于更复杂的例子,或者如果你正在处理的阵列是有点大,或者如果你可以写一个已经预分配数组,你可以考虑Numba

import numba as nb
import numpy as np

@nb.njit()
def fun(x, y):
  if x > y:
    return 0
  else:
    return 1

@nb.njit(parallel=False)
#@nb.njit(parallel=True)
def loop(x,y):
  result=np.empty((x.shape[0],y.shape[0]),dtype=np.int32)
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

@nb.njit(parallel=False)
def loop_preallocated(x,y,result):
  for i in nb.prange(x.shape[0]):
    for j in range(y.shape[0]):
      result[i,j] = fun(x[i], y[j])
  return result

计时

x = np.array(range(1000))
y = np.array(range(1000))

#Compilation overhead of the first call is neglected

res=np.where(x[:,None]>y,0,1) -> 2.46ms
loop(single_threaded)         -> 1.23ms
loop(parallel)                -> 1.0ms
loop(single_threaded)*        -> 0.27ms
loop(parallel)*               -> 0.058ms

*也许通过高速缓存的影响。 测试你自己的例子。



文章来源: Applying a function to an array using Numpy when the function contains a condition