Related to this answer: https://stackoverflow.com/a/11227902/4714970
In the above answer, it's mentioned how you can avoid branch prediction fails by avoiding branches.
The user demonstrates this by replacing:
if (data[c] >= 128)
{
sum += data[c];
}
With:
int t = (data[c] - 128) >> 31;
sum += ~t & data[c];
How are these two equivalent (for the specific data set, not strictly equivalent)?
What are some general ways I can do similar things in similar situations? Would it always be by using >>
and ~
?
int t = (data[c] - 128) >> 31;
The trick here is that if data[c] >= 128
, then data[c] - 128
is nonnegative, otherwise it is negative. The highest bit in an int
, the sign bit, is 1 if and only if that number is negative. >>
is a shift that extends the sign bit, so shifting right by 31 makes the whole result 0 if it used to be nonnegative, and all 1 bits (which represents -1) if it used to be negative. So t
is 0
if data[c] >= 128
, and -1
otherwise. ~t
switches these possibilities, so ~t
is -1
if data[c] >= 128
, and 0
otherwise.
x & (-1)
is always equal to x
, and x & 0
is always equal to 0
. So sum += ~t & data[c]
increases sum
by 0
if data[c] < 128
, and by data[c]
otherwise.
Many of these tricks can be applied elsewhere. This trick can certainly be generally applied to have a number be 0
if and only if one value is greater than or equal to another value, and -1
otherwise, and you can mess with it some more to get <=
, <
, and so on. Bit twiddling like this is a common approach to making mathematical operations branch-free, though it's certainly not always going to be built out of the same operations; ^
(xor) and |
(or) also come into play sometimes.
While Louis Wasserman's answer is correct, I want to show you a more general (and much clearer) way to write branchless code. You can just use ? :
operator:
int t = data[c];
sum += (t >= 128 ? t : 0);
JIT compiler sees from the execution profile that the condition is poorly predicted here. In such cases the compiler is smart enough to replace a conditional branch with a conditional move instruction:
mov 0x10(%r14,%rbp,4),%r9d ; load R9d from array
cmp $0x80,%r9d ; compare with 128
cmovl %r8d,%r9d ; if less, move R8d (which is 0) to R9d
You can verify yourself that this version works equally fast for both sorted and unsorted array.
Branchless code means typically evaluating all possible outcomes of a conditional statement with a weight from the set [0, 1], so that the Sum{ weight_i } = 1. Most of the calculations are essentially discarded. Some optimization can result from the fact, that E_i
doesn't have to be correct when the corresponding weight w_i
(or mask m_i
) is zero.
result = (w_0 * E_0) + (w_1 * E_1) + ... + (w_n * E_n) ;; or
result = (m_0 & E_0) | (m_1 & E_1) | ... | (m_n * E_n)
where m_i stands for a bitmask.
Speed can be achieved also through parallel processing of E_i with a horizontal collapse.
This is contradictory to the semantics of if (a) b; else c;
or it's ternary shorthand a ? b : c
, where only one expression out of [b, c] is evaluated.
Thus ternary operation is no magic bullet for branchless code. A decent compiler produces branchless code equally for
t = data[n];
if (t >= 128) sum+=t;
vs.
movl -4(%rdi,%rdx), %ecx
leal (%rax,%rcx), %esi
addl $-128, %ecx
cmovge %esi, %eax
Variations of branchless code include presenting the problem through other branchless non-linear functions, such as ABS, if present in the target machine.
e.g.
2 * min(a,b) = a + b - ABS(a - b),
2 * max(a,b) = a + b + ABS(a - b)
or even:
ABS(x) = sqrt(x*x) ;; caveat -- this is "probably" not efficient
In addition to <<
and ~
, it may be equally beneficial to use bool
and !bool
instead of (possibly undefined) (int >> 31). Likewise, if the condition evaluates as [0, 1], one can generate a working mask with negation:
-[0, 1] = [0, 0xffffffff] in 2's complement representation