Is there a way to make this function faster? (C)

2020-07-06 06:58发布

问题:

I have a code in C which does additions in the same way as a human does, so if for example I have two arrays A[0..n-1] and B[0..n-1], the method will do C[0]=A[0]+B[0], C[1]=A[1]+B[1]...

I need help in making this function faster, even if the solution is using intrinsics.

My main problem is that I have a really big dependency problem, as the iteration i+1 depends on the carry of the iteration i, as long as I use base 10. So if A[0]=6 and B[0]=5, C[0] must be 1 and I have a carry of 1 for the next addition.

The faster code I could do was this one:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    for (int i = 0; i < N; i++) {
        Vout[i] = Vin1[i] + Vin2[i];
    } 

    unsigned char carry = 0;

    for (int i = 0; i < N; i++) {
        Vout[i] += carry;
        carry = Vout[i] / 10;
        Vout[i] = Vout[i] % 10;
    }
}

But I also tried these approaches which turned out being slower:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    unsigned char CARRY = 0;
    for (int i = 0; i < N; i++) {
        unsigned char R = Vin1[i] + Vin2[i] + CARRY;
        Vout[i] = R % 10; CARRY = R / 10;
    }
}

void LongNumAddition1(char *Vin1, char *Vin2, char *Vout, unsigned N) {
    char CARRY = 0;
    for (int i = 0; i < N; i++) {
        char R = Vin1[i] + Vin2[i] + CARRY;
        if (R <= 9) {
            Vout[i] = R;
            CARRY = 0;
        } else {
            Vout[i] = R - 10;
            CARRY = 1;
        }
    }
}

I've been researching in google and found some pseudocodes which were similar to what I have implemented, also inside GeeksforGeeks there's another implementation to this problem but it is also slower.

Can you please help me?

回答1:

If you don't want to change the format of the data, you can try SIMD.

typedef uint8_t u8x16 __attribute__((vector_size(16)));

void add_digits(uint8_t *const lhs, uint8_t *const rhs, uint8_t *out, size_t n) {
    uint8_t carry = 0;
    for (size_t i = 0; i + 15 < n; i += 16) {
        u8x16 digits = *(u8x16 *)&lhs[i] + *(u8x16 *)&rhs[i] + (u8x16){carry};

        // Get carries and almost-carries
        u8x16 carries = digits >= 10; // true is -1
        u8x16 full = digits == 9;

        // Shift carries
        carry = carries[15] & 1;
        __uint128_t carries_i = ((__uint128_t)carries) << 8;
        carry |= __builtin_add_overflow((__uint128_t)full, carries_i, &carries_i);

        // Add to carry chains and wrap
        digits += (((u8x16)carries_i) ^ full) & 1;
        // faster: digits = (u8x16)_mm_min_epu8((__m128i)digits, (__m128i)(digits - 10));
        digits -= (digits >= 10) & 10;

        *(u8x16 *)&out[i] = digits;
    }
}

This is ~2 instructions per digit. You'll need to add code to handle the tail-end.


Here's a run-through of the algorithm.

First, we add our digits with our carry from the last iteration:

lhs           7   3   5   9   9   2
rhs           2   4   4   9   9   7
carry                             1
         + -------------------------
digits        9   7   9  18  18  10

We calculate which digits will produce carries (≥10), and which would propagate them (=9). For whatever reason, true is -1 with SIMD.

carries       0   0   0  -1  -1  -1
full         -1   0  -1   0   0   0

We convert carries to an integer and shift it over, and also convert full to an integer.

              _   _   _   _   _   _
carries_i  000000001111111111110000
full       111100001111000000000000

Now we can add these together to propagate carries. Note that only the lowest bit is correct.

              _   _   _   _   _   _
carries_i  111100011110111111110000
(relevant) ___1___1___0___1___1___0

There are two indicators to look out for:

  1. carries_i has its lowest bit set, and digit ≠ 9. There has been a carry into this square.

  2. carries_i has its lowest bit unset, and digit = 9. There has been a carry over this square, resetting the bit.

