What is the fastest way to find the closest point

2019-01-06 13:07发布

What is the fastest way to find closest point to the given point in data array?

For example, suppose I have an array A of 3D points (with coordinates x, y and z, as usual) and point (x_p, y_p, z_p). How do I find the closest point in A to (x_p, y_p, z_p)?

As far as I know, slowest way to do it is to use linear search. Are there any better solutions?

Addition of any an auxiliary data structure is possible.

9条回答
祖国的老花朵
2楼-- · 2019-01-06 13:19

I needed to do this rather heavily for the many nearest neighbors search in a real time environment, and hit on a better algorithm both in terms of simplicity and speed.

Take all your points and put a copy into d lists where d is the dimensionality of the space. In your case 3. Sort those three lists according to their dimension. This costs d(nlog(n)) time. And that's it for the data structure.

We maintain these properly sorted lists in each dimension for all the points in question. The trick is that by definition the distance in one direction must be less than or equal to the euclidean distance. So if the distance in one direction is greater than our current closest distance of the closest known point then that point cannot be closer, and more importantly all points in that direction cannot be greater. Once this is true for the 2*d directions we have we by definition have the closest point.

For each particular element we can binarysearch into the sorted lists to find the nearest position where the required point could be in the two different dimensions. Mathematically we know that if the distance in the +x -x +y -y (other dimensions are easy to add) directions exceeds smallest known Euclidean distance to a point, that that point must exceed the distance, and since it's a sorted array, by definition, when we exceed that distance in that direction, we know we can abort that direction, because there can be no better answer in that direction. But, as we expand out in these four directions we can reduce our value of m because it is equal to the euclidean distance of the closest point we found.

So we only need sorted lists for each axis sorted according to that axis. Which is pretty simple.

Then to query the list:

  • We binary search into each of the lists (dlog(n)).
  • We find our current minimum distance, m. (initially it can be infinity)
  • For each list, we travel in the positive and negative directions.
  • For each of the 2*d directions we have,
    • We transverse the lists, lowering m when we find closer points.
  • When a direction proves itself to be mathematically fruitless, we stop searching that way.
  • When no direction remains we have found our closest point.

We have sorted lists and need to find the point we are searching for in each direction in the list. We binary search in to keep our time complexity log(n). Then we have our best current distance (possibly infinity) and then we move in each direction we have available to us. As we find new points, we update the closest point so far we have. The trick is that we quit as soon as the distance in just that one direction is further than our current known closest point.

So if we have a point at a known closest distance of 13 then we can abort checking in the +x, -x, +y, -y, directions as soon as the distance in just that direction exceeds our closest known distance. Because if it is further +x than our current m, all the remaining values of +x can be mathematically be proven to be further away. As we get better and better closest points, the amount of space we need to search gets smaller and smaller.

If we run out of points in a direction, that direction is finished. If the distance to a point along just that one dimension of the line is itself greater than m, that direction is finished.

The solution is m when all directions proven to only have points that must be farther than our best point so far.

-- Since we progressively reduce m, the distance in each dimension needed as a whole drops quickly, though like all algorithms it drops off less quickly in higher dimensions. But, if the distance in just one dimension is greater than the best we have thus far, it must necessarily be the case that all the rest of those points, in that direction, can't be better.

In time complexity seems on par with the better ones. But, in simplicity of the datastructures, this algorithm clearly wins. There's a lot of other properties that make this algorithm a serious contender. When you update stuff, you can resort the lists with really good performance because you are very often sorting already sorted lists or nearly sorted lists. You are iterating arrays. In actual terms of real performance most datastructures suck. Generally because of caching and how memory is laid out, we are supposed to be agnostic about such things but it matters a lot. The data right next to your current relevant data is much faster to actually access. If we already know where the point we're going to be looking for it in the lists, we can solve it even faster (since we wouldn't have to find it with a binary search). And other permitted tricks reusing the information from the previous iteration here and there. And additional dimensions are basically free (save that then the value doesn't converge faster, but that's because there's more randomly distributed points in a sphere than a circle of the same radius).


