Is it possible to check if any of 2 sets of 3 ints

2019-04-19 11:19发布

问题:

int eq3(int a, int b, int c, int d, int e, int f){
    return a == d || a == e || a == f 
        || b == d || b == e || b == f 
        || c == d || c == e || c == f;
}

This function receives 6 ints and returns true if any of the 3 first ints is equal to any of the 3 last ints. Is there any bitwise-hack similar way to make it faster?

回答1:

Expanding on dawg's SSE comparison method, you can combine the results of the comparisons using a vector OR, and move a mask of the compare results back to an integer to test for 0 / non-zero.

Also, you can get data into vectors more efficiently (although it's still pretty clunky to get many separate integers into vectors when they're live in registers to start with, rather than sitting in memory).

You should avoid store-forwarding stalls that result from doing three small stores and one big load.

///// UNTESTED ////////
#include <immintrin.h>
int eq3(int a, int b, int c, int d, int e, int f){

    // Use _mm_set to let the compiler worry about getting integers into vectors
    // Use -mtune=intel or gcc will make bad code, though :(
    __m128i abcc = _mm_set_epi32(0,c,b,a);  // args go from high to low position in the vector
    // masking off the high bits of the result-mask to avoid false positives
    // is cheaper than repeating c (to do the same compare twice)

    __m128i dddd = _mm_set1_epi32(d);
    __m128i eeee = _mm_set1_epi32(e);

    dddd = _mm_cmpeq_epi32(dddd, abcc);
    eeee = _mm_cmpeq_epi32(eeee, abcc);  // per element: 0(unequal) or -1(equal)
    __m128i combined = _mm_or_si128(dddd, eeee);

    __m128i ffff = _mm_set1_epi32(f);
    ffff = _mm_cmpeq_epi32(ffff, abcc);
    combined = _mm_or_si128(combined, ffff);

    // results of all the compares are ORed together.  All zero only if there were no hits
    unsigned equal_mask = _mm_movemask_epi8(combined);
    equal_mask &= 0x0fff;  // the high 32b element could have false positives
    return equal_mask;
    // return !!equal_mask if you want to force it to 0 or 1
    // the mask tells you whether it was a, b, or c that had a hit

    // movmskps would return a mask of just 4 bits, one for each 32b element, but might have a bypass delay on Nehalem.
    // actually, pmovmskb apparently runs in the float domain on Nehalem anyway, according to Agner Fog's table >.<
}

This compiles to pretty nice asm, pretty similar between clang and gcc, but clang's -fverbose-asm puts nice comments on the shuffles. Only 19 instructions including the ret, with a decent amount of parallelism from separate dependency chains. With -msse4.1, or -mavx, it saves another couple of instructions. (But probably doesn't run any faster)

With clang, dawg's version is about twice the size. With gcc, something bad happens and it's horrible (over 80 instructions. Looks like a gcc optimization bug, since it looks worse than just a straightforward translation of the source). Even clang's version spends so long getting data into / out of vector regs that it might be faster to just do the comparisons branchlessly and OR the truth values together.

This compiles to decent code:

// 8bit variable doesn't help gcc avoid partial-register stalls even with -mtune=core2 :/
int eq3_scalar(int a, int b, int c, int d, int e, int f){
    char retval = (a == d) | (a == e) | (a == f)
         | (b == d) | (b == e) | (b == f)
         | (c == d) | (c == e) | (c == f);
    return retval;
}

Play around with how to get the data from the caller into vector regs. If the groups of three are coming from memory, then prob. passing pointers so a vector load can get them from their original location is best. Going through integer registers on the way to vectors sucks (higher latency, more uops), but if your data is already live in regs it's a loss to do integer stores and then vector loads. gcc is dumb and follows the AMD optimization guide's recommendation to bounce through memory, even though Agner Fog says he's found that's not worth it even on AMD CPUs. It's definitely worse on Intel, and apparently a wash or maybe still worse on AMD, so it's definitely the wrong choice for -mtune=generic. Anyway...


It's also possible to do 8 of our 9 compares with just two packed-vector compares.

The 9th can be done with an integer compare, and have its truth value ORed with the vector result. On some CPUs (esp. AMD, and maybe Intel Haswell and later) not transferring one of the 6 integers to vector regs at all might be a win. Mixing three integer branchless-compares in with the vector shuffles / compares would interleave them nicely.

These vector comparisons can be set up by using shufps on integer data (since it can combine data from two source registers). That's fine on most CPUs, but requires a lot of annoying casting when using intrinsics instead of actual asm. Even if there is a bypass delay, it's not a bad tradeoff vs. something like punpckldq and then pshufd.

aabb    ccab
====    ====
dede    deff

c==f

with asm something like:

#### untested
# pretend a is in eax, and so on
movd     xmm0, eax
movd     xmm1, ebx
movd     xmm2, ecx