We calculate this with (((u8x16)carries_i) ^ full) & 1, and add to digits.

(c^f) & 1     0   1   1   1   1   0
digits        9   7   9  18  18  10
         + -------------------------
digits        9   8  10  19  19  10

Then we remove the 10s, which have all been carried already.

digits        9   8  10  19  19  10
(d≥10)&10     0   0  10  10  10  10
         - -------------------------
digits        9   8   0   9   9   0

We also keep track of carries out, which can happen in two places.



回答2:

Candidates for speed improvement:

Optimizations

Be sure you have enabled you compiler with its speed optimizations settings.

restrict

Compiler does not know that changing Vout[] does not affect Vin1[], Vin2[] and is thus limited in certain optimizations.

Use restrict to indicate Vin1[], Vin2[] are not affected by writing to Vout[].

// void LongNumAddition1(unsigned char  *Vin1, unsigned char *Vin2, unsigned char *Vout, unsigned N)
void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)

Note: this restricts the caller from calling the function with a Vout that overlaps Vin1, Vin2.

const

Also use const to aid optimizations. const also allows const arrays to be passed as Vin1, Vin2.

// void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)
void LongNumAddition1(const unsigned char * restrict Vin1, 
   const unsigned char * restrict Vin2, 
   unsigned char * restrict Vout, 
   unsigned N)

unsigned

unsigned/int are the the "goto" types to use for integer math. Rather than unsigned char CARRY or char CARRY, use unsigned or uint_fast8_t from <inttypes.h>.

% alternative

sum = a+b+carry; if (sum >= 10) { sum -= 10; carry = 1; } else carry = 0; @pmg or the like.


Note: I would expect LongNumAddition1() to return the final carry.



回答3:

It is always rather pointless to discuss manual optimizations without a specific system in mind. If we assume you have some sort of mainstream 32-bitter with data cache, instruction cache and branch prediction, then:

  • Avoid the multiple loops. You should be able to merge them into one and thereby get a major performance boost. That way you don't have to touch the same memory area multiple times and you will reduce the total amount of branches. Every i < N must be checked by the program, so reducing the amount of checks should give better performance. Also, this could improve data caching possibilities.

  • Do all operations on largest aligned word size supported. If you have a 32 bitter, you should be able to have this algorithm work on 4 bytes at a time, rather than byte by byte. This means swapping out the byte by byte assignments for a memcpy somehow, doing 4 bytes at a time. That's how library quality code does it.

  • Qualify the parameters properly. You really ought to be familiar of the term of const correctness. Vin1 and Vin2 aren't changed, so these should be const and not just for the sake of performance, but for the sake of program safety and readability/maintainability.

  • Similarly, if you can vouch that the parameters are not pointing at overlapping memory regions, you can restrict qualify all the pointers.

  • Division is an expensive operation on many CPUs, so if it is possible to change the algorithm to get rid of / and %, then do that. If the algorithm is done on byte by byte basis, then you can sacrifice 256 byte of memory to hold a look-up table.

    (This assuming that you can allocate such a look-up table in ROM without introducing wait state dependencies etc).

  • Changing carry to a 32 bit type may give better code on some systems, worse on other. When I tried this out on x86_64, it gave slightly worse code by one instruction (very minor difference).



回答4:

The first loop

for (int i = 0; i < N; i++) {
    Vout[i] = Vin1[i] + Vin2[i];
} 

is auto-vectorized by the compiler. But the next loop

for (int i = 0; i < N; i++) {
    Vout[i] += carry;
    carry = Vout[i] / 10;
    Vout[i] = Vout[i] % 10;
}

contains a loop-carried dependence, which essentially serializes the entire loop (consider adding 1 to 99999999999999999 - it can only be computed sequentially, 1 digit at a time). Loop-carried dependence is one of the biggest headaches in modern computer science.

So that's why the first version is faster - it is partially vectorized. This is not the case with any other version.

How can the loop-carried dependence be avoided?

