What's the fastest stride-3 gather instruction

2020-02-05 03:14发布

The question:

What is the most efficient sequence to generate a stride-3 gather of 32-bit elements from memory? If the memory is arranged as:

MEM = R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 ...

We want to obtain three YMM registers where:

YMM0 = R0 R1 R2 R3 R4 R5 R6 R7
YMM1 = G0 G1 G2 G3 G4 G5 G6 G7
YMM2 = B0 B1 B2 B3 B4 B5 B6 B7

Motivation and discussion

The scalar C code is something like

template <typename T>
T Process(const T* Input) {
  T Result = 0;
  for (int i=0; i < 4096; ++i) {
    T R = Input[3*i];
    T G = Input[3*i+1];
    T B = Input[3*i+2];
    Result += some_parallelizable_algorithm<T>(R, G, B);  
  }
  return Result;
}

Let's say that some_parallelizable_algorithm was written in intrinsics and was tuned to the fastest possible implementation possible:

template <typename T>
__m256i some_parallelizable_algorithm(__m256i R, __m256i G, __m256i B);

So the vector implementation for T=int32_t can be something like:

    template <>
    int32_t Process<int32_t>(const int32_t* Input) {
     __m256i Step = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
     __m256i Result = _mm256_setzero_si256(); 
     for (int i=0; i < 4096; i+=8) {
       // R = R0 R1 R2 R3 R4 R5 R6 R7
       __m256i R = _mm256_i32gather_epi32 (Input+3*i, Step, 3);
       // G = G0 G1 G2 G3 G4 G5 G6 G7
       __m256i G = _mm256_i32gather_epi32 (Input+3*i+1, Step, 3);
       // B = B0 B1 B2 B3 B4 B5 B6 B7
       __m256i B = _mm256_i32gather_epi32 (Input+3*i+2, Step, 3);
       Result = _mm256_add_epi32 (Result, 
                                  some_parallelizable_algorithm<int32_t>(R, G, B));
     }
   // Here should be the less interesting part:
   // Perform a reduction on Result and return the result
}

First, this can be done because there are gather instructions for 32-bit elements, but there are none for 16-bit elements or 8-bit elements. Second, and more importantly, the gather instruction above should be entirely avoided for performance reasons. It is probably more efficient to use contiguous wide loads and shuffle the loaded values to obtain the R, G and B vectors.

    template <>
    int32_t Process<int32_t>(const int32_t* Input) {
     __m256i Result = _mm256_setzero_si256(); 
     for (int i=0; i < 4096; i+=3) {
       __m256i Ld0 = _mm256_lddqu_si256((__m256i*)Input+3*i));
       __m256i Ld1 = _mm256_lddqu_si256((__m256i*)Input+3*i+1));
       __m256i Ld2 = _mm256_lddqu_si256((__m256i*)Input+3*i+2));
       __m256i R = ???
       __m256i G = ???
       __m256i B = ???
       Result = _mm256_add_epi32 (Result, 
                                  some_parallelizable_algorithm<int32_t>(R, G, B));
     }
   // Here should be the less interesting part:
   // Perform a reduction on Result and return the result
}

It seems that for power-2 strides (2, 4, ...) there are known methods using UNKPCKL/UNKPCKH, but for stride-3 accesses i could not find any references.

I am interested in solving this for T=int32_t, T=int16_t and T=int8_t, but to remain focused let's only discuss the first case.

2条回答
▲ chillily
2楼-- · 2020-02-05 03:32

The 8-bit integer case.

As already mentioned in the comments above, two input shuffle instructions, such as vshufps, don't exist for 8-bit granularity. Hence, the 8-bit solution differs a bit from the 32-bit solution. Two different solutions are described below.


A straightforward approach is to group the 8-bit integers 'color by color (R G B)' with 6 vpblendvb-s, followed by a vpshufb permutation:

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=broadwell stride_3.c   */
int __attribute__ ((noinline)) print_vec_char(__m256i x);

