Nearest neighbors in a given range

2019-07-13 00:11发布

问题:

I faced the problem of quickly finding the nearest neighbors in a given range.

Example of dataset:

id | string | float
0  |   AA   |  0.1
12 |   BB   |  0.5
2  |   CC   |  0.3
102|   AA   |  1.1
33 |   AA   |  2.8
17 |   AA   |  0.5

For each line, print the number of lines satisfying the following conditions:

  1. string field is equal to current
  2. float field <= current float - del

For this example with del = 1.5:

id | count
0  |  0
12 |  0
2  |  0
102|  2  (string is equal row with id=0,33,17 but only in row id=0,17 float value: 1.1-1.5<=0.1, 1.1-1.5<=0.5)
33 |  0  (string is equal row with id=0,102,17 but 2.8-1.5>=0.1/1.1/1.5)
17 |  1  

To solve this problem, I used a class BallTree with custom metric, but it works for a very long time due to a reverse tree walk (on a large dataset). Can someone suggest other solutions or how you can increase the speed of custom metrics to the speed of the metrics from the sklearn.neighbors.DistanceMetric?

My code:

from sklearn.neighbors import BallTree
def distance(x, y):
    if(x[0]==y[0] and x[1]>y[1]):
        return (x[1] - y[1])
    else:
        return (x[1] + y[1])

tree2 = BallTree(X, leaf_size=X.shape[0], metric=distance) 
mas=tree2.query_radius(X, r=del, count_only = True)