If an SSE/AVX register's value is such that all its bytes are either 0 or 1, is there any way to efficiently get the indices of all non zero elements?
For example, if xmm value is
| r0=0 | r1=1 | r2=0 | r3=1 | r4=0 | r5=1 | r6=0 |...| r14=0 | r15=1 |
the result should be something like (1, 3, 5, ... , 15). The result should be placed in another _m128i variable or char[16] array.
If it helps, we can assume that register's value is such that all bytes are either 0 or some constant nonzero value (not necessary 1).
I am pretty much wondering if there is an instruction for that or preferably C/C++ intrinsic. In any SSE or AVX set of instructions.
EDIT 1:
It was correctly observed by @zx485 that original question was not clear enough. I was looking for any "consecutive" solution.
The example 0 1 0 1 0 1 0 1...
above should result in either of the following:
- If we assume that indices start from 1, then
0
would be a termination byte and the result might be
002 004 006 008 010 012 014 016 000 000 000 000 000 000 000 000
- If we assume that negative byte is a termination byte the result might be
001 003 005 007 009 011 013 015 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF
- Anything, that gives as a consecutive bytes which we can interpret as indices of non-zero elements in the original value
EDIT 2:
Indeed, as @harold and @Peter Cordes suggest in the comments to the original post, one of the possible solutions is to create a mask first (e.g. with pmovmskb
) and check non zero indices there. But that will lead to a loop.
Your question was unclear regarding the aspect if you want the result array to be "compressed". What I mean by "compressed" is, that the result should be consecutive. So, for example for 0 1 0 1 0 1 0 1...
, there are two possibilities:
Non-consecutive:
XMM0: 000 001 000 003 000 005 000 007 000 009 000 011 000 013 000 015
Consecutive:
XMM0: 001 003 005 007 009 011 013 015 000 000 000 000 000 000 000 000
One problem of the consecutive approach is: how do you decide if it's index 0
or a termination value?
I'm offering a simple solution to the first, non-consecutive approach, which should be quite fast:
.data
ddqZeroToFifteen db 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
ddqTestValue: db 0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1
.code
movdqa xmm0, xmmword ptr [ddqTestValue]
pxor xmm1, xmm1 ; zero XMM1
pcmpeqb xmm0, xmm1 ; set to -1 for all matching
pandn xmm0, xmmword ptr [ddqZeroToFifteen] ; invert and apply indices
Just for the sake of completeness: the second, the consecutive approach, is not covered in this answer.
Updated answer: the new solution is slightly more efficient.
You can do this without a loop by using the pext instruction from the Bit Manipulation Instruction Set 2 ,
in combination with a few other SSE instructions.
/*
gcc -O3 -Wall -m64 -mavx2 -march=broadwell ind_nonz_avx.c
*/
#include <stdio.h>
#include <immintrin.h>
#include <stdint.h>
__m128i nonz_index(__m128i x){
/* Set some constants that will (hopefully) be hoisted out of a loop after inlining. */
uint64_t indx_const = 0xFEDCBA9876543210; /* 16 4-bit integers, all possible indices from 0 o 15 */
__m128i cntr = _mm_set_epi8(64,60,56,52,48,44,40,36,32,28,24,20,16,12,8,4);
__m128i pshufbcnst = _mm_set_epi8(0x80,0x80,0x80,0x80,0x80,0x80,0x80,0x80, 0x0E,0x0C,0x0A,0x08,0x06,0x04,0x02,0x00);
__m128i cnst0F = _mm_set1_epi8(0x0F);
__m128i msk = _mm_cmpeq_epi8(x,_mm_setzero_si128()); /* Generate 16x8 bit mask. */
msk = _mm_srli_epi64(msk,4); /* Pack 16x8 bit mask to 16x4 bit mask. */
msk = _mm_shuffle_epi8(msk,pshufbcnst); /* Pack 16x8 bit mask to 16x4 bit mask, continued. */
uint64_t msk64 = ~ _mm_cvtsi128_si64x(msk); /* Move to general purpose register and invert 16x4 bit mask. */
/* Compute the termination byte nonzmsk separately. */
int64_t nnz64 = _mm_popcnt_u64(msk64); /* Count the nonzero bits in msk64. */
__m128i nnz = _mm_set1_epi8(nnz64); /* May generate vmovd + vpbroadcastb if AVX2 is enabled. */
__m128i nonzmsk = _mm_cmpgt_epi8(cntr,nnz); /* nonzmsk is a mask of the form 0xFF, 0xFF, ..., 0xFF, 0, 0, ...,0 to mark the output positions without an index */
uint64_t indx64 = _pext_u64(indx_const,msk64); /* parallel bits extract. pext shuffles indx_const such that indx64 contains the nnz64 4-bit indices that we want.*/
__m128i indx = _mm_cvtsi64x_si128(indx64); /* Use a few integer instructions to unpack 4-bit integers to 8-bit integers. */
__m128i indx_024 = indx; /* Even indices. */
__m128i indx_135 = _mm_srli_epi64(indx,4); /* Odd indices. */
indx = _mm_unpacklo_epi8(indx_024,indx_135); /* Merge odd and even indices. */
indx = _mm_and_si128(indx,cnst0F); /* Mask out the high bits 4,5,6,7 of every byte. */
return _mm_or_si128(indx,nonzmsk); /* Merge indx with nonzmsk . */
}
int main(){
int i;
char w[16],xa[16];
__m128i x;
/* Example with bytes 15, 12, 7, 5, 4, 3, 2, 1, 0 set. */
x = _mm_set_epi8(1,0,0,1, 0,0,0,0, 1,0,1,1, 1,1,1,1);
/* Other examples. */
/*
x = _mm_set_epi8(1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1);
x = _mm_set_epi8(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0);
x = _mm_set_epi8(1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0);
x = _mm_set_epi8(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,1);
*/
__m128i indices = nonz_index(x);
_mm_storeu_si128((__m128i *)w,indices);
_mm_storeu_si128((__m128i *)xa,x);
printf("counter 15..0 ");for (i=15;i>-1;i--) printf(" %2d ",i); printf("\n\n");
printf("example xmm: ");for (i=15;i>-1;i--) printf(" %2d ",xa[i]); printf("\n");
printf("result in dec ");for (i=15;i>-1;i--) printf(" %2hhd ",w[i]); printf("\n");
printf("result in hex ");for (i=15;i>-1;i--) printf(" %2hhX ",w[i]); printf("\n");
return 0;
}
It takes about five instructions to get 0xFF (the termination byte) at the unwanted positions.
Note that a function nonz_index
that returns the indices and only the position of the termination byte, without actually
inserting the termination byte(s), would be much cheaper to compute and might be as suitable in a particular application.
The position of the first termination byte is nnz64>>2
.
The result is:
$ ./a.out
counter 15..0 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0
example xmm: 1 0 0 1 0 0 0 0 1 0 1 1 1 1 1 1
result in dec -1 -1 -1 -1 -1 -1 -1 15 12 7 5 4 3 2 1 0
result in hex FF FF FF FF FF FF FF F C 7 5 4 3 2 1 0
The pext
instruction is supported on Intel Haswell processors or newer.