How is nth_element Implemented?

2020-05-18 17:08发布

There are a lot of claims on StackOverflow and elsewhere that nth_element is O(n) and that it is typically implemented with Introselect: http://en.cppreference.com/w/cpp/algorithm/nth_element

I want to know how this can be achieved. I looked at Wikipedia's explanation of Introselect and that just left me more confused. How can an algorithm switch between QSort and Median-of-Medians?

I found the Introsort paper here: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.5196&rep=rep1&type=pdf But that says:

In this paper, we concentrate on the sorting problem and return to the selection problem only briefly in a later section.

I've tried to read through the STL itself to understand how nth_element is implemented, but that gets hairy real fast.

Could someone show me pseudo-code for how Introselect is implemented? Or even better, actual C++ code other than the STL of course :)

3条回答
迷人小祖宗
2楼-- · 2020-05-18 17:18

You asked two questions, the titular one

How is nth_element Implemented?

Which you already answered:

There are a lot of claims on StackOverflow and elsewhere that nth_element is O(n) and that it is typically implemented with Introselect.

Which I also can confirm from looking at my stdlib implementation. (More on this later.)

And the one where you don't understand the answer:

How can an algorithm switch between QSort and Median-of-Medians?

Lets have a look at pseudo code that I extracted from my stdlib:

nth_element(first, nth, last)
{ 
  if (first == last || nth == last)
    return;

  introselect(first, nth, last, log2(last - first) * 2);
}

introselect(first, nth, last, depth_limit)
{
  while (last - first > 3)
  {
      if (depth_limit == 0)
      {
          // [NOTE by editor] This should be median-of-medians instead.
          // [NOTE by editor] See Azmisov's comment below
          heap_select(first, nth + 1, last);
          // Place the nth largest element in its final position.
          iter_swap(first, nth);
          return;
      }
      --depth_limit;
      cut = unguarded_partition_pivot(first, last);
      if (cut <= nth)
        first = cut;
      else
        last = cut;
  }
  insertion_sort(first, last);
}

Without getting into details about the referenced functions heap_select and unguarded_partition_pivot we can clearly see, that nth_element gives introselect 2 * log2(size) subdivision steps (twice as much as needed by quickselect in the best case) until heap_select kicks in and solves the problem for good.

查看更多
我欲成王,谁敢阻挡
3楼-- · 2020-05-18 17:23

Disclaimer: I don't know how std::nth_element is implemented in any standard library.

If you know how Quicksort works, you can easily modify it to do what is needed for this algorithm. The basic idea of Quicksort is that in each step, you partition the array into two parts such that all elements less than the pivot are in the left sub-array and all elements equal to or greater than the pivot are in the right sub-array. (A modification of Quicksort known as ternary Quicksort creates a third sub-array with all elements equal to the pivot. Then the right sub-array contains only entries strictly greater than the pivot.) Quicksort then proceeds by recursively sorting the left and right sub-arrays.

If you only want to move the n-th element into place, instead of recursing into both sub-arrays, you can tell in every step whether you will need to descend into the left or right sub-array. (You know this because the n-th element in a sorted array has index n so it becomes a matter of comparing indices.) So – unless your Quicksort suffers worst-case degeneration – you roughly halve the size of the remaining array in each step. (You never look at the other sub-array again.) Therefore, on average, you are dealing with arrays of the following lengths in each step:

  1. Θ(N)
  2. Θ(N / 2)
  3. Θ(N / 4)

Each step is linear in the length of the array it is dealing with. (You loop over it once and decide into what sub-array each element should go depending on how it compares to the pivot.)

You can see that after Θ(log(N)) steps, we will eventually reach a singleton array and are done. If you sum up N (1 + 1/2 + 1/4 + …), you'll get 2 N. Or, in the average case, since we cannot hope that the pivot will always exactly be the median, something on the order of Θ(N).

查看更多
爱情/是我丢掉的垃圾
4楼-- · 2020-05-18 17:27

The code from the STL (version 3.3, I think) is this:

template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
                   _RandomAccessIter __last, _Tp*) {
  while (__last - __first > 3) {
    _RandomAccessIter __cut =
      __unguarded_partition(__first, __last,
                            _Tp(__median(*__first,
                                         *(__first + (__last - __first)/2),
                                         *(__last - 1))));
    if (__cut <= __nth)
      __first = __cut;
    else 
      __last = __cut;
  }
  __insertion_sort(__first, __last);
}

Let's simplify that a bit:

template <class Iter, class T>
void nth_element(Iter first, Iter nth, Iter last) {
  while (last - first > 3) {
    Iter cut =
      unguarded_partition(first, last,
                          T(median(*first,
                                   *(first + (last - first)/2),
                                   *(last - 1))));
    if (cut <= nth)
      first = cut;
    else 
      last = cut;
  }
  insertion_sort(first, last);
}

What I did here was to remove double underscores and _Uppercase stuff, which is only to protect the code from things the user could legally define as macros. I also removed the last parameter, which is only supposed to help in template type deduction, and renamed the iterator type for brevity.

As you should see now, it partitions the range repeatedly until less than four elements remain in the remaining range, which is then simply sorted.

Now, why is that O(n)? Firstly, the final sorting of up to three elements is O(1), because of the maximum of three elements. Now, what remains is the repeated partitioning. Partitioning in and of itself is O(n). Here though, every step halves the number of elements that need to be touched in the next step, so you have O(n) + O(n/2) + O(n/4) + O(n/8) which is less than O(2n) if you sum it up. Since O(2n) = O(n), you have linar complexity on average.

查看更多
登录 后发表回答