Increase set of numbers so that XOR sum is 0

2019-04-12 12:11发布

I need some help with a problem that I have reduced to the following. I have N 30 bit numbers, such that the combined XOR of all of them is non-zero. I need to add a non-negative (0 or more) value to each of the N numbers, such that the combined XOR of the new numbers becomes 0, under the constraint that the total addition value (not the number of additions) is minimized.

For example, if I had numbers (01010)2, (01011)2 and (01100)2 as three numbers (N = 3). Then, their combined XOR is (01101)2. We could add some numbers as follows:

  • (01010)2 + (00001)2 = (01011)2 : (+1)
  • (01011)2 + (10000)2 = (11011)2 : (+16)
  • (01100)2 + (00100)2 = (10000)2 : (+4)

Now, the total XOR of the new numbers is 0, and the total addition is 21 (=+1+16+4). This total addition value has to be minimized (there could be a better distribution which reduces this total, but this is just an example).

These numbers are 30 bits each, so the numbers could be large, and N <= 15. I would really appreciate it if someone could show some efficient way to solve this. I suspect a DP solution is possible, but I could not formulate anything.

Thanks!

2条回答
\"骚年 ilove
2楼-- · 2019-04-12 12:47

Algorithm:

Find k, the position of the most significant bit of the xor-sum for the given numbers (4, in your example). Determine if all the given numbers have that given bit set (as in your example) or not.

If they do, than you must increase two of the given numbers, such that their most significant bit will be on position k+1. To determine witch, you should brute-force all the pairs of numbers and increase one of them until it becomes 2^(k+1) and the other until the xor-sum equals to 0. Then choose the best pair.

If they don't, than you have to increase only one of the given numbers, that has its k-th bit 0. To determine witch, you should brute-force all such numbers and increase them until the xor-sum equals to 0. Then choose the best one.

To determine how much one of the numbers should be increased such that the xor-sum of all becomes 0, compute the xor-sum of all the other numbers and subtract from it the number that must be increased.

查看更多
ら.Afraid
3楼-- · 2019-04-12 12:57

Nice problem:)

I have come up with an approach which runs in O(n * 2^n * 31 * n), for n = 15, it 's a bit slow (228556800) for one test case. Here are the details:

I use a dp approach(memoization) here, we define a state as (int mask, int pos):

  • mask

    0 <= mask < 2^n - 1, if 2^i & mask > 0, we mean number i has been added before, and all lower bit(<=pos) can be considered as zero.

  • pos

    current check bit position, start from high to low

We start from highest bit to lowest bit, and each time we check the count of the given numbers which have current bit set, we denote it as one_cnt, if

  • one_cnt is even

    current pos's xor is zero, we just move to (mask, pos - 1)

  • one_cnt is odd

    if one_cnt equals to n (full odd), here we consider as an bad state and do nothing. Otherwise we iterate on numbers which contain zero at pos and try to place a one here.

Notice here when one_cnt is full odd, we consider it as bad state because we don't want to increase to (pos + 1) whcich may affect previous state (violate the dp principle).

But there will be such case: arr = [1, 1, 1] and the solution exists. So here we try to do some extra computing:

We start from the highest bit pos and check if current bit contain even one bit, if so we iterate on the numbers to set 1 to one number with zero in current pos, then we start our memoization and update our result.

For example if arr = [1, 1, 1], we may check [2, 1, 1], [1,2,1], [1,1,2]

Hope I've explained it well.

I will update the solution if I come up with faster approach :)

Here are the code:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <ctime>
#include <cassert>

using namespace std;

#define fs first
#define sc second
#define pb push_back
#define mp make_pair
#define range(i, n) for (long long i=0; i<(n); ++i)
#define forit(it,v) for(typeof((v).begin()) it = v.begin() ; it != (v).end() ; ++it)
#define eprintf(...) fprintf(stderr, __VA_ARGS__),fflush(stderr)
#define sz(a) ((int)(a).size())
#define all(a) (a).begin(),a.end()
#define two(i) (1LL<<(i))

typedef long long ll;
typedef vector<int> VI;
typedef pair<int, int> PII;

int n;
vector<ll>  arr;
ll ans;
map<PII, ll> M;

void update(ll & ret, ll tmp) {
    if (tmp == -1) return;
    if (ret == -1) ret = tmp;
    ret = min(ret, tmp);
}

/*
 * memoization(mask, pos)
 * Args:
 * mask: if 2^i in mask it means arr[i] has been added a high bit before, and all lower bit(<=pos) can be considerd zero.
 * pos: current check bit position, start from high to low
 * Return:
 *  return -1 if not valid ans exists else return minimum addition sum 
 */
int memoization(int mask, int pos) {

    if (pos < 0) {
        return 0;
    }

    PII state = mp(mask, pos);
    if (M.find(state) != M.end()) {
        return M[state];
    }

    ll &ret = M[state];
    ret = -1;

    int one_cnt = 0;
    for (int i = 0; i < n; i++) {
        if ( !(mask & two(i)) && 
                (two(pos) & arr[i])) {
            one_cnt ++;
        }
    }

    if (one_cnt % 2 == 0) { // even, xor on this pos equals zero
        ret = memoization(mask, pos - 1);
    } else {
        if (one_cnt == n)  { //full odd  bad state, do nothing
            //pass
        } else { //not full odd, choose one empty bit  to place 1  
            for (int i = 0; i < n; i++) {
                if ((mask & two(i))  //if number i has been added before, then it contain zero at pos 
                        || !(two(pos) & arr[i])  // or if number i has zero at pos and hasn't been added before
                        ) {
                    ll candi = memoization(mask | two(i), pos - 1);
                    ll added = mask & two(i) ? two(pos)  // number i has been added before, so we need extra two(pos) sum
                        //number i hasn't been added before, we need calc the new sum 
                        //here we only consider bits in [0 .. pos]
                        : two(pos) - arr[i] % two(pos + 1); 
                    if (candi >= 0)  // legal result
                        update(ret,  candi + added);
                }
            }
        }
    }

    return ret;
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("g.in", "r", stdin);
#endif
    while (cin >> n) {
        arr.clear();
        for (int i = 0; i < n; i++) {
            ll val;
            cin >> val;
            arr.push_back(val);
        }

        ll max_val = arr[0];
        for (int i = 1; i < n; i++) max_val = max(max_val, arr[i]);

        int max_pos = 0;
        while (max_val) max_pos ++, max_val >>= 1;
        max_pos ++;

        //no adjust
        M.clear();
        ans = memoization(0, 31);

        bool even_bit = true;
        for (int i = max_pos; i >= 0; i--) {
            int one_cnt = 0;

            for (int j = 0; j < n; j++) one_cnt += (two(i) & arr[j]) > 0;
            even_bit &= one_cnt % 2 == 0;

            if (even_bit) {
                for (int j = 0; j < n; j++) {
                    //arr[j] at pos i is empty, try add to 1
                    if (!(two(i) & arr[j])) {
                        ll backup = arr[j];
                        arr[j] = two(i);

                        //since previous pos all contain even one bits, we just start from current pos i
                        M.clear();
                        ll candi = memoization(0, i);
                        ll added = two(i) - backup % two(i);
                        if (candi >= 0) 
                            update(ans, candi + added);

                        arr[j] = backup;
                    }
                }
            }
        }
        cout << ans << endl;
    }

    return 0;
}
查看更多
登录 后发表回答