public class EuclideanNeighborSearch2D {
    public static final int INVALID = -1;
    static final Comparator<Point> xsort = new Comparator<Point>() {
        @Override
        public int compare(Point o1, Point o2) {
            return Double.compare(o1.x, o2.x);
        }
    };
    static final Comparator<Point> ysort = new Comparator<Point>() {
        @Override
        public int compare(Point o1, Point o2) {
            return Double.compare(o1.y, o2.y);
        }
    };

    ArrayList<Point> xaxis = new ArrayList<>();
    ArrayList<Point> yaxis = new ArrayList<>();

    boolean dirtySortX = false;
    boolean dirtySortY = false;

    public Point findNearest(float x, float y, float minDistance, float maxDistance) {
        Point find = new Point(x,y);

        sortXAxisList();
        sortYAxisList();

        double findingDistanceMaxSq = maxDistance * maxDistance;
        double findingDistanceMinSq = minDistance * minDistance;

        Point findingIndex = null;

        int posx = Collections.binarySearch(xaxis, find, xsort);
        int posy = Collections.binarySearch(yaxis, find, ysort);
        if (posx < 0) posx = ~posx;
        if (posy < 0) posy = ~posy;

        int mask = 0b1111;

        Point v;

        double vx, vy;
        int o;
        int itr = 0;
        while (mask != 0) {
            if ((mask & (1 << (itr & 3))) == 0) {
                itr++;
                continue; //if that direction is no longer used.
            }
            switch (itr & 3) {
                default:
                case 0: //+x
                    o = posx + (itr++ >> 2);
                    if (o >= xaxis.size()) {
                        mask &= 0b1110;
                        continue;
                    }
                    v = xaxis.get(o);
                    vx = x - v.x;
                    vy = y - v.y;
                    vx *= vx;
                    vy *= vy;
                    if (vx > findingDistanceMaxSq) {
                        mask &= 0b1110;
                        continue;
                    }
                    break;
                case 1: //+y
                    o = posy + (itr++ >> 2);
                    if (o >= yaxis.size()) {
                        mask &= 0b1101;
                        continue;
                    }
                    v = yaxis.get(o);
                    vx = x - v.x;
                    vy = y - v.y;
                    vx *= vx;
                    vy *= vy;
                    if (vy > findingDistanceMaxSq) {
                        mask &= 0b1101;
                        continue;
                    }
                    break;
                case 2: //-x
                    o = posx + ~(itr++ >> 2);
                    if (o < 0) {
                        mask &= 0b1011;
                        continue;
                    }
                    v = xaxis.get(o);
                    vx = x - v.x;
                    vy = y - v.y;
                    vx *= vx;
                    vy *= vy;
                    if (vx > findingDistanceMaxSq) {
                        mask &= 0b1011;
                        continue;
                    }
                    break;
                case 3: //-y
                    o = posy + ~(itr++ >> 2);
                    if (o < 0) {
                        mask = mask & 0b0111;
                        continue;
                    }
                    v = yaxis.get(o);
                    vx = x - v.x;
                    vy = y - v.y;
                    vx *= vx;
                    vy *= vy;
                    if (vy > findingDistanceMaxSq) {
                        mask = mask & 0b0111;
                        continue;
                    }
                    break;
            }
            double d = vx + vy;

            if (d <= findingDistanceMinSq) continue;

            if (d < findingDistanceMaxSq) {
                findingDistanceMaxSq = d;
                findingIndex = v;
            }

        }
        return findingIndex;
    }

    private void sortXAxisList() {
        if (!dirtySortX) return;
        Collections.sort(xaxis, xsort);
        dirtySortX = false;
    }

    private void sortYAxisList() {
        if (!dirtySortY) return;
        Collections.sort(yaxis,ysort);
        dirtySortY = false;
    }