int main() {

    char *m;
    int i;
    __m256i blnd1 = _mm256_set_epi8(0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,     0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0);
    __m256i blnd2 = _mm256_set_epi8(0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,     0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0);
    __m256i p0    = _mm256_set_epi8(13,10,7,4,1, 14,11,8,5,2, 15,12,9,6,3,0,    13,10,7,4,1, 14,11,8,5,2, 15,12,9,6,3,0);
    __m256i p1    = _mm256_set_epi8(14,11,8,5,2, 15,12,9,6,3,0, 13,10,7,4,1,    14,11,8,5,2, 15,12,9,6,3,0, 13,10,7,4,1);
    __m256i p2    = _mm256_set_epi8(15,12,9,6,3,0, 13,10,7,4,1, 14,11,8,5,2,    15,12,9,6,3,0, 13,10,7,4,1, 14,11,8,5,2);
            m     = _mm_malloc(96,32);
    for(i = 0; i < 96; i++) m[i] = i;

//        printf("m_lo  ");print_vec_char(_mm256_load_si256((__m256i*)&m[0]));printf("m_mid ");print_vec_char(_mm256_load_si256((__m256i*)&m[32]));printf("m_hi  ");print_vec_char(_mm256_load_si256((__m256i*)&m[64]));printf("\n");
//        m_lo   31  30  29  28 |  27  26  25  24 |  23  22  21  20 |  19  18  17  16 ||  15  14  13  12 |  11  10   9   8 |   7   6   5   4 |   3   2   1   0 
//        m_mid  63  62  61  60 |  59  58  57  56 |  55  54  53  52 |  51  50  49  48 ||  47  46  45  44 |  43  42  41  40 |  39  38  37  36 |  35  34  33  32 
//        m_hi   95  94  93  92 |  91  90  89  88 |  87  86  85  84 |  83  82  81  80 ||  79  78  77  76 |  75  74  73  72 |  71  70  69  68 |  67  66  65  64 

    __m256i t0 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[0]));
    __m256i t1 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[16]));
    __m256i t2 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[32]));

            t0 = _mm256_inserti128_si256(t0,_mm_loadu_si128((__m128i*)&m[48]),1);
            t1 = _mm256_inserti128_si256(t1,_mm_loadu_si128((__m128i*)&m[64]),1);
            t2 = _mm256_inserti128_si256(t2,_mm_loadu_si128((__m128i*)&m[80]),1);

//        printf("t0  ");print_vec_char(t0);printf("t1  ");print_vec_char(t1);printf("t2  ");print_vec_char(t2);printf("\n");
//        t0   63  62  61  60 |  59  58  57  56 |  55  54  53  52 |  51  50  49  48 ||  15  14  13  12 |  11  10   9   8 |   7   6   5   4 |   3   2   1   0 
//        t1   79  78  77  76 |  75  74  73  72 |  71  70  69  68 |  67  66  65  64 ||  31  30  29  28 |  27  26  25  24 |  23  22  21  20 |  19  18  17  16 
//        t2   95  94  93  92 |  91  90  89  88 |  87  86  85  84 |  83  82  81  80 ||  47  46  45  44 |  43  42  41  40 |  39  38  37  36 |  35  34  33  32 

    __m256i u0 = _mm256_blendv_epi8( _mm256_blendv_epi8(t0,t1,blnd2), t2,blnd1);
    __m256i u1 = _mm256_blendv_epi8( _mm256_blendv_epi8(t1,t2,blnd2), t0,blnd1);
    __m256i u2 = _mm256_blendv_epi8( _mm256_blendv_epi8(t2,t0,blnd2), t1,blnd1);

