Finding the most frequently occurring element in a

2020-06-16 03:29发布

Does anyone have any thoughts on how to calculate the mode (statistic) of a vector of 8-bit integers in SSE4.x? To clarify, this would be 16x8-bit values in a 128-bit register.

I want the result as a vector mask which selects the mode-valued elements. i.e. the result of _mm_cmpeq_epi8(v, set1(mode(v))), as well as the scalar value.


Providing some additional context; while the above problem is an interesting one to solve in its own right, I have been through most algorithms I can think of with linear complexity. This class will wipe out any gains I can get from calculating this number.

I hope to engage you all in searching for some deep magic, here. It's possible that an approximation may be necessary to break this bound, such as "select a frequently occurring element" for example (N.B. difference against the most), which would be of merit. A probabilistic answer would be usable, too.

SSE and x86 have some very interesting semantics. It may be worth exploring a superoptimization pass.

3条回答
可以哭但决不认输i
2楼-- · 2020-06-16 04:02

Sort the data in the register. Insertion sort can be done in 16 (15) steps, by initializing the register to "Infinity", which tries to illustrate a monotonically decreasing array and inserting the new element in parallel to all possible places:

// e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78
__m128i sorted = _mm_or_si128(my_array, const_FFFFF00);

for (int i = 1; i < 16; ++i)
{
    // Trying to insert e.g. A0, we must shift all the FF's to left
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78 00
    __m128i shifted = _mm_bslli_si128(sorted, 1);

    // Taking the MAX of shifted and 'A0 on all places'
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 A0
    shifted = _mm_max_epu8(shifted, _mm_set1_epi8(my_array[i]));

    // and minimum of the shifted + original --
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 78
    sorted = _mm_min_epu8(sorted, shifted);
}

Then calculate mask for vec[n+1] == vec[n], move mask to GPR and use that to index a 32768 entry LUT for best index location.

In real case one probably want's to sort more than just one vector; i.e. sort 16 16-entry vectors at once;

__m128i input[16];      // not 1, but 16 vectors
transpose16x16(input);  // inplace vector transpose
sort(transpose);        // 60-stage network exists for 16 inputs
// linear search -- result in 'mode'
__m128i mode = input[0];
__m128i previous = mode;
__m128i count = _mm_set_epi8(0);
__m128i max_count = _mm_setzero_si128(0);
for (int i = 1; i < 16; i++)
{
   __m128i &current = input[i];
   // histogram count is off by one
   // if (current == previous) count++;
   //    else count = 0;
   // if (count > max_count)
   //    mode = current, max_count = count
   prev = _mm_cmpeq_epi8(prev, current);
   count = _mm_and_si128(_mm_sub_epi8(count, prev), prev);
   __m128i max_so_far = _mm_cmplt_epi8(max_count, count);
   mode = _mm_blendv_epi8(mode, current, max_so_far);
   max_count = _mm_max_epi8(max_count, count);
   previous = current;
}

The inner loop totals amortized cost of 7-8 instructions per result; Sorting has typically 2 instructions per stage -- i.e. 8 instructions per result, when 16 results need 60 stages or 120 instructions. (This still leaves the transpose as an exercise -- but I think it should be vastly faster than sorting?)

So, this should be in the ball park of 24 instructions per 8-bit result.

查看更多
Deceive 欺骗
3楼-- · 2020-06-16 04:04

Probably a relatively simple brute force SSEx approach is suitable here, see the code below. The idea is to byte-rotate the input vector v by 1 to 15 positions and compare the rotated vector with the original v for equality. To shorten the dependency chain and to increase the instruction level parallelism, two counters are used to count (vertical sum) these equal elements: sum1 and sum2, because there might be architectures that benefit from this. Equal elements are counted as -1. Variable sum = sum1 + sum2 contains the total count with values between -1 and -16. min_brc contains the horizontal minimum of sum broadcasted to all elements. mask = _mm_cmpeq_epi8(sum,min_brc) is the mask for the mode-valued elements requested as an intermediate result by the OP. In the next few lines of the code the actual mode is extracted.

This solution is certainly faster than a scalar solution. Note that with AVX2 the upper 128-bit lanes can be used to speedup the computation further.

It takes 20 cycles (throughput) to compute only the a mask for the mode-valued elements. With the actual mode broadcasted across the SSE register it takes about 21.4 cycles.

Note the behaviour in the next example: [1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] returns mask=[-1,-1,-1,-1,0,0,...,0] and the mode value is 1, although 1 occurs as often as 3.

The code below is tested, but not thoroughly tested

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=nehalem mode_uint8.c   */
int print_vec_char(__m128i x);

