我发现几十个例子来说明如何为矢量化在Python / NumPy的循环。 不幸的是,我不明白我怎么可以用量化的形式减少我简单的for循环的计算时间。 它甚至有可能在这种情况下?
time = np.zeros(185000)
lat1 = np.array(([48.78,47.45],[38.56,39.53],...)) # ~ 200000 rows
lat2 = np.array(([7.78,5.45],[7.56,5.53],...)) # same number of rows as time
for ii in np.arange(len(time)):
pos = np.argwhere( (lat1[:,0]==lat2[ii,0]) and \
(lat1[:,1]==lat2[ii,1]) )
if pos.size:
pos = int(pos)
time[ii] = dtime[pos]
也许找到所有比赛的最快方法是将两个数组进行排序,并通过他们走在一起,像这样的工作例如:
import numpy as np
def is_less(a, b):
# this ugliness is needed because we want to compare lexicographically same as np.lexsort(), from the last column backward
for i in range(len(a)-1, -1, -1):
if a[i]<b[i]: return True
elif a[i]>b[i]: return False
return False
def is_equal(a, b):
for i in range(len(a)):
if a[i] != b[i]: return False
return True
# lat1 = np.array(([48.78,47.45],[38.56,39.53]))
# lat2 = np.array(([7.78,5.45],[48.78,47.45],[7.56,5.53]))
lat1 = np.load('arr.npy')
lat2 = np.load('refarr.npy')
idx1 = np.lexsort( lat1.transpose() )
idx2 = np.lexsort( lat2.transpose() )
ii = 0
jj = 0
while ii < len(idx1) and jj < len(idx2):
a = lat1[ idx1[ii] , : ]
b = lat2[ idx2[jj] , : ]
if is_equal( a, b ):
# do stuff with match
print "match found: lat1=%s lat2=%s %d and %d" % ( repr(a), repr(b), idx1[ii], idx2[jj] )
ii += 1
jj += 1
elif is_less( a, b ):
ii += 1
else:
jj += 1
这可能不是完全符合Python(也许有人使用发电机或itertools能想到一个更好的实现?),但它是很难想象依靠在某一时刻的速度击败这个搜索一个点的任何方法。
这里是一个解决方案。 我真的不知道,它可能向量化它。 如果你想使它耐“浮动比较错误”你应该修改is_less
和is_greater
。 整个算法中仅仅是一个二进制搜索。
import numpy as np
#lexicographicaly compare two points - a and b
def is_less(a, b):
i = 0
while i<len(a):
if a[i]<b[i]:
return True
else:
if a[i]>b[i]:
return False
i+=1
return False
def is_greater(a, b):
i = 0
while i<len(a):
if a[i]>b[i]:
return True
else:
if a[i]<b[i]:
return False
i+=1
return False
def binary_search(a, x, lo=0, hi=None):
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
midval = a[mid]
if is_less(midval, x):
lo = mid+1
elif is_greater(midval, x):
hi = mid
else:
return mid
return -1
def lex_sort(v): #sort by 1 and 2 column respectively
#return v[np.lexsort((v[:,2],v[:,1]))]
order = range(1, v.shape[1])
return v[np.lexsort(tuple(v[:,i] for i in order[::-1]))]
def sort_and_index(arr):
ind = np.indices((len(arr),)).reshape((len(arr), 1))
arr = np.hstack([ind, arr]) # add an index column as first column
arr = lex_sort(arr)
arr_cut = arr[:,1:] # an array to do binary search in
arr_ind = arr[:,:1] # shuffled indices
return arr_ind, arr_cut
#lat1 = np.array(([1,2,3], [3,4,5], [5,6,7], [7,8,9])) # ~ 200000 rows
lat1 = np.arange(1,800001,1).reshape((200000,4))
#lat2 = np.array(([3,4,5], [5,6,7], [7,8,9], [1,2,3])) # same number of rows as time
lat2 = np.arange(101,800101,1).reshape((200000,4))
lat1_ind, lat1_cut = sort_and_index(lat1)
time_arr = np.zeros(200000)
import time
start = time.time()
for ii, elem in enumerate(lat2):
pos = binary_search(lat1_cut, elem)
if pos == -1:
#Not found
continue
pos = lat1_ind[pos][0]
#print "element in lat2 with index",ii,"has position",pos,"in lat1"
print time.time()-start
该评论打印是你有LAT1和LAT2相应指标的地方。 适用于上200000行7秒。