Computers, being base-2 devices, are notoriously bad with base-10 arithmetic. Not only does it waste space, it also creates artificial carry dependencies between every digit.

If you can turn your data from base-10 to base-2 representation, then it will become easier for the machine to add two arrays because the machine can easily perform binary addition of multiple bits in a single iteration. A well-performing representation could be for example uint64_t for a 64-bit machine. Note that streaming addition with carry is still problematic for SSE, but some options exist there as well.

Unfortunately still it's hard for C compilers to generate efficient loops with carry propagation. For this reason for example libgmp implements bignum addition not in C but in the assembly language using the ADC (add with carry) instruction. By the way, libgmp could be a direct drop-in replacement for a lot of bignum arithmetic functions in your project.



回答5:

To improve the speed of your bignum addition, you should pack more decimal digits into array elements. For example: you can use uint32_t instead of unsigned char and store 9 digits at a time.

Another trick to improve performance is you want to avoid branches.

Here is a modified version of your code without tests:

void LongNumAddition1(const char *Vin1, const char *Vin2, char *Vout, unsigned N) {
    char carry = 0;
    for (int i = 0; i < N; i++) {
        char r = Vin1[i] + Vin2[i] + CARRY;
        carry = (r >= 10);
        Vout[i] = r - carry * 10;
    }
}

Here is a modified version dealing with 9 digits at a time:

#include <stdint.h>

void LongNumAddition1(const uint32_t *Vin1, const uint32_t *Vin2, uint32_t *Vout, unsigned N) {
    uint32_t carry = 0;
    for (int i = 0; i < N; i++) {
        uint32_t r = Vin1[i] + Vin2[i] + CARRY;
        carry = (r >= 1000000000);
        Vout[i] = r - carry * 1000000000;
    }
}

You can look at the code generated by gcc and clang on GodBolt's Compiler Explorer.

Here is a small test program:

#include <inttypes.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>

int LongNumConvert(const char *s, uint32_t *Vout, unsigned N) {
    unsigned i, len = strlen(s);
    uint32_t num = 0;
    if (len > N * 9)
        return -1;
    while (N * 9 > len + 8)
        Vout[--N] = 0;
    for (i = 0; i < len; i++) {
        num = num * 10 + (s[i] - '0');
        if ((len - i) % 9 == 1) {
            Vout[--N] = num;
            num = 0;
        }
    }
    return 0;
}

int LongNumPrint(FILE *fp, const uint32_t *Vout, unsigned N, const char *suff) {
    int len;
    while (N > 1 && Vout[N - 1] == 0)
        N--;
    len = fprintf(fp, "%"PRIu32"", Vout[--N]);
    while (N > 0)
        len += fprintf(fp, "%09"PRIu32"", Vout[--N]);
    if (suff)
        len += fprintf(fp, "%s", suff);
    return len;
}

void LongNumAddition(const uint32_t *Vin1, const uint32_t *Vin2,
                     uint32_t *Vout, unsigned N) {
    uint32_t carry = 0;
    for (unsigned i = 0; i < N; i++) {
        uint32_t r = Vin1[i] + Vin2[i] + carry;
        carry = (r >= 1000000000);
        Vout[i] = r - carry * 1000000000;
    }
}

int main(int argc, char *argv[]) {
    const char *sa = argc > 1 ? argv[1] : "123456890123456890123456890";
    const char *sb = argc > 2 ? argv[2] : "2035864230956204598237409822324";
#define NUMSIZE  111  // handle up to 999 digits
    uint32_t a[NUMSIZE], b[NUMSIZE], c[NUMSIZE];
    LongNumConvert(sa, a, NUMSIZE);
    LongNumConvert(sb, b, NUMSIZE);
    LongNumAddition(a, b, c, NUMSIZE);
    LongNumPrint(stdout, a, NUMSIZE, " + ");
    LongNumPrint(stdout, b, NUMSIZE, " = ");
    LongNumPrint(stdout, c, NUMSIZE, "\n");
    return 0;
}