__m128i mode_statistic(__m128i v){
    __m128i  sum2         = _mm_set1_epi8(-1);                    /* Each integer occurs at least one time */
    __m128i  v_rot1       = _mm_alignr_epi8(v,v,1);
    __m128i  v_rot2       = _mm_alignr_epi8(v,v,2);
    __m128i  sum1         =                   _mm_cmpeq_epi8(v,v_rot1);
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot2));

    __m128i  v_rot3       = _mm_alignr_epi8(v,v,3);
    __m128i  v_rot4       = _mm_alignr_epi8(v,v,4);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot3));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot4));

    __m128i  v_rot5       = _mm_alignr_epi8(v,v,5);
    __m128i  v_rot6       = _mm_alignr_epi8(v,v,6);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot5));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot6));

    __m128i  v_rot7       = _mm_alignr_epi8(v,v,7);
    __m128i  v_rot8       = _mm_alignr_epi8(v,v,8);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot7));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot8));

    __m128i  v_rot9       = _mm_alignr_epi8(v,v,9);
    __m128i  v_rot10      = _mm_alignr_epi8(v,v,10);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot9));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot10));

    __m128i  v_rot11      = _mm_alignr_epi8(v,v,11);
    __m128i  v_rot12      = _mm_alignr_epi8(v,v,12);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot11));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot12));

    __m128i  v_rot13      = _mm_alignr_epi8(v,v,13);
    __m128i  v_rot14      = _mm_alignr_epi8(v,v,14);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot13));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot14));

    __m128i  v_rot15      = _mm_alignr_epi8(v,v,15);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot15));
    __m128i  sum          = _mm_add_epi8(sum1,sum2);                      /* Sum contains values such as -1, -2 ,...,-16                                    */
                                                                          /* The next three instructions compute the horizontal minimum of sum */
    __m128i  sum_shft     = _mm_srli_epi16(sum,8);                        /* Shift right 8 bits, while shifting in zeros                                    */
    __m128i  min1         = _mm_min_epu8(sum,sum_shft);                   /* sum and sum_shuft are considered as unsigned integers. sum_shft is zero at the odd positions and so is min1 */ 
    __m128i  min2         = _mm_minpos_epu16(min1);                       /* Byte 0 within min2 contains the horizontal minimum of sum                      */
    __m128i  min_brc      = _mm_shuffle_epi8(min2,_mm_setzero_si128());   /* Broadcast horizontal minimum                                                   */

    __m128i  mask         = _mm_cmpeq_epi8(sum,min_brc);                  /* Mask = -1 at the byte positions where the value of v is equal to the mode of v */

    /* comment next 4 lines out if there is no need to broadcast the mode value */
    int      bitmask      = _mm_movemask_epi8(mask);
    int      indx         = __builtin_ctz(bitmask);                            /* Index of mode                            */
    __m128i  v_indx       = _mm_set1_epi8(indx);                               /* Broadcast indx                           */
    __m128i  answer       = _mm_shuffle_epi8(v,v_indx);                        /* Broadcast mode to each element of answer */ 

/* Uncomment lines below to print intermediate results, to see how it works. */
//    printf("sum         = ");print_vec_char (sum           );
//    printf("sum_shft    = ");print_vec_char (sum_shft      );
//    printf("min1        = ");print_vec_char (min1          );
//    printf("min2        = ");print_vec_char (min2          );
//    printf("min_brc     = ");print_vec_char (min_brc       );
//    printf("mask        = ");print_vec_char (mask          );
//    printf("v_indx      = ");print_vec_char (v_indx        );
//    printf("answer      = ");print_vec_char (answer        );

             return answer;   /* or return mask, or return both ....    :) */
}


