If I have 8 packed 32-bit floating point numbers (__m256
), what's the fastest way to extract the horizontal sum of all 8 elements? Similarly, how to obtain the horizontal maximum and minimum? In other words, what's the best implementation for the following C++ functions?
float sum(__m256 x); ///< returns sum of all 8 elements
float max(__m256 x); ///< returns the maximum of all 8 elements
float min(__m256 x); ///< returns the minimum of all 8 elements
Quickly jotted down here (and hence untested):
float sum(__m256 x) {
__m128 hi = _mm256_extractf128_ps(x, 1);
__m128 lo = _mm256_extractf128_ps(x, 0);
lo = _mm_add_ps(hi, lo);
hi = _mm_movehl_ps(hi, lo);
lo = _mm_add_ps(hi, lo);
hi = _mm_shuffle_ps(lo, lo, 1);
lo = _mm_add_ss(hi, lo);
return _mm_cvtss_f32(lo);
}
For min/max, replace _mm_add_ps
and _mm_add_ss
with _mm_max_*
or _mm_min_*
.
Note that this is a lot of work for a few operations; AVX isn't really intended to do horizontal operations efficiently. If you can batch up this work for multiple vectors, then more efficient solutions are possible.
While Stephen Canon's answer is probably ideal for finding the horizontal maximum/minimum I think a better solution can be found for the horizontal sum.
float horizontal_add (__m256 a) {
__m256 t1 = _mm256_hadd_ps(a,a);
__m256 t2 = _mm256_hadd_ps(t1,t1);
__m128 t3 = _mm256_extractf128_ps(t2,1);
__m128 t4 = _mm_add_ss(_mm256_castps256_ps128(t2),t3);
return _mm_cvtss_f32(t4);
}
I tried to write code that avoids mixing avx and non-avx instructions and the horizontal sum of an avx register containing floats can be done avx-only by
- 1x
vperm2f128
,
- 2x
vshufps
and
- 3x
vaddps
,
resulting in a register where all entries contain the sum of all elements in the original register.
// permute
// 4, 5, 6, 7, 0, 1, 2, 3
// add
// 0+4, 1+5, 2+6, 3+7, 4+0, 5+1, 6+2, 7+3
// shuffle
// 1+5, 0+4, 3+7, 2+6, 5+1, 4+0, 7+3, 6+2
// add
// 1+5+0+4, 0+4+1+5, 3+7+2+6, 2+6+3+7,
// 5+1+4+0, 4+0+5+1, 7+3+6+2, 6+2+7+3
// shuffle
// 3+7+2+6, 2+6+3+7, 1+5+0+4, 0+4+1+5,
// 7+3+6+2, 6+2+7+3, 5+1+4+0, 4+0+5+1
// add
// 3+7+2+6+1+5+0+4, 2+6+3+7+0+4+1+5, 1+5+0+4+3+7+2+6, 0+4+1+5+2+6+3+7,
// 7+3+6+2+5+1+4+0, 6+2+7+3+4+0+5+1, 5+1+4+0+7+3+6+2, 4+0+5+1+6+2+7+3
static inline __m256 hsums(__m256 const& v)
{
auto x = _mm256_permute2f128_ps(v, v, 1);
auto y = _mm256_add_ps(v, x);
x = _mm256_shuffle_ps(y, y, _MM_SHUFFLE(2, 3, 0, 1));
x = _mm256_add_ps(x, y);
y = _mm256_shuffle_ps(x, x, _MM_SHUFFLE(1, 0, 3, 2));
return _mm256_add_ps(x, y);
}
Obtaining the value is then easy by using _mm256_castps256_ps128
and _mm_cvtss_f32
:
static inline float hadd(__m256 const& v)
{
return _mm_cvtss_f32(_mm256_castps256_ps128(hsums(v)));
}
I did some basic benchmarks against the other solutions with __rdtscp
and did not find one to be superior in terms of mean cpu cycle count on my Intel i5-2500k.
Looking at the Agner Instruction Tables I found (for Sandy-Bridge processors):
µops lat. 1/tp count
this:
vperm2f128 1 2 1 1
vaddps 1 3 1 3
vshufps 1 1 1 2
sum 6 13 6 6
Z boson:
vhaddps 3 5 2 2
vextractf128 1 2 1 1
addss 1 3 1 1
sum 8 15 6 4
Stephen Canon:
vextractf128 1 2 1 1
addps 1 3 1 2
movhlps 1 1 1 1
shufps 1 1 1 1
addss 1 3 1 1
sum 8 13 6 6
where to me (due to the values being rather similar) none is clearly superior (as I cannot forsee whether instruction count, µop count, latency or throughput matters most).
edit, note: The potential problem I assumed to exist in the following is not true.
I suspected, that -if having the result in the ymm register is sufficient- my hsums
could be useful as it doesn't require vzeroupper
to prevent state switching penalty and can thus interleave / execute concurrently with other avx computations using different registers without introducing some kind of sequence point.
union ymm {
__m256 m256;
struct {
__m128 m128lo;
__m128 m128hi;
};
};
union ymm result = {1,2,3,4,5,6,7,8};
__m256 a = {9,10,11,12,13,14,15,16};
result.m256 = _mm256_add_ps (result.m256, a);
result.m128lo = _mm_hadd_ps (result.m128lo, result.m128hi);
result.m128lo = _mm_hadd_ps (result.m128lo, result.m128hi);
result.m128lo = _mm_hadd_ps (result.m128lo, result.m128hi);