shl      rdx, 32
#mov     edi, edi     # zero the upper 32 of rdi if needed, or use shld instead of OR if you don't care about AMD CPUs
or       rdx, rdi     # de in an integer register.
movq     xmm3, rdx    # de, aka (d<<32)|e
# in 32bit code, use a vector shuffle of some sort to do this in a vector reg, or:
#pinsrd  xmm3, edi, 1  # SSE4.1, and 2 uops (same as movd+shuffle)
#movd    xmm4, edi    # e

movd     xmm5, esi    # f

shufps   xmm0, xmm1, 0            #  xmm0=aabb  (low dword = a; my notation is backwards from left/right vector-shift perspective)
shufps   xmm5, xmm3, 0b01000000   # xmm5 = ffde  
punpcklqdq xmm3, xmm3            # broadcast: xmm3=dede
pcmpeqd  xmm3, xmm0              # xmm3: aabb == dede

# spread these instructions out between vector instructions, if you aren't branching
xor      edx,edx
cmp      esi, ecx     # c == f
#je   .found_match    # if there's one of the 9 that's true more often, make it this one.  Branch mispredicts suck, though
sete     dl

shufps   xmm0, xmm2, 0b00001000  # xmm0 = abcc
pcmpeqd  xmm0, xmm5              # abcc == ffde

por      xmm0, xmm3
pmovmskb eax, xmm0    # will have bits set if cmpeq found any equal elements
or       eax, edx     # combine vector and scalar compares
jnz  .found_match
# or record the result instead of branching on it
setnz    dl

