How to get data out of AVX registers?

2020-04-02 09:31发布

问题:

Using MSVC 2013 and AVX 1, I've got 8 floats in a register:

__m256 foo = mm256_fmadd_ps(a,b,c);

Now I want to call inline void print(float) {...} for all 8 floats. It looks like the Intel AVX intrisics would make this rather complicated:

print(_castu32_f32(_mm256_extract_epi32(foo, 0)));
print(_castu32_f32(_mm256_extract_epi32(foo, 1)));
print(_castu32_f32(_mm256_extract_epi32(foo, 2)));
// ...

but MSVC doesn't even have either of these two intrinsics. Sure, I could write back the values to memory and load from there, but I suspect that at assembly level there's no need to spill a register.

Bonus Q: I'd of course like to write

for(int i = 0; i !=8; ++i) 
    print(_castu32_f32(_mm256_extract_epi32(foo, i)))

but MSVC doesn't understand that many intrinsics require loop unrolling. How do I write a loop over the 8x32 floats in __m256 foo?

回答1:

Careful: _mm256_fmadd_ps isn't part of AVX1. FMA3 has its own feature bit, and was only introduced on Intel with Haswell. AMD introduced FMA3 with Piledriver (AVX1+FMA4+FMA3, no AVX2).


At the asm level, if you want to get eight 32bit elements into integer registers, it is actually faster to store to the stack and then do scalar loads. pextrd is a 2-uop instruction on SnB-family, and Bulldozer-family. (and Nehalem and Silvermont, which don't support AVX).

The only CPU where vextractf128 + 2xmovd + 6xpextrd isn't terrible is AMD Jaguar. (cheap pextrd, and only one load port.) (See Agner Fog's insn tables)

A wide aligned store can forward to overlapping narrow loads. (Of course, you can use movd to get the low element, so you have a mix of load port and ALU port uops).


Of course, you seem to be extracting floats by using an integer extract and then converting it back to a float. That seems horrible.

What you actually need is each float in the low element of its own xmm register. vextractf128 is obviously the way to start, bringing element 4 to the bottom of a new xmm reg. Then 6x AVX shufps can easily get the other three elements of each half. (Or movshdup and movhlps have shorter encodings: no immediate byte).

7 shuffle uops are worth considering vs. 1 store and 7 load uops, but not if you were going to spill the vector for a function call anyway.


ABI considerations:

You're on Windows, where xmm6-15 are call-preserved (only the low128; the upper halves of ymm6-15 are call-clobbered). This is yet another reason to start with vextractf128.

In the SysV ABI, all the xmm / ymm / zmm registers are call-clobbered, so every print() function requires a spill/reload. The only sane thing to do there is store to memory and call print with the original vector (i.e. print the low element, because it will ignore the rest of the register). Then movss xmm0, [rsp+4] and call print on the 2nd element, etc.

It does you no good to get all 8 floats nicely unpacked into 8 vector regs, because they'd all have to be spilled separately anyway before the first function call!



回答2:

Assuming you only have AVX (i.e. no AVX2) then you could do something like this:

float extract_float(const __m128 v, const int i)
{
    float x;
    _MM_EXTRACT_FLOAT(x, v, i);
    return x;
}

void print(const __m128 v)
{
    print(extract_float(v, 0));
    print(extract_float(v, 1));
    print(extract_float(v, 2));
    print(extract_float(v, 3));
}

void print(const __m256 v)
{
    print(_mm256_extractf128_ps(v, 0));
    print(_mm256_extractf128_ps(v, 1));
}

However I think I would probably just use a union:

union U256f {
    __m256 v;
    float a[8];
};

void print(const __m256 v)
{
    const U256f u = { v };

    for (int i = 0; i < 8; ++i)
        print(u.a[i]);
}


回答3:

    float valueAVX(__m256 a, int i){

        float ret = 0;
        switch (i){

            case 0:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)      ( a3, a2, a1, a0 )
// cvtss_f32             a0 

                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 0));
                break;
            case 1: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)     lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 1)      ( - , a3, a2, a1 )
// cvtss_f32                 a1 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 1));
            }
                break;
            case 2: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// movehl(lo, lo)        ( - , - , a3, a2 )
// cvtss_f32               a2 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_movehl_ps(lo, lo));
            }
                break;
            case 3: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 3)    ( - , - , - , a3 )
// cvtss_f32               a3 
                __m128 lo = _mm256_extractf128_ps(a, 0);                    
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 3));
            }
                break;

            case 4:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)      ( a7, a6, a5, a4 )
// cvtss_f32             a4 
                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 1));
                break;
            case 5: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)     hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 1)      ( - , a7, a6, a5 )