int main() {
    /* To test throughput set throughput_test to 1, otherwise 0    */
    /* Use e.g. perf stat -d ./a.out to test throughput           */
    #define throughput_test 0

    /* Different test vectors  */
    int i;
    char   x1[16] = {5, 2, 2, 7, 21, 4, 7, 7, 3, 9, 2, 5, 4, 3, 5, 5};
    char   x2[16] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
    char   x3[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    char   x4[16] = {1, 2, 3, 2, 1, 6, 7, 8, 2, 2, 2, 3, 3, 2, 15, 16};
    char   x5[16] = {1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    printf("\n15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0\n\n");

    __m128i  x_vec  = _mm_loadu_si128((__m128i*)x1);

    printf("x_vec       = ");print_vec_char(x_vec        );

    __m128i  y      = mode_statistic (x_vec);

    printf("answer      = ");print_vec_char(y         );


    #if throughput_test == 1
    __m128i  x_vec1  = _mm_loadu_si128((__m128i*)x1);
    __m128i  x_vec2  = _mm_loadu_si128((__m128i*)x2);
    __m128i  x_vec3  = _mm_loadu_si128((__m128i*)x3);
    __m128i  x_vec4  = _mm_loadu_si128((__m128i*)x4);
    __m128i  x_vec5  = _mm_loadu_si128((__m128i*)x5);
    __m128i  y1, y2, y3, y4, y5;
    __asm__ __volatile__ ( "vzeroupper" : : : );   /* Remove this line on non-AVX processors */
    for (i=0;i<100000000;i++){
        y1       = mode_statistic (x_vec1);
        y2       = mode_statistic (x_vec2);
        y3       = mode_statistic (x_vec3);
        y4       = mode_statistic (x_vec4);
        y5       = mode_statistic (x_vec5);
        x_vec1   = mode_statistic (y1    );
        x_vec2   = mode_statistic (y2    );
        x_vec3   = mode_statistic (y3    );
        x_vec4   = mode_statistic (y4    );
        x_vec5   = mode_statistic (y5    );
     }
    printf("mask mode   = ");print_vec_char(y1           );
    printf("mask mode   = ");print_vec_char(y2           );
    printf("mask mode   = ");print_vec_char(y3           );
    printf("mask mode   = ");print_vec_char(y4           );
    printf("mask mode   = ");print_vec_char(y5           );
    #endif

    return 0;
}



int print_vec_char(__m128i x){
    char v[16];
    _mm_storeu_si128((__m128i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi\n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

Output:

15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0

x_vec       =   5   5   3   4 |   5   2   9   3 |   7   7   4  21 |   7   2   2   5
sum         =  -4  -4  -2  -2 |  -4  -3  -1  -2 |  -3  -3  -2  -1 |  -3  -3  -3  -4
min_brc     =  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4
mask        =  -1  -1   0   0 |  -1   0   0   0 |   0   0   0   0 |   0   0   0  -1
answer      =   5   5   5   5 |   5   5   5   5 |   5   5   5   5 |   5   5   5   5

The horizontal minimum is computed with Evgeny Kluev's method.

查看更多
地球回转人心会变
4楼-- · 2020-06-16 04:12

For performance comparison with scalar code. Non-vectorized on main part but vectorized on table-clear and tmp initialization. (168 cycles per f() call for fx8150 (22M calls complete in 1.0002 seconds at 3.7 GHz))

#include <x86intrin.h>

unsigned char tmp[16]; // extracted values are here (single instruction, store_ps)
unsigned char table[256]; // counter table containing zeroes
char f(__m128i values)
{
    _mm_store_si128((__m128i *)tmp,values);
    int maxOccurence=0;
    int currentValue=0;
    for(int i=0;i<16;i++)
    {
        unsigned char ind=tmp[i];
        unsigned char t=table[ind];
        t++;
        if(t>maxOccurence)
        {
             maxOccurence=t;
             currentValue=ind;
        }
        table[ind]=t;
    }
    for(int i=0;i<256;i++)
        table[i]=0;
    return currentValue;
}

g++ 6.3 output:

f:                                      # @f
        movaps  %xmm0, tmp(%rip)
        movaps  %xmm0, -24(%rsp)
        xorl    %r8d, %r8d
        movq    $-15, %rdx
        movb    -24(%rsp), %sil
        xorl    %eax, %eax
        jmp     .LBB0_1
.LBB0_2:                                # %._crit_edge
        cmpl    %r8d, %esi
        cmovgel %esi, %r8d
        movb    tmp+16(%rdx), %sil
        incq    %rdx
.LBB0_1:                                # =>This Inner Loop Header: Depth=1
        movzbl  %sil, %edi
        movb    table(%rdi), %cl
        incb    %cl
        movzbl  %cl, %esi
        cmpl    %r8d, %esi
        cmovgl  %edi, %eax
        movb    %sil, table(%rdi)
        testq   %rdx, %rdx
        jne     .LBB0_2
        xorps   %xmm0, %xmm0
        movaps  %xmm0, table+240(%rip)
        movaps  %xmm0, table+224(%rip)
        movaps  %xmm0, table+208(%rip)
        movaps  %xmm0, table+192(%rip)
        movaps  %xmm0, table+176(%rip)
        movaps  %xmm0, table+160(%rip)
        movaps  %xmm0, table+144(%rip)
        movaps  %xmm0, table+128(%rip)
        movaps  %xmm0, table+112(%rip)
        movaps  %xmm0, table+96(%rip)
        movaps  %xmm0, table+80(%rip)
        movaps  %xmm0, table+64(%rip)
        movaps  %xmm0, table+48(%rip)
        movaps  %xmm0, table+32(%rip)
        movaps  %xmm0, table+16(%rip)
        movaps  %xmm0, table(%rip)
        movsbl  %al, %eax
        ret
查看更多
登录 后发表回答