This is also 19 instructions (not counting the final jcc / setcc), but one of them is an xor-zeroing idiom, and there are other simple integer instructions. (Shorter encoding, some can run on port6 on Haswell+ which can't handle vector instructions). There might be a longer dep chain due to the chain of shuffles that builds abcc.



回答2:

Assuming you're expecting a high rate of false results you could make a quick "pre-check" to quickly isolate such cases:

If a bit in a is set that isn't set in any of d, e and f then a cannot be equal to any of these.

Thus something like

int pre_eq3(int a, int b, int c, int d, int e, int f){
    int const mask = ~(d | e | f);
    if ((a & mask) && (b & mask) && (c & mask)) {
         return false;
    }
    return eq3(a, b, c, d, e, f);
}

could speed it up (8 operations instead of 9 17, but much more costly if the result will actually be true). If mask == 0 then of course this won't help.


This can be further improved if with high probability a & b & c has some bits set:

int pre_eq3(int a, int b, int c, int d, int e, int f){
    int const mask = ~(d | e | f);
    if ((a & b & c) & mask) {
        return false;
    }
    if ((a & mask) && (b & mask) && (c & mask)) {
         return false;
    }
    return eq3(a, b, c, d, e, f);
}

Now if all of a, b and c have bits set where none of d, e and c have any bits set we're out pretty fast.



回答3:

If you want a bitwise version look to xor. If you xor two numbers that are the same the answer will be 0. Otherwise, the bits will flip if one is set and the other is not. For example 1000 xor 0100 is 1100.

The code you have will likely cause at least 1 pipeline flush but apart from that it will be ok performance wise.



回答4:

I think using SSE is probably worth investigating.

It has been 20 years since I wrote any, and not benchmarked, but something like:

#include <xmmintrin.h>
int cmp3(int32_t a, int32_t b, int32_t c, int32_t d, int32_t e, int32_t f){
    // returns -1 if any of a,b,c is eq to any of d,e,f
    // returns 0 if all a,b,c != d,e,f
    int32_t __attribute__ ((aligned(16))) vec1[4];
    int32_t __attribute__ ((aligned(16))) vec2[4];
    int32_t __attribute__ ((aligned(16))) vec3[4];
    int32_t __attribute__ ((aligned(16))) vec4[4];
    int32_t __attribute__ ((aligned(16))) r1[4];
    int32_t __attribute__ ((aligned(16))) r2[4];
    int32_t __attribute__ ((aligned(16))) r3[4];

    // fourth word is DNK
    vec1[0]=a;
    vec1[1]=b;
    vec1[2]=c;

    vec2[0]=vec2[1]=vec2[2]=d;
    vec3[0]=vec3[1]=vec3[2]=e;
    vec4[0]=vec4[1]=vec4[2]=f;

    __m128i v1 = _mm_load_si128((__m128i *)vec1);
    __m128i v2 = _mm_load_si128((__m128i *)vec2);
    __m128i v3 = _mm_load_si128((__m128i *)vec3);
    __m128i v4 = _mm_load_si128((__m128i *)vec4);

    // any(a,b,c) == d? 
    __m128i vcmp1 = _mm_cmpeq_epi32(v1, v2);
    // any(a,b,c) == e?
    __m128i vcmp2 = _mm_cmpeq_epi32(v1, v3);
    // any(a,b,c) == f?
    __m128i vcmp3 = _mm_cmpeq_epi32(v1, v4);

    _mm_store_si128((__m128i *)r1, vcmp1);
    _mm_store_si128((__m128i *)r2, vcmp2);
    _mm_store_si128((__m128i *)r3, vcmp3);

    // bit or the first three of each result.
    // might be better with SSE mask, but I don't remember how!
    return r1[0] | r1[1] | r1[2] |
           r2[0] | r2[1] | r2[2] |
           r3[0] | r3[1] | r3[2];
}

If done correctly, SSE with no branches should be 4x to 8x faster.



回答5:

If your compiler/architecture supports vector extensions (like clang and gcc) you can use something like:

#ifdef __SSE2__
#include <immintrin.h>
#elif defined __ARM_NEON
#include <arm_neon.h>
#elif defined __ALTIVEC__
#include <altivec.h>
//#elif ... TODO more architectures
#endif

static int hastrue128(void *x){
#ifdef __SSE2__
    return _mm_movemask_epi8(*(__m128i*)x);
#elif defined __ARM_NEON
    return vaddlvq_u8(*(uint8x16_t*)x);
#elif defined __ALTIVEC__
typedef __UINT32_TYPE__ v4si __attribute__ ((__vector_size__ (16), aligned(4), __may_alias__));
    return vec_any_ne(*(v4si*)x,(v4si){0});
#else
    int *y = x;
    return y[0]|y[1]|y[2]|y[3];
#endif
}

//if inputs will always be aligned to 16 add an aligned attribute
//otherwise ensure they are at least aligned to 4
int cmp3(  int* a  ,  int* b ){
typedef __INT32_TYPE__ i32x4 __attribute__ ((__vector_size__ (16), aligned(4), __may_alias__));
    i32x4 x = *(i32x4*)a, cmp, tmp, y0 = y0^y0, y1 = y0, y2 = y0;
    //start vectors off at 0 and add the int to each element for optimization
    //it adds the int to each element, but since we started it at zero,
    //a good compiler (not ICC at -O3) will skip the xor and add and just broadcast/whatever
    y0 += b[0];
    y1 += b[1];
    y2 += b[2];
    cmp =  x == y0;
    tmp =  x == y1; //ppc complains if we don't use temps here
    cmp |= tmp;
    tmp =  x ==y2;
    cmp |= tmp;
    //now hack off then end since we only need 3
    cmp &= (i32x4){0xffffffff,0xffffffff,0xffffffff,0};
    return hastrue128(&cmp);
}

int cmp4(  int* a  ,  int* b ){
typedef __INT32_TYPE__ i32x4 __attribute__ ((__vector_size__ (16), aligned(4), __may_alias__));
    i32x4 x = *(i32x4*)a, cmp, tmp, y0 = y0^y0, y1 = y0, y2 = y0, y3 = y0;
    y0 += b[0];
    y1 += b[1];
    y2 += b[2];
    y3 += b[3];
    cmp =  x == y0;
    tmp =  x == y1; //ppc complains if we don't use temps here
    cmp |= tmp;
    tmp =  x ==y2;
    cmp |= tmp;
    tmp =  x ==y3;
    cmp |= tmp;
    return hastrue128(&cmp);
}

On arm64 this compiles to the following branchless code:

cmp3:
        ldr     q2, [x0]
        adrp    x2, .LC0
        ld1r    {v1.4s}, [x1]
        ldp     w0, w1, [x1, 4]
        dup     v0.4s, w0
        cmeq    v1.4s, v2.4s, v1.4s
        dup     v3.4s, w1
        ldr     q4, [x2, #:lo12:.LC0]
        cmeq    v0.4s, v2.4s, v0.4s
        cmeq    v2.4s, v2.4s, v3.4s
        orr     v0.16b, v1.16b, v0.16b
        orr     v0.16b, v0.16b, v2.16b
        and     v0.16b, v0.16b, v4.16b
        uaddlv h0,v0.16b
        umov    w0, v0.h[0]
        uxth    w0, w0
        ret
cmp4:
        ldr     q2, [x0]
        ldp     w2, w0, [x1, 4]
        dup     v0.4s, w2
        ld1r    {v1.4s}, [x1]
        dup     v3.4s, w0
        ldr     w1, [x1, 12]
        dup     v4.4s, w1
        cmeq    v1.4s, v2.4s, v1.4s
        cmeq    v0.4s, v2.4s, v0.4s
        cmeq    v3.4s, v2.4s, v3.4s
        cmeq    v2.4s, v2.4s, v4.4s
        orr     v0.16b, v1.16b, v0.16b
        orr     v0.16b, v0.16b, v3.16b
        orr     v0.16b, v0.16b, v2.16b
        uaddlv h0,v0.16b
        umov    w0, v0.h[0]
        uxth    w0, w0
        ret

And on ICC x86_64 -march=skylake it produces the following branchless code:

cmp3:
        vmovdqu   xmm2, XMMWORD PTR [rdi]                       #27.24
        vpbroadcastd xmm0, DWORD PTR [rsi]                      #34.17
        vpbroadcastd xmm1, DWORD PTR [4+rsi]                    #35.17
        vpcmpeqd  xmm5, xmm2, xmm0                              #34.17
        vpbroadcastd xmm3, DWORD PTR [8+rsi]                    #37.16
        vpcmpeqd  xmm4, xmm2, xmm1                              #35.17
        vpcmpeqd  xmm6, xmm2, xmm3                              #37.16
        vpor      xmm7, xmm4, xmm5                              #36.5
        vpor      xmm8, xmm6, xmm7                              #38.5
        vpand     xmm9, xmm8, XMMWORD PTR __$U0.0.0.2[rip]      #40.5
        vpmovmskb eax, xmm9                                     #11.12
        ret                                                     #41.12
cmp4:
        vmovdqu   xmm3, XMMWORD PTR [rdi]                       #46.24
        vpbroadcastd xmm0, DWORD PTR [rsi]                      #51.17
        vpbroadcastd xmm1, DWORD PTR [4+rsi]                    #52.17
        vpcmpeqd  xmm6, xmm3, xmm0                              #51.17
        vpbroadcastd xmm2, DWORD PTR [8+rsi]                    #54.16
        vpcmpeqd  xmm5, xmm3, xmm1                              #52.17
        vpbroadcastd xmm4, DWORD PTR [12+rsi]                   #56.16
        vpcmpeqd  xmm7, xmm3, xmm2                              #54.16
        vpor      xmm8, xmm5, xmm6                              #53.5
        vpcmpeqd  xmm9, xmm3, xmm4                              #56.16
        vpor      xmm10, xmm7, xmm8                             #55.5
        vpor      xmm11, xmm9, xmm10                            #57.5
        vpmovmskb eax, xmm11                                    #11.12
        ret

And it even works on ppc64 with altivec - though definitely suboptimal

cmp3:
        lwa 10,4(4)
        lxvd2x 33,0,3
        vspltisw 11,-1
        lwa 9,8(4)
        vspltisw 12,0
        xxpermdi 33,33,33,2
        lwa 8,0(4)
        stw 10,-32(1)
        addi 10,1,-80
        stw 9,-16(1)
        li 9,32
        stw 8,-48(1)
        lvewx 0,10,9
        li 9,48
        xxspltw 32,32,3
        lvewx 13,10,9
        li 9,64
        vcmpequw 0,1,0
        lvewx 10,10,9
        xxsel 32,44,43,32
        xxspltw 42,42,3
        xxspltw 45,45,3
        vcmpequw 13,1,13
        vcmpequw 1,1,10
        xxsel 45,44,43,45
        xxsel 33,44,43,33
        xxlor 32,32,45
        xxlor 32,32,33
        vsldoi 1,12,11,12
        xxland 32,32,33
        vcmpequw. 0,0,12
        mfcr 3,2
        rlwinm 3,3,25,1
        cntlzw 3,3
        srwi 3,3,5
        blr
cmp4:
        lwa 10,8(4)
        lxvd2x 33,0,3
        vspltisw 10,-1
        lwa 9,12(4)
        vspltisw 11,0
        xxpermdi 33,33,33,2
        lwa 7,0(4)
        lwa 8,4(4)
        stw 10,-32(1)
        addi 10,1,-96
        stw 9,-16(1)
        li 9,32
        stw 7,-64(1)
        stw 8,-48(1)
        lvewx 0,10,9
        li 9,48
        xxspltw 32,32,3
        lvewx 13,10,9
        li 9,64
        xxspltw 45,45,3
        vcmpequw 13,1,13
        xxsel 44,43,42,45
        lvewx 13,10,9
        li 9,80
        vcmpequw 0,1,0
        xxspltw 45,45,3
        xxsel 32,43,42,32
        vcmpequw 13,1,13
        xxlor 32,32,44
        xxsel 45,43,42,45
        lvewx 12,10,9
        xxlor 32,32,45
        xxspltw 44,44,3
        vcmpequw 1,1,12
        xxsel 33,43,42,33
        xxlor 32,32,33
        vcmpequw. 0,0,11
        mfcr 3,2
        rlwinm 3,3,25,1
        cntlzw 3,3
        srwi 3,3,5
        blr

As you can see from the generated asm, there is still a little room for improvement, but it will compile on risc-v, mips, ppc and other architecture+compiler combinations that support vector extensions.