You mentioned that the size of the array was always known to be 23
. Moreover, the type used is unsigned short
. In this case, you might try to use a sorting network of size 23
; since your type is unsigned short
, sorting the whole array with a sorting network might be even faster than partially sorting it with std::nth_element
. Here is a very straightforward C++14 implementation of a sorting network of size 23
with 118
compare-exchange units, as described by Using Symmetry and Evolutionary Search to Minimize Sorting Networks:
template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
swap_if(first[1u], first[20u], compare);
swap_if(first[2u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[0u], first[7u], compare);
swap_if(first[15u], first[22u], compare);
swap_if(first[4u], first[11u], compare);
swap_if(first[6u], first[12u], compare);
swap_if(first[10u], first[16u], compare);
swap_if(first[8u], first[18u], compare);
swap_if(first[14u], first[19u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[4u], first[14u], compare);
swap_if(first[11u], first[18u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[0u], first[9u], compare);
swap_if(first[13u], first[22u], compare);
swap_if(first[5u], first[15u], compare);
swap_if(first[7u], first[17u], compare);
swap_if(first[1u], first[10u], compare);
swap_if(first[12u], first[21u], compare);
swap_if(first[8u], first[19u], compare);
swap_if(first[17u], first[22u], compare);
swap_if(first[0u], first[5u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[0u], first[1u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[0u], first[3u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[6u], first[15u], compare);
swap_if(first[7u], first[16u], compare);
swap_if(first[8u], first[11u], compare);
swap_if(first[11u], first[14u], compare);
swap_if(first[4u], first[11u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[17u], first[20u], compare);
swap_if(first[2u], first[5u], compare);
swap_if(first[9u], first[12u], compare);
swap_if(first[10u], first[13u], compare);
swap_if(first[15u], first[18u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[4u], first[7u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[3u], first[9u], compare);
swap_if(first[13u], first[19u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[8u], first[14u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[18u], first[21u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[4u], first[9u], compare);
swap_if(first[13u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[5u], first[8u], compare);
swap_if(first[14u], first[17u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[11u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[6u], first[9u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[12u], first[15u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[6u], first[11u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[9u], first[12u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[11u], first[12u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[11u], first[12u], compare);
}
The swap_if
utility function compares two parameters x
and y
with the predicate compare
and swaps them if compare(y, x)
. My example uses a a generic swap_if
function, but you can used an optimized version if you known that you will be comparing unsigned short
values with operator<
anyway (you might not need such a function if your compiler recognizes and optimizes the compare-exchange, but unfortunately, not all compilers do that - I am using g++5.2 with -O3
and I still need the following function for performance):
void swap_if(unsigned short& x, unsigned short& y)
{
unsigned short dx = x;
unsigned short dy = y;
unsigned short tmp = x = std::min(dx, dy);
y ^= dx ^ tmp;
}
Now, just to make sure that it is indeed faster, I decided to time std::nth_element
when required to partial sort only the first 10 elements vs. sorting the whole 23 elements with the sorting network (1000000 times with different shuffled arrays). Here is what I get:
std::nth_element 1158ms
network_sort23 487ms
That said, my computer has been running for a bit of time and is a bit slow, but the difference in performance is neat. I believe that this difference will remain the same when I restart my computer. I may try it later and let you know.
Regarding how these times were generated, I used a modified version of this benchmark from my cpp-sort library. The original sorting network and swap_if
functions come from there as well, so you can be sure that they have been tested more than once :)
EDIT: here are the results now that I have restarted my computer. The network_sort23
version is still two times faster than std::nth_element
:
std::nth_element 369ms
network_sort23 154ms
EDIT²: if all you need in the median, you can trivially delete the compare-exchange units that are not needed to compute the final value that will be at the 11th position. The resulting median-finding network of size 23 that follows uses a different size-23 sorting network than the previous one, and it yields slightly better results:
swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);
There are probably smarter ways to generate median-finding networks, but I don't think that extensive research has been done on the subject. Therefore, it's probably the best method you can use as of now. The result isn't awesome but it still uses 104 compare-exchange units instead of 118.
General idea
Looking at source code of std::nth_element
in MSVC2013, it seems that cases of N <= 32 are solved by insertion sort. It means that STL implementors realized that doing randomized partitions would be slower despite better asymptotics for that sizes.
One of the ways to improve performance is to optimize sorting algorithm. @Morwenn's answer shows how to sort 23 elements with a sorting network, which is known to be one of the fastest ways to sort small constant-sized arrays.
I'll investigate the other way, which is to calculate median without sorting algorithm. In fact, I won't permute the input array at all.
Since we are talking about small arrays, we need to implement some O(N^2) algorithm in the simplest way possible. Ideally, it should have no branches at all, or only well-predictable branches. Also, simple structure of the algorithm could allow us to vectorize it, further improving its performance.
Algorithm
I have decided to follow the counting method, which was used here to accelerate small linear search. First of all, suppose that all the elements are different. Choose any element of the array: number of elements less than it defines its position in the sorted array. We can iterate over all elements, and for each of them calculate number of elements less than it. If the sorted index has desired value, we can stop the algorithm.
Unfortunately, there may be equal elements in general case. We'll have to make our algorithm significantly slower and more complex to handle them. Instead of calculating the unique sorted index of an element, we can calculate interval of possible sorted indices for it. For any element, it is enough to count number of elements less than it (L) and number of elements equal to it (E), then sorted index fits range [L, L+R). If this interval contains desired sorted index (i.e. N/2), then we can stop the algorithm and return the considered element.
for (size_t i = 0; i < n; i++) {
auto x = arr[i];
//count number of "less" and "equal" elements
int cntLess = 0, cntEq = 0;
for (size_t j = 0; j < n; j++) {
cntLess += arr[j] < x;
cntEq += arr[j] == x;
}
//fast range checking from here: https://stackoverflow.com/a/17095534/556899
if ((unsigned int)(idx - cntLess) < cntEq)
return x;
}
Vectorization
The constructed algorithm has only one branch, which is rather predictable: it fails in all cases, except for the only case when we stop the algorithm. The algorithm is easy to vectorize using 8 elements per SSE register. Since we'll have to access some elements after the last one, I'll assume that the input array is padded with max=2^15-1 values up to 24 or 32 elements.
The first way is to vectorize inner loop by j
. In this case inner loop would be executed only 3 times, but two 8-wide reductions must be done after it is finished. They eat more time than the inner loop itself. As a result, such a vectorization is not very efficient.
The second way is to vectorize outer loop by i
. In this case we process 8 elements x = arr[i]
at once. For each pack, we compare it with each element arr[j]
in inner loop. After the inner loop we perform vectorized range check for the whole pack of 8 elements. If any of them succeeds, we determine exact number with simple scalar code (it eats little time anyway).
__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
//load pack of 8 elements
auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
//count number of less/equal elements for each element in the pack
__m128i cntLess = _mm_setzero_si128();
__m128i cntEq = _mm_setzero_si128();
for (size_t j = 0; j < n; j++) {
__m128i vAll = _mm_set1_epi16(arr[j]);
cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
}
//perform range check for 8 elements at once
__m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
if (int bm = _mm_movemask_epi8(mask)) {
//range check succeeds for one of the elements, find and return it
for (int t = 0; t < 8; t++)
if (bm & (1 << (2*t)))
return arr[i + t];
}
}
Here we see _mm_set1_epi16
intrinsic in the innermost loop. GCC seems to have some performance issues with it. Anyway, it is eating time on each innermost iteration, which can be reduced if we process 8 elements at once in the innermost loop too. In such case we can do one vectorized load and 14 unpack instructions to obtain vAll
for eight elements. Also, we'll have to write compare-and-count code for eight elements in loop body, so it acts as 8x unrolling too. The resulting code is the fastest one, a link to it can be found below.
Comparison
I have benchmarked various solutions on Ivy Bridge 3.4 Ghz processor. Below you can see total computation time for 2^23 ~= 8M calls in seconds (the first number). Second number is checksum of results.
Results on MSVC 2013 x64 (/O2):
memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064) //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064) //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064) //vectorization by j
vectorized count (T): 0.602 (1186136064) //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064) //vectorization by i and j
Results on MinGW GCC 4.8.3 x64 (-O3 -msse4):
memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632) //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632) //GCC generates some crap
vectorized count (both): 0.374 (1095237632)
As you see, the proposed vectorized algorithm for 23 16-bit elements is a bit faster than sorting-based approach (BTW, on an older CPU I see only 5% time difference).
If you can guarantee that all elements are different, you can simplify the algorithm, making it even faster.
The full code of all algorithms is available here, including all the testing code.