Calculating pow(a,b) mod n

2019-01-01 00:58发布

I want to calculate ab mod n for use in RSA decryption. My code (below) returns incorrect answers. What is wrong with it?

unsigned long int decrypt2(int a,int b,int n)
{
    unsigned long int res = 1;

    for (int i = 0; i < (b / 2); i++)
    {
        res *= ((a * a) % n);
        res %= n;
    }

    if (b % n == 1)
        res *=a;

    res %=n;
    return res;
}

标签: c++ c algorithm
12条回答
大哥的爱人
2楼-- · 2019-01-01 01:29

In order to calculate pow(a,b) % n to be used for RSA decryption, the best algorithm I came across is Primality Testing 1) which is as follows:

 int modulo(int a, int b, int n){
    long long x=1, y=a; 
    while (b > 0) {
        if (b%2 == 1) {
            x = (x*y) % n; // multiplying with base
        }
        y = (y*y) % n; // squaring the base
        b /= 2;
    }
    return x % n;
}

See below reference for more details.


1) Primality Testing : Non-deterministic Algorithms – topcoder

查看更多
宁负流年不负卿
3楼-- · 2019-01-01 01:29

I'm using this function:

int CalculateMod(int base, int exp ,int mod){
    int result;
    result = (int) pow(base,exp);
    result = result % mod;
    return result;
}

I parse the variable result because pow give you back a double, and for using mod you need two variables of type int, anyway, in a RSA decryption, you should just use integer numbers.

查看更多
深知你不懂我心
4楼-- · 2019-01-01 01:30

Calculating pow(a,b) mod n

  1. A key problem with OP's code is a * a. This is int overflow (undefined behavior) when a is large enough. The type of res is irrelevant in the multiplication of a * a.

    The solution is to ensure either:

    • the multiplication is done with 2x wide math or
    • with modulus n, n*n <= type_MAX + 1
  2. There is no reason to return a wider type than the type of the modulus as the result is always represent by that type.

    // unsigned long int decrypt2(int a,int b,int n)
    int decrypt2(int a,int b,int n)
    
  3. Using unsigned math is certainly more suitable for OP's RSA goals.


// (a^b)%n
// n != 0

// Test if unsigned long long at least 2x values bits as unsigned
#if ULLONG_MAX/UINT_MAX  - 1 > UINT_MAX
unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
  unsigned long long result = 1u % n;  // Insure result < n, even when n==1
  while (b > 0) {
    if (b & 1) result = (result * a) % n;
    a = (1ULL * a * a) %n;
    b >>= 1;
  }
  return (unsigned) result;
}

#else
unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
  // Detect if  UINT_MAX + 1 < n*n
  if (UINT_MAX/n < n-1) {
    return TBD_code_with_wider_math(a,b,n);
  }
  a %= n;
  unsigned result = 1u % n;
  while (b > 0) {
    if (b & 1) result = (result * a) % n;
    a = (a * a) % n;
    b >>= 1;
  }
  return result;
}

#endif
查看更多
素衣白纱
5楼-- · 2019-01-01 01:31

Doing the raw power operation is very costly, hence you can apply the following logic to simplify the decryption.

From here,

Now say we want to encrypt the message m = 7,
c = m^e mod n = 7^3 mod 33 = 343 mod 33 = 13.
Hence the ciphertext c = 13.

To check decryption we compute
m' = c^d mod n = 13^7 mod 33 = 7.
Note that we don't have to calculate the full value of 13 to the power 7 here. We can make use of the fact that
a = bc mod n = (b mod n).(c mod n) mod n
so we can break down a potentially large number into its components and combine the results of easier, smaller calculations to calculate the final value.

One way of calculating m' is as follows:-
Note that any number can be expressed as a sum of powers of 2. So first compute values of
13^2, 13^4, 13^8, ... by repeatedly squaring successive values modulo 33. 13^2 = 169 ≡ 4, 13^4 = 4.4 = 16, 13^8 = 16.16 = 256 ≡ 25.
Then, since 7 = 4 + 2 + 1, we have m' = 13^7 = 13^(4+2+1) = 13^4.13^2.13^1
≡ 16 x 4 x 13 = 832 ≡ 7 mod 33

查看更多
伤终究还是伤i
6楼-- · 2019-01-01 01:32
#include <cmath>
...
static_cast<int>(std::pow(a,b))%n

but my best bet is you are overflowing int (IE: the number is two large for the int) on the power I had the same problem creating the exact same function.

查看更多
皆成旧梦
7楼-- · 2019-01-01 01:40

The only actual logic error that I see is this line:

if (b % n == 1)

which should be this:

if (b % 2 == 1)

But your overall design is problematic: your function performs O(b) multiplications and modulus operations, but your use of b / 2 and a * a implies that you were aiming to perform O(log b) operations (which is usually how modular exponentiation is done).

查看更多
登录 后发表回答