Thursday, February 2, 2012

Bitmap Sort

In Jon Bentley's Programming Pearls book, the first Column introduces a sorting problem. As we learn more about the problem and clearly define it's constraints, the solution transitions from Merge Sort using Disk to an efficient Bitmap Sort.

This algorithm uses a bitmap (or a bit vector) to represent a finite set of distinct integers. For example, if we have an integer range 0-5, we can represent it using a 6-bit array, e.g.
[2,3,5] becomes 0 0 1 1 0 5
[1,3,4] becomes 0 1 0 1 1 0

In order to sort an array of integers, we first need to initialize a bit array of size corresponding to the range and fill it with zeroes (in Java, this is the default value). Then we go through the input array and set the corresponding bit in our bitmap to 1 for each input integer. Finally, we can scan through the bitmap and output the number for each bit that's set to 1. Since we're scanning the bitmap in order, we print all the original integers in a sorted order.

Implementation
This sounds like a very simple algorithm, that also runs in linear O(n) time! However, when I started to implement this algorithm in Java, a number of implementation details popped up and the faults of this algorithm also became apparent.

Most languages I know of won't let you to simply define a bit array -- the smallest primitive type is usually the 8-bit byte. For a range N, we will need to define an byte array of size N/8. And for each number i from input, we will set the i mod 8 'th bit of the floor(i / 8) 'th element of the bitmap. For example, if we had the range 0-15 (N=16) and the input integer was 13, we would set the 5th bit of the 1st element of the bitmap.

When we iterate through the bitmap, we go through each bit in each element and print out the number i * 8 + j for every j 'th bit set in the i 'th element.

This is my Java code:
      public int[] sortDistinctIntegers(int[] a, int min, int max){  
           int N = (max-min) / 8 + 1;  
           byte[] bitmap = new byte[N]; //initialized to 0  
             
           for(int i = 0; i < a.length; i++)  
                bitmap[a[i]/8] |= 1 << (a[i] % 8);  
             
           int k = 0;  
           for(int i = 0; i < N; i++){  
                for(int j = 0; j < 8; j++){  
                     if((bitmap[i] & (1 << j)) > 0){  
                          a[k] = i * 8 + j + min;  
                          k++;  
                     }  
                }  
           }  
             
           return a;  
      }  

In order to save space, I use the original array to "print" the sorted numbers, by keeping the counter k. Also notice that I've defined the range by the min and max variables, and I use the min variable as the offset when setting the bitmap and also when setting sorted integers. Note that this will also work for negative integers.

Shortcomings
This is where the shortcomings become apparent. First, we need to know the range in advance. If we don't, we cannot use this method for the 64-bit longs, since the range of the long type is larger than the largest allowable size for an array in Java (the maximum integer value). The maximum allowable range would be (2^32)-1*8 -- by having a byte array of maximum number of elements, each representing 8 numbers. Of course, we can increase the effective range by adding more dimensions.

Actually, the code above in its current form will not work for the maximum range between Integer.MIN_VALUE and Integer.MAX_VALUE because there will be an overflow in the (min-max) integer calculation. In order to fix that, we need to hack the expression so it evaluates the intermediate value as long and then converts back to integer after the division by 8 (this works because the division puts the range back into the integer land), so it will look like this:
 int N = (int)(((long)max-min) / 8 + 1);  
We will also need to do this in the bitmap location calculation, which will then look like this:
 bitmap[(int)(((long)a[i]-min)/8)] |= 1 << (int)(((long)a[i]-min) % 8);  

Second, even if we know the range, this method will be inefficient for sparse input (with very wide range and very small number of elements) -- and not just in space, since it takes a while to initialize and fill a 2 billion element byte array.

Performance Analysis
To evaluate the performance of this algorithm, I compare it to Java's built-in sort function. My testing code is as follows:
      public static void main(String[] args) {  
           BitmapSort bs = new BitmapSort();  
             
           Random gen = new Random();  
           Set<Integer> numbers = new HashSet<Integer>();  
           int input[];  
             
           for(int N = 10; N <= 2000000000l; N*=10){  
                System.out.print(N + ",");  
                  
                while(numbers.size() < N){  
                     numbers.add(gen.nextInt());  
                }  
                  
                input = new int[N];  
                int i = 0;  
                for(Integer num : numbers){  
                     input[i] = num;  
                     i++;  
                }  
                  
                int temp[] = input;  
                long t1 = System.currentTimeMillis();  
                Arrays.sort(temp);  
                System.out.print((System.currentTimeMillis() - t1) + ",");  
                  
                t1 = System.currentTimeMillis();  
                bs.sortDistinctIntegers(input, Integer.MIN_VALUE, Integer.MAX_VALUE);  
                System.out.println(System.currentTimeMillis() - t1);  
           }  
      }  

I use a HashSet to keep distinct integers generated from the Random class and then time Java's sort execution and my Bitmap Sort implementation.

The chart below shows my results:

Note that the code did not execute for input size larger than 10 million as the HashSet structure ate all the memory in the generation phase (even with 6GB allocation).



Java's sort performance is much faster for the small sizes, but starts to grow exponentially towards the millions zone. I have a feeling that Bitmap Sort might overtake Java's default sorting algorithm for larger input sizes (although, I'm not sure if Java switches algorithms at certain input size).

This experiment was of course for the worst case of Bitmap Sort -- with integers within the maximum range. However, if the range is much better defined (as would be the case in a lot of practical situations), the algorithm should perform much better.

Running the experiment with numbers generated in the range between 0 and N * 2 gives the follow results:



In this case, Bitmap Sort outperforms Java's system sort for all input sizes and looks like it scales better too!

2 comments:

  1. I think your algorithm is wrong,
    min = 98
    max = 100
    n = (min-max)/2=2/8 +1 = 1
    bitmap = new byte[n]

    consider a[i] = 99 (98<99<100)

    Now in your for loop
    bitmap[a[i]/8] = bitmap[99/8] = bitmap[16]

    segmentation fault....

    ReplyDelete
  2. FYI: C++ vector is usually implemented as a specialized class that is better suited for this kind of thing.

    ReplyDelete