//        printf("u0  ");print_vec_char(u0);printf("u1  ");print_vec_char(u1);printf("u2  ");print_vec_char(u2);printf("\n");
//        u0   63  78  93  60 |  75  90  57  72 |  87  54  69  84 |  51  66  81  48 ||  15  30  45  12 |  27  42   9  24 |  39   6  21  36 |   3  18  33   0 
//        u1   79  94  61  76 |  91  58  73  88 |  55  70  85  52 |  67  82  49  64 ||  31  46  13  28 |  43  10  25  40 |   7  22  37   4 |  19  34   1  16 
//        u2   95  62  77  92 |  59  74  89  56 |  71  86  53  68 |  83  50  65  80 ||  47  14  29  44 |  11  26  41   8 |  23  38   5  20 |  35   2  17  32 

            t0 = _mm256_shuffle_epi8(u0,p0);
            t1 = _mm256_shuffle_epi8(u1,p1);
            t2 = _mm256_shuffle_epi8(u2,p2);

          printf("t0  ");print_vec_char(t0);printf("t1  ");print_vec_char(t1);printf("t2  ");print_vec_char(t2);printf("\n");
//        t0   93  90  87  84 |  81  78  75  72 |  69  66  63  60 |  57  54  51  48 ||  45  42  39  36 |  33  30  27  24 |  21  18  15  12 |   9   6   3   0 
//        t1   94  91  88  85 |  82  79  76  73 |  70  67  64  61 |  58  55  52  49 ||  46  43  40  37 |  34  31  28  25 |  22  19  16  13 |  10   7   4   1 
//        t2   95  92  89  86 |  83  80  77  74 |  71  68  65  62 |  59  56  53  50 ||  47  44  41  38 |  35  32  29  26 |  23  20  17  14 |  11   8   5   2 

    return 0;
}



int __attribute__ ((noinline)) print_vec_char(__m256i x){
    char v[32];
    _mm256_storeu_si256((__m256i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi || ",
           v[31],v[30],v[29],v[28],v[27],v[26],v[25],v[24],v[23],v[22],v[21],v[20],v[19],v[18],v[17],v[16]);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi \n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

Instruction summary:

    3 vmovdqu
    3 vinserti128-load
    6 vpblendvb   
    3 vpshufb


Unfortunately, the vpblendvb instruction is often relatively slow: on Intel Skylake vpblendvb has a throughput of one per cycle and on AMD Ryzen and Intel Haswell the throughput is only one per two cylcles. Skylake-X has a fast byte blend vpblendmb(throughput three per cycle (256-bit) ), although on Skylake-X one might be more interested in a solution that works with 512-bit vectors instead of 256-bit.


An alternative is to combine vpshufb with vshufps, as suggested in @Peter Cordes' comments above. In the code below the data is loaded as 12-byte chunks. Altogether more instructions are needed than in the first solution. Nevertheless, the performance of this second solution is probably better than the first solution, depending on the surrounding code and the micro architecture.

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=broadwell stride_3.c   */
int __attribute__ ((noinline)) print_vec_char(__m256i x);

inline __m256i _mm256_shufps_epi32(__m256i a,__m256i b,int imm){return _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b),imm));}

