Dynamic Programing approach for a subset sum

2019-02-19 10:40发布

问题:

Given the following Input

10 4 3 5 5 7

Where

10 = Total Score

4 = 4 players

3 = Score by player 1

5 = Score by player 2

5 = Score by player 3

7 = Score by player 4

I am to print players who's combine score adds to total so output can be 1 4 because player 1 + player 4 score = 3 + 7 -> 10 or output can be 2 3 because player 2 + player 3 score = 5 + 5 -> 10

So it is quite similar to a subset sum problem. I am relatively new to dynamic programing but after getting help on stackoverflow and reading dynamic programing tutorials online and watch few videos online for past 3 days. The following code i have come with so far.

class Test
{
    public static void main (String[] args) throws java.lang.Exception
    {
        int[] test = {3,5,5,7};
        getSolution(test,4,10);
    }

    //pass total score, #of players (size) and the actual scores by each player(arr)
    public static int getSolution(int[] arr,int size, int total){


        int W = total;
        int n = size;
        int[][] myArray = new int[W+1][size+1];

        for(int i = 0; i<size+1; i++)
        {
            myArray[i][0] = 1;
        }
        for(int j =1; j<W+1; j++)
        {
            myArray[0][j] = 0;
        }

        for(int i =1; i<size+1; i++)
        {
            for(int x=1; x<W+1; x++)
            {
                if(arr[i] < x)
                {
                    myArray[i][x] = myArray[i-1][x];
                }
                else
                {
                    myArray[i][x] = myArray[i-1][x-arr[i]];
                }
            }
        }

        return myArray[n][W];
    }

}

For some reason i am not getting expected result. I have been trying to find bug in this issue for past 7+ hours without 0 success. I would highly appreciate it if someone can help fix the issue to get the desired result.

Also please forgive my English it is not my first language.

Update Also i do not need to print all possible combinations that equal the score. I can print any combination that equals the score and it will be fine.

回答1:

Here's the super naive solution that simply generates a power set on your input array and then iterates over each set to see if the sum satisfies the given total. I hacked it together with code already available on StackOverflow.

O(2n) in time and space. Gross.

You can use the idea of a Set to store all indices into your arrays, then generate all permutations of those indices, and then use each set of indices to then go back into your array and get the values.

Input

  • Target: 10
  • Values: [3, 5, 5, 7]

Code:

import java.util.*;
import java.lang.*;
import java.io.*;

class SubsetSum
{
    public static <T> Set<Set<T>> powerSet(Set<T> originalSet)
    {
        Set<Set<T>> sets = new HashSet<Set<T>>();
        if (originalSet.isEmpty()) 
        {
            sets.add(new HashSet<T>());
            return sets;
        }
        List<T> list = new ArrayList<T>(originalSet);
        T head = list.get(0);
        Set<T> rest = new HashSet<T>(list.subList(1, list.size())); 
        for (Set<T> set : powerSet(rest))
        {
            Set<T> newSet = new HashSet<T>();
            newSet.add(head);
            newSet.addAll(set);
            sets.add(newSet);
            sets.add(set);
        }       
        return sets;
    }

    public static void main(String[] args)
    {
        Set<Integer> mySet = new HashSet<Integer>();
        int[] arr={3, 5, 5, 7};
        int target = 10;
        int numVals = 4;
        for(int i=0;i<numVals;++i)
        {
            mySet.add(i);
        }
        System.out.println("Solutions: ");
        for (Set<Integer> s : powerSet(mySet)) 
        {
            int sum = 0;
            for (Integer e : s)
            {
                sum += arr[e];
            }
            if (sum == target)
            {
                String soln = "[ ";
                for (Integer e : s)
                {
                    soln += arr[e];
                    soln += " ";
                }
                soln += "]";

                System.out.println(soln);
            }
        }
    }
}

Output

Solutions:
[ 5 5 ]
[ 3 7 ]

Live Demo

Once you understand this, perhaps you are ready to begin branch and bound or approximation approaches.



回答2:

