Can't get the radix sort algorithm to work in

2019-08-30 09:38发布

问题:

Given n 32 bit integers (assume that they are positive), you want to sort them by first looking at the most significant shift in total bits and recursively sorting each bucket that is created by the sorted integers on those bits.

So if shift is 2, then you will first look at the two most significant bits in each 32 bit integer and then apply counting sort. Finally, from the groups that you will get, you recurse on each group and start sorting the numbers of each group by looking at the third and the fourth most significant bit. You do this recursively.

My code is following:

void radix_sortMSD(int start, int end, 
          int shift, int currentDigit, int input[])
{

    if(end <= start+1 || currentDigit>=32) return;

    /*
     find total amount of buckets
     which is basically 2^(shift)
    */
    long long int numberOfBuckets = (1UL<<shift);

    /*
     initialize a temporary array 
     that will hold the sorted input array
     after finding the values of each bucket.   
    */

    int tmp[end];

   /*
     Allocate memory for the buckets.
   */
   int *buckets = new int[numberOfBuckets + 1];

   /*
       initialize the buckets,
        we don't care about what's 
     happening in position numberOfBuckets+1
   */
   for(int p=0;p<numberOfBuckets + 1;p++)
         buckets[p] = 0;

   //update the buckets
   for (int p = start; p < end; p++)
      buckets[((input[p] >> (32 - currentDigit - shift)) 
                &   (numberOfBuckets-1)) + 1]++;

   //find the accumulative sum
   for(int p = 1; p < numberOfBuckets + 1; p++)
       buckets[p] += buckets[p-1];

   //sort the input array input and store it in array tmp   
   for (int p = start; p < end; p++){ 
    tmp[buckets[((input[p] >> (32 - currentDigit- shift)) 
            & (numberOfBuckets-1))]++] = input[p];
    }

   //copy all the elements in array tmp to array input
   for(int p = start; p < end; p++)
          input[p] = tmp[p];

   //recurse on all the groups that have been created
   for(int p=0;p<numberOfBuckets;p++){
       radix_sortMSD(start+buckets[p], 
       start+buckets[p+1], shift, currentDigit+shift, input);
    }

    //free the memory of the buckets
    delete[] buckets;
}

  int main()
  {

        int a[] = {1, 3, 2, 1, 4, 8, 4, 3};
        int n = sizeof(a)/sizeof(int);
        radix_sortMSD(0,n, 2,0,a);
        return 0;
   }

I can imagine only two issues in this code.

First issue is whether or not I actually get the correct bits of the integers in every iteration. I made the assumption that if I am in position currentDigit where if currentDigit = 0 it means that I am in bit 32 of my integer, then to get the next shift bits, I do a right shift by 32 - currentDigit - shift places and then I apply the AND operation to get the shift least most significant bits, which are exactly the bits that I want.

Second issue is in recursion. I do not think that I recurse on the right groups, but due to the fact that I have no idea whether the first issue is actually resolved correctly, I can not say more things about this at the moment.

any feedback on this would be appreciated.

thank you in advance.

EDIT: added main function to show how my radix function is called.

回答1:

Another update, converted to template for array type. Tmp array is now passed as a parameter. The copy steps were eliminated and a helper function added to return the buffer that the sorted data ends up in. Tested with 4 million 64 bit unsigned integers, it works but it's slow. Fastest time achieved with numberOfBits = 4. numberOfBits no longer has to exactly divide the number of bits per element.

To explain why MSD first is slow I'll use a card sorter analogy. Imagine you have 1,000 cards, each with 3 digits, 000 to 999, in random order. Normally you run through the sorter with the 3rd digit, ending up with 100 cards in each of the bins, bin 0 holds the cards with a "0", ... bin 9 holds the cards with a "9". You then concatenate the cards from bin 0 to bin 9, and run them through the sorter again using the 2nd digit, and again using the 1st digit, resulting in a sorted set of cards. That's 3 runs with 1000 cards on each run, so a total of 3000 cards went through the sorter.

Now start with the randomly ordered cards again, and sort by the 1st digit. You can't concatenate the the sets, because cards with higher 1st digits but lower 2nd digits end up out of order. So now you have to do 10 runs with 100 cards each. This results in 100 sets of 10 cards each, which you run again through the sorter, resulting in 1000 sets of 1 card each, and the cards are now sorted. So the number of cards run through the sorter is still 3,000, same as above, but you had to do 111 runs (1 with 1000 card set, 10 with 100 card sets, 100 with 10 card sets).

template <typename T>
void RadixSortMSD(size_t start, size_t end, 
          size_t numberOfBits, size_t currentBit, T input[], T tmp[])
{
    if((end - start) < 1)
        return;

    // adjust numberOfBits if currentBit close to end element
    if((currentBit + numberOfBits) > (8*sizeof(T)))
        numberOfBits = (8*sizeof(T)) - currentBit;

    // set numberOfBuckets
    size_t numberOfBuckets = 1 << numberOfBits;
    size_t bitMask = numberOfBuckets - 1;
    size_t shift = (8*sizeof(T)) - currentBit - numberOfBits;

    // create bucket info
    size_t *buckets = new size_t[numberOfBuckets+1];
    for(size_t p = 0; p < numberOfBuckets+1; p++)
        buckets[p] = 0;
    for(size_t p = start; p < end; p++)
        buckets[(input[p] >> shift) & bitMask]++;
    size_t m = start;
    for(size_t p = 0; p < numberOfBuckets+1; p++){
        size_t n = buckets[p];
        buckets[p] = m;
        m += n;
    }

    //sort the input array input and store it in array tmp   
    for (size_t p = start; p < end; p++){ 
        tmp[buckets[(input[p] >> shift) & bitMask]++] = input[p];
    }

    // restore bucket info
    for(size_t p = numberOfBuckets; p > 0; p--)
        buckets[p] = buckets[p-1];
    buckets[0] = start;

    // advance current bit
    currentBit += numberOfBits;
    if(currentBit < (8*sizeof(T))){
        //recurse on all the groups that have been created
        for(size_t p=0; p < numberOfBuckets; p++){
            RadixSortMSD(buckets[p], buckets[p+1],
                numberOfBits, currentBit, tmp, input);
        }
    }

    //free buckets
    delete[] buckets;
    return;
}

template <typename T>
T * RadixSort(T *pData, T *pTmp, size_t n)
{
size_t numberOfBits = 4;
    RadixSortMSD(0, n, numberOfBits, 0, pData, pTmp);
    // return the pointer to the sorted data
    if((((8*sizeof(T))+numberOfBits-1)/numberOfBits)&1)
        return pTmp;
    else
        return pData;
}