    /**
     * Called if something should have invalidated the points for some reason.
     * Such as being moved outside of this class or otherwise updated.
     */
    public void update() {
        dirtySortX = true;
        dirtySortY = true;
    }

    /**
     * Called to add a point to the sorted list without needing to resort the list.
     * @param p Point to add.
     */
    public final void add(Point p) {
        sortXAxisList();
        sortYAxisList();
        int posx = Collections.binarySearch(xaxis, p, xsort);
        int posy = Collections.binarySearch(yaxis, p, ysort);
        if (posx < 0) posx = ~posx;
        if (posy < 0) posy = ~posy;
        xaxis.add(posx, p);
        yaxis.add(posy, p);
    }

    /**
     * Called to remove a point to the sorted list without needing to resort the list.
     * @param p Point to add.
     */
    public final void remove(Point p) {
        sortXAxisList();
        sortYAxisList();
        int posx = Collections.binarySearch(xaxis, p, xsort);
        int posy = Collections.binarySearch(yaxis, p, ysort);
        if (posx < 0) posx = ~posx;
        if (posy < 0) posy = ~posy;
        xaxis.remove(posx);
        yaxis.remove(posy);
    }
}

Update: With regard to, in the comments, the k-points problem. You will notice that very little changed. The only relevant thing was if the point v is found be be less than the current m (findingDistanceMaxSq), then that point is added to the heap, and the value for m is set to be equal to euclidean distance between the finding position and the kth element. The regular version of the algorithm can be seen as the case where k = 1. We search for the 1 element we want and we update m to equal the only (k=1) element when v is found to be closer.

Keep in mind, I only ever do distance comparatives in the distance squared form, since I only ever need to know if it's farther away, and I don't waste clock cycles on square root functions.

And I know there is a perfect data structure for storing the k-elements in a size limited heap. Obviously an array insertion is not optimal for that. But, other than too much java dependent apis there simply wasn't one for that particular class though apparently Google Guava makes one. But, you won't really notice at all given that odds are good your k is likely not that big. But, it does make the time complexity for an insertion in points stored in k-time. There are also things like caching the distance from the finding-point for the elements.

Finally, and likely most pressingly, the project I would use to test the code is in transition so I haven't managed to test this out. But, it certainly shows how you do this: You store the k best results thus far, and make m equal to the distance to the the kth closest point. -- All else remains the same.

Example source.

