Special Pairs with sum as Prime Number

2020-07-27 02:18发布

问题:

A number N is given in the range 1 <= N <= 10^50. A function F(x) is defined as the sum of all digits of a number x. We have to find the count of number of special pairs (x, y) such that:
1. 0 <= x, y <= N
2. F(x) + F(y) is prime in nature
We have to count (x, y) and (y, x) only once. Print the output modulo 1000000000 + 7

My approach:
Since the maximum value of sum of digits in given range can be 450 (If all the characters are 9 in a number of length 50, which gives 9*50 = 450). So, we can create a 2-D array of size 451*451 and for all pair we can store whether it is prime or not.
Now, the issue I am facing is to find all the pairs (x, y) for given number N in linear time (Obviously, we cannot loop through 10^50 to find all the pairs). Can someone suggest any approach, or any formula (if exists), to get all the pairs in linear time.

回答1:

You can create a 2-D array of size 451*451 and for all pair we can store whether it is prime or not. At the same time if you know how many numbers less than n who have F(x)=i and how many have F(x)=j, then after checking (i+j) is prime or not you can easily find a result with the state (i,j) of 2-D array of size 451*451.

So what you need is finding the total numbers who have F(x) =i.

You can easily do it using digit dp:

Digit DP for finding how many numbers who have F(x)=i:

string given=convertIntToString(given string);
int DP[51][2][452]= {-1};
Initially all index hpolds -1;
int digitDp(int pos,int small,int sum)
{
    if(pos==given.size())
    {
        if(sum==i) return 1;
        else return 0;
    }
    if(DP[pos][small][sum]!=-1)return DP[pos][small][sum];
    int res=0;
    if(small)
    {
        for(int j=0; j<=9; j++)res=(res+digitDp(pos+1,small,sum+j))%1000000007;
    }
    else
    {
        int hi=given[pos]-'0';
        for(int j=0; j<=hi; j++)
        {
            if(j==hi)res=(res+digitDp(pos+1,small,sum+j))%1000000007;
            else res=(res+digitDp(pos+1,1,sum+j))%1000000007;
        }
    }
    return DP[pos][small][sum]=res;
}

This function will return the total numbers less than n who have F(x)=i.

So we can call this function for every i from 0 to 451 and can store the result in a temporary variable.

int res[452];
for(i=0;i<=451;i++){
  memset(DP,-1,sizeof DP);
  res[i]=digitDp(0,0,0);
}

Now test for each pair (i,j) :

int answer=0;
for(k=0;k<=451;k++){
   for(int j=0;j<=451;j++){
       if(isprime(k+j)){
         answer=((log long)answer+(long long)res[k]*(long long)res[j])%1000000007;
      }
   }
}

finally result will be answer/2 as (i,j) and (j,i) will be calculated once.

Although there is a case for i=1 and j=1 , Hope you will be able to  handle it.


回答2:

Here's the answer in Python if which makes the code easily readable and a bit easier to understand.

primes = set([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997])
DP = []
given = ''
k = 0

def memset():
    global DP
    DP = [[[-1 for k in range(452)] for j in range(2)] for i in range(51)]

def digitDp(pos , small , final):
    global DP , k
    if pos == len(given):
        if final == k:
            return 1
        else:
            return 0

    if DP[pos][small][final] != -1:
        return DP[pos][small][final]

    res = 0
    if small:
        for i in range(10):
            res=(res+digitDp(pos+1,small,final+i))% 1000000007
    else:
        hi = int(given[pos]) - 0
       
        for i in range(hi+1):
            if(i == hi):
                 res= (res + digitDp(pos + 1 , small, final + i)) % 1000000007
            else:
                 res = (res + digitDp(pos + 1 , 1 , final + i)) % 1000000007

    DP[pos][small][final] = res
    return DP[pos][small][final]


def main():
    result = [0] * 452
    global primes , k , given

    given = str(input())
    for k in range(452):
        memset()
        result[k] = digitDp(0 , 0 , 0)
  
    answer = 0
    for i in range(452):
        for j in range(452):
            if (i+j) in primes:
                 answer = (answer + result[i] * result[j]) % 1000000007
    print(answer // 2)

main()

Thanks to @mahbubcseju for providing the solution to this problem.