int main() {

    char *m;
    int i;
    __m256i p0    = _mm256_set_epi8(-1,-1,-1,-1, 11,8,5,2, 10,7,4,1, 9,6,3,0,     -1,-1,-1,-1, 11,8,5,2, 10,7,4,1, 9,6,3,0);
    __m256i p1    = _mm256_set_epi8(11,8,5,2, 10,7,4,1, 9,6,3,0, -1,-1,-1,-1,     11,8,5,2, 10,7,4,1, 9,6,3,0, -1,-1,-1,-1);
    __m256i p2    = _mm256_set_epi8(10,7,4,1, 9,6,3,0, -1,-1,-1,-1, 11,8,5,2,     10,7,4,1, 9,6,3,0,-1, -1,-1,-1, 11,8,5,2);
    __m256i p3    = _mm256_set_epi8(9,6,3,0, -1,-1,-1,-1, 11,8,5,2, 10,7,4,1,     9,6,3,0, -1,-1,-1,-1, 11,8,5,2, 10,7,4,1);

            m     = _mm_malloc(96+4,32);   /* 4 extra dummy bytes to avoid errors with _mm_loadu_si128((__m128i*)&m[84]) . Otherwise use maskload instead of standard load */
    for(i = 0; i < 96; i++) m[i] = i;

//        printf("m_lo  ");print_vec_char(_mm256_load_si256((__m256i*)&m[0]));printf("m_mid ");print_vec_char(_mm256_load_si256((__m256i*)&m[32]));printf("m_hi  ");print_vec_char(_mm256_load_si256((__m256i*)&m[64]));printf("\n");
//        m_lo   31  30  29  28 |  27  26  25  24 |  23  22  21  20 |  19  18  17  16 ||  15  14  13  12 |  11  10   9   8 |   7   6   5   4 |   3   2   1   0 
//        m_mid  63  62  61  60 |  59  58  57  56 |  55  54  53  52 |  51  50  49  48 ||  47  46  45  44 |  43  42  41  40 |  39  38  37  36 |  35  34  33  32 
//        m_hi   95  94  93  92 |  91  90  89  88 |  87  86  85  84 |  83  82  81  80 ||  79  78  77  76 |  75  74  73  72 |  71  70  69  68 |  67  66  65  64 

    __m256i t0 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[0]));
    __m256i t1 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[12]));
    __m256i t2 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[24]));
    __m256i t3 = _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&m[36]));

            t0 = _mm256_inserti128_si256(t0,_mm_loadu_si128((__m128i*)&m[48]),1);
            t1 = _mm256_inserti128_si256(t1,_mm_loadu_si128((__m128i*)&m[60]),1);
            t2 = _mm256_inserti128_si256(t2,_mm_loadu_si128((__m128i*)&m[72]),1);
            t3 = _mm256_inserti128_si256(t3,_mm_loadu_si128((__m128i*)&m[84]),1);   /* Use a masked load (_mm_maskload_epi32) here if m[99] is not a valid address */

//        printf("t0  ");print_vec_char(t0);printf("t1  ");print_vec_char(t1);printf("t2  ");print_vec_char(t2);printf("t3  ");print_vec_char(t3);printf("\n");
//        t0   63  62  61  60 |  59  58  57  56 |  55  54  53  52 |  51  50  49  48 ||  15  14  13  12 |  11  10   9   8 |   7   6   5   4 |   3   2   1   0 
//        t1   75  74  73  72 |  71  70  69  68 |  67  66  65  64 |  63  62  61  60 ||  27  26  25  24 |  23  22  21  20 |  19  18  17  16 |  15  14  13  12 
//        t2   87  86  85  84 |  83  82  81  80 |  79  78  77  76 |  75  74  73  72 ||  39  38  37  36 |  35  34  33  32 |  31  30  29  28 |  27  26  25  24 
//        t3    0   0   0   0 |  95  94  93  92 |  91  90  89  88 |  87  86  85  84 ||  51  50  49  48 |  47  46  45  44 |  43  42  41  40 |  39  38  37  36 

            t0 = _mm256_shuffle_epi8(t0,p0);
            t1 = _mm256_shuffle_epi8(t1,p1);
            t2 = _mm256_shuffle_epi8(t2,p2);
            t3 = _mm256_shuffle_epi8(t3,p3);

//        printf("t0  ");print_vec_char(t0);printf("t1  ");print_vec_char(t1);printf("t2  ");print_vec_char(t2);printf("t3  ");print_vec_char(t3);printf("\n");
//        t0    0   0   0   0 |  59  56  53  50 |  58  55  52  49 |  57  54  51  48 ||   0   0   0   0 |  11   8   5   2 |  10   7   4   1 |   9   6   3   0 
//        t1   71  68  65  62 |  70  67  64  61 |  69  66  63  60 |   0   0   0   0 ||  23  20  17  14 |  22  19  16  13 |  21  18  15  12 |   0   0   0   0 
//        t2   82  79  76  73 |  81  78  75  72 |   0   0   0   0 |  83  80  77  74 ||  34  31  28  25 |  33  30  27  24 |   0   0   0   0 |  35  32  29  26 
//        t3   93  90  87  84 |   0   0   0   0 |  95  92  89  86 |  94  91  88  85 ||  45  42  39  36 |   0   0   0   0 |  47  44  41  38 |  46  43  40  37 

    __m256i u0 = _mm256_blend_epi32(t0,t1,0b10101010);
    __m256i u1 = _mm256_blend_epi32(t2,t3,0b10101010);
    __m256i u2 = _mm256_blend_epi32(t0,t1,0b01010101);
    __m256i u3 = _mm256_blend_epi32(t2,t3,0b01010101);