// cvtss_f32                 a5 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 1));
            }
                break;
            case 6: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// movehl(hi, hi)        ( - , - , a7, a6 )
// cvtss_f32               a6 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_movehl_ps(hi, hi));
            }
                break;
            case 7: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 3)    ( - , - , - , a7 )
// cvtss_f32               a7 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 3));
            }
                break;
        }

        return ret;
    }


回答4:

(Unfinished answer. Posting anyway in case it helps anyone, or in case I come back to it. Generally if you need to interface with scalar that you can't vectorize, it's not bad to just store a vector to a local array, and then reload it one element at a time.)

See my other answer for asm details. This answer is about the C++ side of things.


void foo(__m256 v) {
    alignas(32) float vecbuf[8];   // 32-byte aligned array allows aligned store
                                   // avoiding the risk of cache-line splits
    _mm256_store_ps(vecbuf, v);

    float v0 = _mm_cvtss_f32(_mm256_castps256_ps128(v));  // the bottom of the register
    float v1 = vecbuf[1];
    float v2 = vecbuf[2];
    ...
   // or loop over vecbuf[i]
   // if you do need all 8 elements one at a time, this is a good way
}

or loop over vecbuf[i]. A vector store can forward to a scalar reload of one of its elements so this only introduces about 6 cycles of latency, and multiple reloads can be in flight at once. (So it's very good for throughput on modern CPUs with 2/clock load throughput.)

Note that I avoided reloading the low element; the low element of a vector in a register already is a scalar float. _mm_cvtss_f32( _mm256_castps256_ps128(v) ) is simply how you keep the compiler's type system happy; it compiles to zero asm instructions and so it's literally free (barring missed-optimization bugs). (See Intel's intrinsics guide). XMM registers are the low 128 of the corresponding YMM register, and scalar float / double are the low 32 or 64 bits of an XMM register. (Garbage in the upper half doesn't matter.)

Casting the first once gives OoO exec something to do while waiting for the rest to arrive. You might consider shuffling to get a 2nd element with vunpckhps or vmovhlps on the low 128, so you have 2 elements ready quickly, if that helps fill the latency bubble.

In GNU C/C++, you can index a vector type like an array, with v[1] or even a variable index like v[i]. The compiler will choose between shuffle or store/reload.

But this isn't portable to MSVC which defines __m256 in terms of a union with some named members.

Storing to an array and reloading is portable, and compilers can sometimes even optimize it into a shuffle. (If you don't want that, check the generated asm.)

e.g. clang optimizes a function that just returns vecbuf[1] into a simple vshufps. https://godbolt.org/z/tHJH_V


If you actually want to add up all the elements of a vector into a scalar total, shuffle and SIMD add. Fastest way to do horizontal float vector sum on x86

(Same for multiply, min, max or other associative reductions over the elements of a single vector. Of course if you have multiple vectors, do vertical ops down to one vector, like _mm256_add_ps(v1,v2))


Using Agner Fog's Vector Class Library, his wrapper classes overload operator[] to work exactly the way you'd expect, even for non-constant args. This often compiles to a store/reload, but it makes it easy to write the code in C++. With optimization enabled, you'll probably get decent results. (except the low element might get stored/reloaded, instead of just getting used in place. So you might want to special-case vec[0] into _mm_cvtss_f32(vec) or something.)

(VCL used to be licensed under the GPL, but the current version is now a simple Apache license.)

See also my github repo with mostly-untested changes to Agner's VCL, to generate better code for some functions.


There's a _MM_EXTRACT_FLOAT wrapper macro, but it's weird and only defined with SSE4.1. I think it's intended to go with SSE4.1 extractps (which can extract the binary representation of a float into an integer register, or store to memory). It gcc does compile it into an FP shuffle when the destination is a float, though. Be careful that other compilers don't compile it to an actual extractps instruction if you want the result as a float, because that's not what extractps does. (That is what insertps does, but a simpler FP shuffle would take fewer instruction bytes. e.g. shufps with AVX is great.)

It's weird because it takes 3 args: _MM_EXTRACT_FLOAT(dest, src_m128, idx), so you can't even use it as an initializer for a float local.


To loop over a vector

gcc will unroll a loop like that for you, but only with -O1 or higher. At -O0, it will give you an error message.

float bad_hsum(__m128 & fv) {
    float sum = 0;
    for (int i=0 ; i<4 ; i++) {
        float f;
        _MM_EXTRACT_FLOAT(f, fv, i);  // works only with -O1 or higher
        sum += f;
    }
    return sum;
}