There are a lot of claims on StackOverflow and elsewhere that nth_element
is O(n) and that it is typically implemented with Introselect: http://en.cppreference.com/w/cpp/algorithm/nth_element
I want to know how this can be achieved. I looked at Wikipedia's explanation of Introselect and that just left me more confused. How can an algorithm switch between QSort and Median-of-Medians?
I found the Introsort paper here: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.5196&rep=rep1&type=pdf But that says:
In this paper, we concentrate on the sorting problem and return to the selection problem only briefly in a later section.
I've tried to read through the STL itself to understand how nth_element
is implemented, but that gets hairy real fast.
Could someone show me pseudo-code for how Introselect is implemented? Or even better, actual C++ code other than the STL of course :)
You asked two questions, the titular one
Which you already answered:
Which I also can confirm from looking at my stdlib implementation. (More on this later.)
And the one where you don't understand the answer:
Lets have a look at pseudo code that I extracted from my stdlib:
Without getting into details about the referenced functions
heap_select
andunguarded_partition_pivot
we can clearly see, thatnth_element
gives introselect2 * log2(size)
subdivision steps (twice as much as needed by quickselect in the best case) untilheap_select
kicks in and solves the problem for good.Disclaimer: I don't know how
std::nth_element
is implemented in any standard library.If you know how Quicksort works, you can easily modify it to do what is needed for this algorithm. The basic idea of Quicksort is that in each step, you partition the array into two parts such that all elements less than the pivot are in the left sub-array and all elements equal to or greater than the pivot are in the right sub-array. (A modification of Quicksort known as ternary Quicksort creates a third sub-array with all elements equal to the pivot. Then the right sub-array contains only entries strictly greater than the pivot.) Quicksort then proceeds by recursively sorting the left and right sub-arrays.
If you only want to move the n-th element into place, instead of recursing into both sub-arrays, you can tell in every step whether you will need to descend into the left or right sub-array. (You know this because the n-th element in a sorted array has index n so it becomes a matter of comparing indices.) So – unless your Quicksort suffers worst-case degeneration – you roughly halve the size of the remaining array in each step. (You never look at the other sub-array again.) Therefore, on average, you are dealing with arrays of the following lengths in each step:
Each step is linear in the length of the array it is dealing with. (You loop over it once and decide into what sub-array each element should go depending on how it compares to the pivot.)
You can see that after Θ(log(N)) steps, we will eventually reach a singleton array and are done. If you sum up N (1 + 1/2 + 1/4 + …), you'll get 2 N. Or, in the average case, since we cannot hope that the pivot will always exactly be the median, something on the order of Θ(N).
The code from the STL (version 3.3, I think) is this:
Let's simplify that a bit:
What I did here was to remove double underscores and _Uppercase stuff, which is only to protect the code from things the user could legally define as macros. I also removed the last parameter, which is only supposed to help in template type deduction, and renamed the iterator type for brevity.
As you should see now, it partitions the range repeatedly until less than four elements remain in the remaining range, which is then simply sorted.
Now, why is that O(n)? Firstly, the final sorting of up to three elements is O(1), because of the maximum of three elements. Now, what remains is the repeated partitioning. Partitioning in and of itself is O(n). Here though, every step halves the number of elements that need to be touched in the next step, so you have O(n) + O(n/2) + O(n/4) + O(n/8) which is less than O(2n) if you sum it up. Since O(2n) = O(n), you have linar complexity on average.