//        printf("u0  ");print_vec_char(u0);printf("u1  ");print_vec_char(u1);printf("u2  ");print_vec_char(u2);printf("u3  ");print_vec_char(u3);printf("\n");
//        u0   71  68  65  62 |  59  56  53  50 |  69  66  63  60 |  57  54  51  48 ||  23  20  17  14 |  11   8   5   2 |  21  18  15  12 |   9   6   3   0 
//        u1   93  90  87  84 |  81  78  75  72 |  95  92  89  86 |  83  80  77  74 ||  45  42  39  36 |  33  30  27  24 |  47  44  41  38 |  35  32  29  26 
//        u2    0   0   0   0 |  70  67  64  61 |  58  55  52  49 |   0   0   0   0 ||   0   0   0   0 |  22  19  16  13 |  10   7   4   1 |   0   0   0   0 
//        u3   82  79  76  73 |   0   0   0   0 |   0   0   0   0 |  94  91  88  85 ||  34  31  28  25 |   0   0   0   0 |   0   0   0   0 |  46  43  40  37 

            t0 = _mm256_blend_epi32(u0,u1,0b11001100);
            t1 = _mm256_shufps_epi32(u2,u3,0b00111001);
            t2 = _mm256_shufps_epi32(u0,u1,0b01001110);

        printf("t0  ");print_vec_char(t0);printf("t1  ");print_vec_char(t1);printf("t2  ");print_vec_char(t2);printf("\n");
//        t0   93  90  87  84 |  81  78  75  72 |  69  66  63  60 |  57  54  51  48 ||  45  42  39  36 |  33  30  27  24 |  21  18  15  12 |   9   6   3   0 
//        t1   94  91  88  85 |  82  79  76  73 |  70  67  64  61 |  58  55  52  49 ||  46  43  40  37 |  34  31  28  25 |  22  19  16  13 |  10   7   4   1 
//        t2   95  92  89  86 |  83  80  77  74 |  71  68  65  62 |  59  56  53  50 ||  47  44  41  38 |  35  32  29  26 |  23  20  17  14 |  11   8   5   2 

    return 0;
}