public static double distanceSq(double x0, double y0, double x1, double y1) {
    double dx = x1 - x0;
    double dy = y1 - y0;
    dx *= dx;
    dy *= dy;
    return dx + dy;
}
public Collection<Point> findNearest(int k, final float x, final float y, float minDistance, float maxDistance) {
    sortXAxisList();
    sortYAxisList();

    double findingDistanceMaxSq = maxDistance * maxDistance;
    double findingDistanceMinSq = minDistance * minDistance;
    ArrayList<Point> kpointsShouldBeHeap = new ArrayList<>(k);
    Comparator<Point> euclideanCompare = new Comparator<Point>() {
        @Override
        public int compare(Point o1, Point o2) {
            return Double.compare(distanceSq(x, y, o1.x, o1.y), distanceSq(x, y, o2.x, o2.y));
        }
    };

    Point find = new Point(x, y);
    int posx = Collections.binarySearch(xaxis, find, xsort);
    int posy = Collections.binarySearch(yaxis, find, ysort);
    if (posx < 0) posx = ~posx;
    if (posy < 0) posy = ~posy;

    int mask = 0b1111;

    Point v;

    double vx, vy;
    int o;
    int itr = 0;
    while (mask != 0) {
        if ((mask & (1 << (itr & 3))) == 0) {
            itr++;
            continue; //if that direction is no longer used.
        }
        switch (itr & 3) {
            default:
            case 0: //+x
                o = posx + (itr++ >> 2);
                if (o >= xaxis.size()) {
                    mask &= 0b1110;
                    continue;
                }
                v = xaxis.get(o);
                vx = x - v.x;
                vy = y - v.y;
                vx *= vx;
                vy *= vy;
                if (vx > findingDistanceMaxSq) {
                    mask &= 0b1110;
                    continue;
                }
                break;
            case 1: //+y
                o = posy + (itr++ >> 2);
                if (o >= yaxis.size()) {
                    mask &= 0b1101;
                    continue;
                }
                v = yaxis.get(o);
                vx = x - v.x;
                vy = y - v.y;
                vx *= vx;
                vy *= vy;
                if (vy > findingDistanceMaxSq) {
                    mask &= 0b1101;
                    continue;
                }
                break;
            case 2: //-x
                o = posx + ~(itr++ >> 2);
                if (o < 0) {
                    mask &= 0b1011;
                    continue;
                }
                v = xaxis.get(o);
                vx = x - v.x;
                vy = y - v.y;
                vx *= vx;
                vy *= vy;
                if (vx > findingDistanceMaxSq) {
                    mask &= 0b1011;
                    continue;
                }
                break;
            case 3: //-y
                o = posy + ~(itr++ >> 2);
                if (o < 0) {
                    mask = mask & 0b0111;
                    continue;
                }
                v = yaxis.get(o);
                vx = x - v.x;
                vy = y - v.y;
                vx *= vx;
                vy *= vy;
                if (vy > findingDistanceMaxSq) {
                    mask = mask & 0b0111;
                    continue;
                }
                break;
        }
        double d = vx + vy;
        if (d <= findingDistanceMinSq) continue;
        if (d < findingDistanceMaxSq) {
            int insert = Collections.binarySearch(kpointsShouldBeHeap, v, euclideanCompare);
            if (insert < 0) insert = ~insert;
            kpointsShouldBeHeap.add(insert, v);
            if (k < kpointsShouldBeHeap.size()) {
                Point kthPoint = kpointsShouldBeHeap.get(k);
                findingDistanceMaxSq = distanceSq(x, y, kthPoint.x, kthPoint.y);
            }
        }
    }
    //if (kpointsShouldBeHeap.size() > k) {
    //    kpointsShouldBeHeap.subList(0,k);
    //}
    return kpointsShouldBeHeap;
}
查看更多
【Aperson】
3楼-- · 2019-01-06 13:21

The "Fastest" way to do it, considering the search ONLY, would be to use voxels. With a 1:1 point-voxel map, the access time is constant and really fast, just shift the coordinates to center your point origin at the voxel origin(if needed) and then just round-down the position and access the voxel array with that value. For some cases, this is a good choice. As explained before me, octrees are better when a 1:1 map is hard to get( too much points, too little voxel resolution, too much free space).

查看更多
Deceive 欺骗
4楼-- · 2019-01-06 13:22

check this out.. You can consult CLRS computational geometry chapter also.. http://www.cs.ucsb.edu/~suri/cs235/ClosestPair.pdf

查看更多
乱世女痞
5楼-- · 2019-01-06 13:28

You may organize your points in an Octree. Then you only need to search a small subset.

A Octree is a fairly simple data structure you can implement yourself (which would be a valuable learning experience), or you may find some helpful libraries to get you going.

查看更多
男人必须洒脱
6楼-- · 2019-01-06 13:28

I would use a KD-tree to do this in O(log(n)) time, assuming the points are randomly distributed or you have a way to keep the tree balanced.

http://en.wikipedia.org/wiki/Kd-tree

KD trees are excellent for this kind of spatial query, and even allow you to retrieve the nearest k neighbors to a query point.

查看更多
甜甜的少女心
7楼-- · 2019-01-06 13:29

Its my understanding quadtree is for 2d, but you could calculate something for 3d thats is very similar. This will speed up your search, but it will require much more time to calculate the index if done on the fly. I would suggest calculating the index once then storing it. On every lookup you figure out all of the exterior quads then work your way in looking for hits... it would look like pealing an orange. The speed will greatly increase as the quads get smaller. Everything has a trade off.

查看更多
登录 后发表回答