Extract set bytes position from SIMD vector

2019-05-06 21:25发布

问题:

I run a bench of computations using SIMD intructions. These instructions return a vector of 16 bytes as result, named compare, with each byte being 0x00 or 0xff :

             0    1    2    3    4    5    6    7       15   16
compare : 0x00 0x00 0x00 0x00 0xff 0x00 0x00 0x00 ... 0xff 0x00

Bytes set to 0xff mean I need to run the function do_operation(i) with i being the position of the byte.

For instance, the above compare vector mean, I need to run this sequence of operations :

do_operation(4);
do_operation(15);

Here is the fastest solution I came up with until now :

for(...) {
        //
        // SIMD computations
        //
        __m128i compare = ... // Result of SIMD computations

        // Extract high and low quadwords for compare vector
        std::uint64_t cmp_low = (_mm_cvtsi128_si64(compare));
        std::uint64_t cmp_high = (_mm_extract_epi64(compare, 1));

        //  Process low quadword 
        if (cmp_low) {
            const std::uint64_t low_possible_positions = 0x0706050403020100;
            const std::uint64_t match_positions = _pext_u64(
                    low_possible_positions, cmp_low);
            const int match_count = _popcnt64(cmp_low) / 8;
            const std::uint8_t* match_pos_array =
                    reinterpret_cast<const std::uint8_t*>(&match_positions);

            for (int i = 0; i < match_count; ++i) {
                do_operation(i);
            }
        }

        // Process high quadword (similarly)
        if (cmp_high) { 

            const std::uint64_t high_possible_positions = 0x0f0e0d0c0b0a0908;
            const std::uint64_t match_positions = _pext_u64(
                    high_possible_positions, cmp_high);
            const int match_count = _popcnt64(cmp_high) / 8;
            const std::uint8_t* match_pos_array =
                    reinterpret_cast<const std::uint8_t*>(&match_positions);

            for(int i = 0; i < match_count; ++i) {
                do_operation(i);
            }
        }
}

I start with extracting the first and second 64 bits integers of the 128 bits vector (cmp_low and cmp_high). Then I use popcount to compute the number of bytes set to 0xff (number of bits set to 1 divided by 8). Finally, I use pext to get positions, without zeros, like this :

0x0706050403020100
0x000000ff00ff0000
        |
      PEXT
        |
0x0000000000000402

I would like to find a faster solution to extract the positions of the bytes set to 0xff in the compare vector. More precisely, the are very often only 0, 1 or 2 bytes set to 0xff in the compare vector and I would like to use this information to avoid some branches.

回答1:

Here's a quick outline of how you could reduce the number of tests:

  • First use a function to project all the lsb or msb of each byte of your 128bit integer into a 16bit value (for instance, there's a SSE2 assembly instruction for that on X86 cpus: pmovmskb, which is supported on Intel and MS compilers with the _mm_movemask_pi8 intrinsic, and gcc has also an intrinsic: __builtin_ia32_ppmovmskb128, );

  • Then split that value in 4 nibbles;

  • define functions to handle each possible values of a nibble (from 0 to 15) and put these in an array;

  • Finally call the function indexed by each nibble (with extra parameters to indicate which nibble in the 16bits it is).



回答2:

Since in your case very often only 0, 1 or 2 bytes are set to 0xff in the compare vector, a short while-loop on the bitmask might be more efficient than a solution based on the pext instruction. See also my answer on a similar question.


/*
gcc -O3 -Wall -m64 -mavx2 -march=broadwell esbsimd.c
*/

#include <stdio.h>
#include <immintrin.h>

int do_operation(int i){           /* some arbitrary do_operation() */
   printf("i = %d\n",i);
   return 0;
}

int main(){

   __m128i compare = _mm_set_epi8(0xFF,0,0,0,  0,0,0,0, 0,0,0,0xFF, 0,0,0,0);   /* Take some randon value for compare */
   int           k = _mm_movemask_epi8(compare);

   while (k){
      int i=_tzcnt_u32(k);                                /* Count the number of trailing zero bits in k.  BMI1 instruction set, Haswell or newer. */
      do_operation(i);
      k=_blsr_u32(k);                                     /* Clear the lowest set bit in k.                                                        */
   }
   return 0;
}

/* 
Output:

i = 4
i = 15

*/