Ways to do modulo multiplication with primitive ty

2020-01-30 11:20发布

Is there a way to build e.g. (853467 * 21660421200929) % 100000000000007 without BigInteger libraries (note that each number fits into a 64 bit integer but the multiplication result does not)?

This solution seems inefficient:

int64_t mulmod(int64_t a, int64_t b, int64_t m) {
    if (b < a)
        std::swap(a, b);
    int64_t res = 0;
    for (int64_t i = 0; i < a; i++) {
        res += b;
        res %= m;
    }
    return res;
}

标签: c++ algorithm
7条回答
劳资没心,怎么记你
2楼-- · 2020-01-30 12:02

a * b % m equals a * b - (a * b / m) * m

Use floating point arithmetic to approximate a * b / m. The approximation leaves a value small enough for normal 64 bit integer operations, for m up to 63 bits.

This method is limited by the significand of a double, which is usually 52 bits.

uint64_t mod_mul_52(uint64_t a, uint64_t b, uint64_t m) {
    uint64_t c = (double)a * b / m - 1;
    uint64_t d = a * b - c * m;

    return d % m;
}

This method is limited by the significand of a long double, which is usually 64 bits or larger. The integer arithmetic is limited to 63 bits.

uint64_t mod_mul_63(uint64_t a, uint64_t b, uint64_t m) {
    uint64_t c = (long double)a * b / m - 1;
    uint64_t d = a * b - c * m;

    return d % m;
}

These methods require that a and b be less than m. To handle arbitrary a and b, add these lines before c is computed.

a = a % m;
b = b % m;

In both methods, the final % operation could be made conditional.

return d >= m ? d % m : d;
查看更多
贼婆χ
3楼-- · 2020-01-30 12:06

I can suggest an improvement for your algorithm.

You actually calculate a * b iteratively by adding each time b, doing modulo after each iteration. It's better to add each time b * x, whereas x is determined so that b * x won't overflow.

int64_t mulmod(int64_t a, int64_t b, int64_t m)
{
    a %= m;
    b %= m;

    int64_t x = 1;
    int64_t bx = b;

    while (x < a)
    {
        int64_t bb = bx * 2;
        if (bb <= bx)
            break; // overflow

        x *= 2;
        bx = bb;
    }

    int64_t ans = 0;

    for (; x < a; a -= x)
        ans = (ans + bx) % m;

    return (ans + a*b) % m;
}
查看更多
霸刀☆藐视天下
4楼-- · 2020-01-30 12:12

An improvement to the repeating doubling algorithm is to check how many bits at once can be calculated without an overflow. An early exit check can be done for both arguments -- speeding up the (unlikely?) event of N not being prime.

e.g. 100000000000007 == 0x00005af3107a4007, which allows 16 (or 17) bits to be calculated per each iteration. The actual number of iterations will be 3 with the example.

// just a conceptual routine
int get_leading_zeroes(uint64_t n)
{
   int a=0;
   while ((n & 0x8000000000000000) == 0) { a++; n<<=1; }
   return a;
}

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n)
{
     uint64_t result = 0;
     int N = get_leading_zeroes(n);
     uint64_t mask = (1<<N) - 1;
     a %= n;
     b %= n;  // Make sure all values are originally in the proper range?
     // n is not necessarily a prime -- so both a & b can end up being zero
     while (a>0 && b>0)
     {
         result = (result + (b & mask) * a) % n;  // no overflow
         b>>=N;
         a = (a << N) % n;
     }
     return result;
}
查看更多
对你真心纯属浪费
5楼-- · 2020-01-30 12:13

You should use Russian Peasant multiplication. It uses repeated doubling to compute all the values (b*2^i)%m, and adds them in if the ith bit of a is set.

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) {
    int64_t res = 0;
    while (a != 0) {
        if (a & 1) res = (res + b) % m;
        a >>= 1;
        b = (b << 1) % m;
    }
    return res;
}

It improves upon your algorithm because it takes O(log(a)) time, not O(a) time.

Caveats: unsigned, and works only if m is 63 bits or less.

查看更多
beautiful°
6楼-- · 2020-01-30 12:13

You could try something that breaks the multiplication up into additions:

// compute (a * b) % m:

unsigned int multmod(unsigned int a, unsigned int b, unsigned int m)
{
    unsigned int result = 0;

    a %= m;
    b %= m;

    while (b)
    {
        if (b % 2 != 0)
        {
            result = (result + a) % m;
        }

        a = (a * 2) % m;
        b /= 2;
    }

    return result;
}
查看更多
beautiful°
7楼-- · 2020-01-30 12:17

Keith Randall's answer is good, but as he said, a caveat is that it works only if m is 63 bits or less.

Here is a modification which has two advantages:

  1. It works even if m is 64 bits.
  2. It doesn't need to use the modulo operation, which can be expensive on some processors.

(Note that the res -= m and temp_b -= m lines rely on 64-bit unsigned integer overflow in order to give the expected results. This should be fine since unsigned integer overflow is well-defined in C and C++. For this reason it's important to use unsigned integer types.)

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) {
    uint64_t res = 0;
    uint64_t temp_b;

    /* Only needed if b may be >= m */
    if (b >= m) {
        if (m > UINT64_MAX / 2u)
            b -= m;
        else
            b %= m;
    }

    while (a != 0) {
        if (a & 1) {
            /* Add b to res, modulo m, without overflow */
            if (b >= m - res) /* Equiv to if (res + b >= m), without overflow */
                res -= m;
            res += b;
        }
        a >>= 1;

        /* Double b, modulo m */
        temp_b = b;
        if (b >= m - b)       /* Equiv to if (2 * b >= m), without overflow */
            temp_b -= m;
        b += temp_b;
    }
    return res;
}
查看更多
登录 后发表回答