public List<Integer> findSubsetWithSum(int[] score, int totalScore)
{
    int players = score.length;

    int[] cameFrom = new int[totalScore+1];
    int[] pickedPlayer = new int[totalScore+1];
    for (int s = 0; s <= totalScore; s++)
    {
        cameFrom[s] = -1;
        pickedPlayer[s] = -1;
    }
    cameFrom[0] = 0;
    for (int p = 0; p < players; p++)
    {
        for (int s = score[p]; s <= totalScore; s++)
        {
            if (cameFrom[s - score[p]] >= 0)
            {
                cameFrom[s] = s - score[p];
                pickedPlayer[s] = p + 1;
            }
        }
    }
    List<Integer> picked = new ArrayList<Integer>();
    for (int s = totalScore; s > 0 && cameFrom[s] >= 0; s = cameFrom[s])
    {
        picked.add(pickedPlayer[s]);
    }
    return picked;
}


回答3:

Your problem is in this part of the code

            if(arr[i] < x)
            {
                myArray[i][x] = myArray[i-1][x];
            }
            else
            {
                myArray[i][x] = myArray[i-1][x-arr[i]];
            }

You have two situations

  1. (inside if)We already found a set in this case you need carry previous result to next one.
  2. (inside else) After subtracting result become false, but previous result is true. so you need to carry that result.

why? [3, 34, 4, 12, 5, 2]

Do not forget the part that DP has Optimal Substructure properties. For, finding sum is 9 we have to find all the sum before it, means 1 to 8. That is exactly you are doing by declaring a W+1 row. So when we calculate sum is 7, for first three values we have a result [3,34,4], we need to carry that result to next level.

So you need to modify previous code, to this

           myArray[i][x] = myArray[i-1][x];//carrying previous result
            if(x>=arr[i] )
            {
                if (myArray[i][x]==1){
                    myArray[i][x]=1; 
                }
                else{
                    myArray[i][x] = myArray[i-1][x-arr[i]];
                }
            }

You also have array indexing issue. Your i and x both start from 1 and you never consider the index 0 which is actually your first player. you need to take arr[i-1] value

so further update will look like this,

        myArray[i][x] = myArray[i-1][x];//carrying previous result
                if(x>=arr[i-1] )
                {
                    if (myArray[i][x]==1){
                        myArray[i][x]=1; 
                    }
                    else{
                        myArray[i][x] = myArray[i-1][x-arr[i-1]];
                    }
                }

So final program will look like this

    public boolean findSolution(int[] scores, int total) {
    int W = total;
    int players = scores.length;

    boolean[][] myArray = new boolean[players + 1][total + 1];

    for (int player = 0; player <= players; player++) {
        myArray[player][0] = true;
    }
    for (int score = 1; score < total; score++) {
        myArray[0][score] = false;
    }
    for (int player = 1; player <= players; player++) {
        for (int score = 1; score <= total; score++) {
            myArray[player][score] = myArray[player - 1][score];
            if (score >= scores[player - 1]) {
                myArray[player][score] = myArray[player - 1][score
                        - scores[player - 1]]
                        || myArray[player][score];
            }
        }
    }
    return myArray[players][W];

}

Now for printing result, look into true values in the matrix. it shouldn't be difficult to find out which values are set and when it was set. print those index to get the result.



回答4:

I might try a few things.

First, you're passing in an array with 4 values but later you say there are only three players. I think that's causing some of your difficulties.

For the fastest way I can think of to program this, I might try this.

public static void getSolution(int[] array, int desiredTotal) {
    for (int firstIndex = 0; firstIndex < array.size - 1; ++firstIndex) {
        getSolutionWith(array, desiredTotal, firstIndex);
    }
}

public static void getSolutionWith(int[] array, int desiredTotal, int firstIndex) {
    int lookFor = desiredTotal - array[firstIndex];
    for (int secondIndex = firstIndex + 1; secondIndex < array.size; ++secondIndex) {
        if (array[secondIndex] == lookFor) {
            System.out.printf("%d %d\n", firstIndex + 1, secondIndex + 1);
        }
    }
}

I haven't tested this code, so it might not be perfect. Basically you start at the 0 position (person 1) and you then look at everyone else to see if the first person's value + second person's value equals your total. If so, you print them.