How can I make branchless code?

2020-02-17 06:50发布

问题:

Related to this answer: https://stackoverflow.com/a/11227902/4714970

In the above answer, it's mentioned how you can avoid branch prediction fails by avoiding branches.

The user demonstrates this by replacing:

if (data[c] >= 128)
{
    sum += data[c];
}

With:

int t = (data[c] - 128) >> 31;
sum += ~t & data[c];

How are these two equivalent (for the specific data set, not strictly equivalent)?

What are some general ways I can do similar things in similar situations? Would it always be by using >> and ~?

回答1:

int t = (data[c] - 128) >> 31;

The trick here is that if data[c] >= 128, then data[c] - 128 is nonnegative, otherwise it is negative. The highest bit in an int, the sign bit, is 1 if and only if that number is negative. >> is a shift that extends the sign bit, so shifting right by 31 makes the whole result 0 if it used to be nonnegative, and all 1 bits (which represents -1) if it used to be negative. So t is 0 if data[c] >= 128, and -1 otherwise. ~t switches these possibilities, so ~t is -1 if data[c] >= 128, and 0 otherwise.

x & (-1) is always equal to x, and x & 0 is always equal to 0. So sum += ~t & data[c] increases sum by 0 if data[c] < 128, and by data[c] otherwise.

Many of these tricks can be applied elsewhere. This trick can certainly be generally applied to have a number be 0 if and only if one value is greater than or equal to another value, and -1 otherwise, and you can mess with it some more to get <=, <, and so on. Bit twiddling like this is a common approach to making mathematical operations branch-free, though it's certainly not always going to be built out of the same operations; ^ (xor) and | (or) also come into play sometimes.



回答2:

While Louis Wasserman's answer is correct, I want to show you a more general (and much clearer) way to write branchless code. You can just use ? : operator:

    int t = data[c];
    sum += (t >= 128 ? t : 0);

JIT compiler sees from the execution profile that the condition is poorly predicted here. In such cases the compiler is smart enough to replace a conditional branch with a conditional move instruction:

    mov    0x10(%r14,%rbp,4),%r9d  ; load R9d from array
    cmp    $0x80,%r9d              ; compare with 128
    cmovl  %r8d,%r9d               ; if less, move R8d (which is 0) to R9d

You can verify yourself that this version works equally fast for both sorted and unsorted array.



回答3:

Branchless code means typically evaluating all possible outcomes of a conditional statement with a weight from the set [0, 1], so that the Sum{ weight_i } = 1. Most of the calculations are essentially discarded. Some optimization can result from the fact, that E_i doesn't have to be correct when the corresponding weight w_i (or mask m_i) is zero.

  result = (w_0 * E_0) + (w_1 * E_1) + ... + (w_n * E_n)    ;; or
  result = (m_0 & E_0) | (m_1 & E_1) | ... | (m_n * E_n)

where m_i stands for a bitmask.

Speed can be achieved also through parallel processing of E_i with a horizontal collapse.

This is contradictory to the semantics of if (a) b; else c; or it's ternary shorthand a ? b : c, where only one expression out of [b, c] is evaluated.

Thus ternary operation is no magic bullet for branchless code. A decent compiler produces branchless code equally for

t = data[n];
if (t >= 128) sum+=t;

vs.

    movl    -4(%rdi,%rdx), %ecx
    leal    (%rax,%rcx), %esi
    addl    $-128, %ecx
    cmovge  %esi, %eax

Variations of branchless code include presenting the problem through other branchless non-linear functions, such as ABS, if present in the target machine.

e.g.

 2 * min(a,b) = a + b - ABS(a - b),
 2 * max(a,b) = a + b + ABS(a - b)

or even:

 ABS(x) = sqrt(x*x)      ;; caveat -- this is "probably" not efficient

In addition to << and ~, it may be equally beneficial to use bool and !bool instead of (possibly undefined) (int >> 31). Likewise, if the condition evaluates as [0, 1], one can generate a working mask with negation:

-[0, 1] = [0, 0xffffffff]  in 2's complement representation