int __attribute__ ((noinline)) print_vec_char(__m256i x){
    char v[32];
    _mm256_storeu_si256((__m256i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi || ",
           v[31],v[30],v[29],v[28],v[27],v[26],v[25],v[24],v[23],v[22],v[21],v[20],v[19],v[18],v[17],v[16]);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi \n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

Instruction summary:

    4 vmovdqu
    4 vinserti128-load
    4 vpshufb
    5 vpblendd    (vpblendd is much faster than vpblendvb on most cpu architectures)
    2 vshufps

It is easy to adapt the ideas of these methods to the 16-bit case.

查看更多
Bombasti
3楼-- · 2020-02-05 03:40

This article from Intel describes how to do exactly the 3x8 case that you want.

That article covers the float case. If you want int32, you'll need to cast the outputs since there's no integer version of _mm256_shuffle_ps().

enter image description here

Copying their solution verbatim:

float *p;  // address of first vector
__m128 *m = (__m128*) p;
__m256 m03;
__m256 m14; 
__m256 m25; 
m03  = _mm256_castps128_ps256(m[0]); // load lower halves
m14  = _mm256_castps128_ps256(m[1]);
m25  = _mm256_castps128_ps256(m[2]);
m03  = _mm256_insertf128_ps(m03 ,m[3],1);  // load upper halves
m14  = _mm256_insertf128_ps(m14 ,m[4],1);
m25  = _mm256_insertf128_ps(m25 ,m[5],1);

__m256 xy = _mm256_shuffle_ps(m14, m25, _MM_SHUFFLE( 2,1,3,2)); // upper x's and y's 
__m256 yz = _mm256_shuffle_ps(m03, m14, _MM_SHUFFLE( 1,0,2,1)); // lower y's and z's
__m256 x  = _mm256_shuffle_ps(m03, xy , _MM_SHUFFLE( 2,0,3,0)); 
__m256 y  = _mm256_shuffle_ps(yz , xy , _MM_SHUFFLE( 3,1,2,0)); 
__m256 z  = _mm256_shuffle_ps(yz , m25, _MM_SHUFFLE( 3,0,3,1)); 

So this is 11 instructions. (6 loads, 5 shuffles)


In the general case, it's possible to do an S x W transpose in O(S*log(W)) instructions. Where:

  • S is the stride
  • W is the SIMD width

Assuming the existence of 2-vector permutes and half-vector insert-loads, then the formula becomes:

(S x W load-permute)  <=  S * (lg(W) + 1) instructions

Ignoring reg-reg moves. For degenerate cases like the 3 x 4, it may be possible to do better.

Here's the 3 x 16 load-transpose with AVX512: (6 loads, 3 shuffles, 6 blends)

FORCE_INLINE void transpose_f32_16x3_forward_AVX512(
    const float T[48],
    __m512& r0, __m512& r1, __m512& r2
){
    __m512 a0, a1, a2;

    //   0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
    //  16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
    //  32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

    a0 = _mm512_castps256_ps512(_mm256_loadu_ps(T +  0));
    a1 = _mm512_castps256_ps512(_mm256_loadu_ps(T +  8));
    a2 = _mm512_castps256_ps512(_mm256_loadu_ps(T + 16));
    a0 = _mm512_insertf32x8(a0, ((const __m256*)T)[3], 1);
    a1 = _mm512_insertf32x8(a1, ((const __m256*)T)[4], 1);
    a2 = _mm512_insertf32x8(a2, ((const __m256*)T)[5], 1);

    //   0  1  2  3  4  5  6  7 24 25 26 27 28 29 30 31
    //   8  9 10 11 12 13 14 15 32 33 34 35 36 37 38 39
    //  16 17 18 19 20 21 22 23 40 41 42 43 44 45 46 47

    r0 = _mm512_mask_blend_ps(0xf0f0, a0, a1);
    r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32(  4,  5,  6,  7, 16, 17, 18, 19, 12, 13, 14, 15, 24, 25, 26, 27), a2);
    r2 = _mm512_mask_blend_ps(0xf0f0, a1, a2);

    //   0  1  2  3 12 13 14 15 24 25 26 27 36 37 38 39
    //   4  5  6  7 16 17 18 19 28 29 30 31 40 41 42 43
    //   8  9 10 11 20 21 22 23 32 33 34 35 44 45 46 47

    a0 = _mm512_mask_blend_ps(0xcccc, r0, r1);
    a1 = _mm512_shuffle_ps(r0, r2, 78);
    a2 = _mm512_mask_blend_ps(0xcccc, r1, r2);

    //   0  1  6  7 12 13 18 19 24 25 30 31 36 37 42 43
    //   2  3  8  9 14 15 20 21 26 27 32 33 38 39 44 45
    //   4  5 10 11 16 17 22 23 28 29 34 35 40 41 46 47

    r0 = _mm512_mask_blend_ps(0xaaaa, a0, a1);
    r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32(  1,  16,  3, 18,  5, 20,  7, 22,  9, 24, 11, 26, 13, 28, 15, 30), a2);
    r2 = _mm512_mask_blend_ps(0xaaaa, a1, a2);

    //   0  3  6  9 12 15 18 21 24 27 30 33 36 39 42 45
    //   1  4  7 10 13 16 19 22 25 28 31 34 37 40 43 46
    //   2  5  8 11 14 17 20 23 26 29 32 35 38 41 44 47
}

The inverse 3 x 16 transpose-store will be left as an exercise to the reader.

The pattern is not at all trivial to see since the S = 3 is somewhat degenerate. But if you can see the pattern, you'll be able to generalize this to any odd integer S as well as any power-of-two W.

查看更多
登录 后发表回答