Friday, September 30, 2011

Recursion III: Divide and Conquer, and Mergesort

I now feel that it's time to tackle more advanced applications of recursion, especially those that are used in real-world algorithms - such as mergesort. At this point, I only vaguely remember how mergesort works, apart from the fact that it uses the recursive divide-and-conquer technique. Before finally diving into my algorithms book, I want to try to make sorting work on my own - using recursion - and then compare it to the actual algorithms presented in the book.

My Mergesort
The basic operation of any sort is the swap, where two values in the array get swapped, (hopefully) making progress to the correct sort order. So if we're using recursion, the base case would involve doing a swap. This would sort two values for us, and we can split the whole array into these small 2-element parts and sort each one individually at the deepest recursion level. Of course, that will not sort the whole array for us, but are we making progress for the global order?

Consider this example of 8-element array (sticking to power of 2 for simplicity): {5,7,8,4,6,5,3,2}. We can split it at the first level into {5,7,8,4} and {6,5,3,2}. Then further, e.g. in the left branch, to {5,7} and {8,4}. We can now sort these two parts, we get {5,7} and {4,8}. When we go up a level, the super-part {5,7,4,8} is still not sorted. However, since we know that the two halves are sorted, we only need to compare the respectively ordered cells between the two parts, i.e. the first element in first and second part, the second element in first and second part, etc. In our example, we compare 5 and 4, 7 and 8, swapping if necessary. This gives us {4,7,5,8}. Notice that the list is still not sorted. We still need to swap the inner 2 elements, giving us the final sorted list {4,5,7,8}. This is the merge part of the sort - we merged the sorted {5,7} and {4,8} into a sorted {4,5,7,8}.

Now can answer the question: was sorting the smallest sub-list necessary? If we did not sort {5,7} and {8,4} before merging the two lists, we merging on upper-level would not get us a sorted list. Starting from {5,7,8,4}, we compare 5 and 8, 7 and 4, giving us {5,4,8,7}. Comparing 4 and 8 doesn't change the order. The final list is not sorted, because 7 and 8 were not swapping at the deepest level.

Here is my implementation of this method:
      private void swap(Integer[] a, int i, int j){  
           if(a[j] < a[i]){  
                int temp = a[j];  
                a[j] = a[i];  
                a[i] = temp;  
           }  
      }  
        
      private void merge(Integer[] a, int i, int n){  
           if(n == 1)  
                swap(a, i, i+1);  
           else{  
                for (int k = 0; k < n; k++)  
                     swap(a, i+k, i+n+k);  
        
                int m = n-1;  
                merge(a, i+1, m);  
           }  
      }  
        
      private void mergeSort(Integer[] a, int i, int n){  
           if(n == 2){  
                swap(a, i, i+1);  
           }  
           else{  
                int m = n/2;  
                mergeSort(a, i, m);  
                mergeSort(a, i+m, m);  
                merge(a, i, m);  
           }  
      }  
The way I tested this code is by generating a random array of integers and then sorting the same array with the Java's provided sorted, comparing with my order in the end:
      public static void main(String[] args) {  
           MergeSort ms = new MergeSort();  
             
           int n = 1024;  
           Integer a[] = new Integer[n];  
           Random r = new Random();  
           for (int i = 0; i < n; i++) {  
                a[i] = r.nextInt(Integer.MAX_VALUE);  
           }  
             
           Integer b[] = a.clone();  
           Integer c[] = a.clone();  
             
           ms.mergeSort(b, 0, a.length);  
           Arrays.sort(c);  
             
           System.out.print(Arrays.equals(b, c));  
      }  
The code passes multiple runs, but there are several problems still:
1) The code does not work with array sizes other than powers of 2
2) There is a stack overflow exception for large input, e.g. 2^20

The second problem can be remedied by setting the stack memory size to 8MB: -Xss8192k. The first requires some extra thought. When splitting the array in mergeSort(), we will need to take the ceiling of n/2 to make sure we keep the n even. We also need to make sure j does not go out of bounds in the swap() function (for right-most nodes on the right branch). The code works correctly with those modifications:
      private void swap(Integer[] a, int i, int j){  
           if(j < a.length && a[j] < a[i]){  
                int temp = a[j];  
                a[j] = a[i];  
                a[i] = temp;  
           }  
      }  
        
      private void merge(Integer[] a, int i, int n){  
           if(n == 1)  
                swap(a, i, i+1);  
           else{  
                for (int k = 0; k < n; k++)  
                     swap(a, i+k, i+n+k);  
        
                int m = n-1;  
                merge(a, i+1, m);  
           }  
      }  
        
      private void mergeSort(Integer[] a, int i, int n){  
           if(n == 2){  
                swap(a, i, i+1);  
           }  
           else{  
                int m = (int)Math.ceil((double)n/2);  
                mergeSort(a, i, m);  
                mergeSort(a, i+m, m);  
                merge(a, i, m);  
           }  
      }  
The code takes a while to execute, even for a moderate sample size 30k. I have a feeling that there is some inefficiency, particularly in the merge method. For input size 1024 (with fixed Random seed), my method took almost 40 times longer than the Java's sort. Looking at the JProfiler results, we can see the problem:








n^2 running time of bubble sort.

Comparing Solutions
Now it's time to delve into the book and compare my solution with one provided there. I'm sure the actual mergesort algorithm will be much faster than mine. Let's see where exactly I went wrong.

Turns out that overall, the code structure and the idea I had was very close to the original algorithm. As I expected, the biggest problem was in the merge operation. Merging is done in O(n), with at most n/2 comparisons when merging two lists of overall size n. The lists are merged by continuously comparing the first element of each of two lists and building a new, sorted list (instead of doing all the swapping in-line). This obviously reduces the number of comparisons we need to do in order to merge two lists. Also, the number of required comparisons scales linearly, and not exponentially (as in my case).

Fixing My Solution
Let's try to fix my algorithm to run as efficiently as the original mergesort algorithm:
      private void merge(int[] a, int i, int m, int n){  
           int[] temp = new int[n];  
        
           int j = i; //start index of left part  
           int k = i+m; //start index of right part  
           for (int l = 0; l < temp.length; l++) {  
                if(j < i+m){  
                     if(k < i+n){  
                          if(a[j] <= a[k])  
                               temp[l] = a[j++]; //pick first from left  
                          else  
                               temp[l] = a[k++];  
                     }  
                     else  
                          temp[l] = a[j++]; //pick first from left  
                }  
                else  
                     temp[l] = a[k++]; //pick first from right  
           }  
             
           for (int l = 0; l < temp.length; l++)  
                a[i+l] = temp[l];  
      }  
        
      private void mergeSort(int[] a, int i, int n){  
           if(n == 1)  
                return;  
           else{  
                int m = n/2;  
                mergeSort(a, i, m);  
                mergeSort(a, i+m, n-m);  
                merge(a, i, m, n);  
           }  
      }  
First thing you might notice is that I'm now using primitive int type. I was curious if the Object wrapper would be significantly slower than the primitive type, and this page suggests using primitives in most situations. I double checked that primitive arrays are passed by reference (i.e. the reference to the array is passed by value, as Java is pass by reference for everything except non-array primitive types).

Evaluation
I've tested my solution against the Java built-in sort for input size 1 million and 10 million (with multiple Random attempts). My code ran correctly for all input, and took about 2.5 times longer than Java's sort (which must be using something like quicksort). This is still much much faster than my original version and consumes way less memory!

Tuesday, September 27, 2011

Recursion II: Permutations (or Let's Do Some Magix)

In the first post we looked at simple examples where an iterative solution seemed as natural - if not more - than its recursive counterpart. Now that we've mastered the simplest examples, we can start looking at problems where a recursive solution seems like a natural approach.

Magic Squares
One are of such problems is permutations, i.e. possible combinations of ordering certain items. I'm still following a great tutorial from here, so we will look at Magic Squares first. The problem of Magic Squares is to generate a square (2x2, 3x3, ... NxN) grid of numbers (1 to N^2) such that the sum in each row, column and diagonal is equal to the magical constant (N^3+N)/2, e.g. for 3x3 square we can only use numbers 1...9 and the sums should equal 15.

Before reading into the solution, we can try solving the problem ourselves. The first thing to worry about is representation of the problem in code. To me, the way that made most sense was to use a 1-dimensional array to represent the whole square - I figured that it will be easier to use recursion in 1-dimension.

Checking Candidate Solution
The first thing I did was writing a check function that would accept a representation of a candidate magic square solution and evaluate its magic:
      private boolean checkAnswer(int[] square){  
             
           int size = (int)Math.sqrt(square.length);  
           int constant = ((int)Math.pow(size, 3) + size) / 2;  
             
           //check row sums  
           for (int i = 0; i < square.length; i+=size) {  
                int sum = 0;  
                for (int j = i; j < i+size; j++) {  
                     sum += square[j];  
                }  
                if(sum != constant) return false;  
           }  
             
           //check column sums  
           for (int i = 0; i < size; i++) {  
                int sum = 0;  
                for (int j = i; j < square.length; j+=size) {  
                     sum += square[j];  
                }  
                if(sum != constant) return false;  
           }  
             
           //check first diagonal sum  
           int sum = 0;  
           for (int i = 0; i < square.length; i+=size+1)   
                sum += square[i];  
           if(sum != constant) return false;  
             
           //check second diagonal sum  
           sum = 0;  
           for (int i = size-1; i < square.length-1; i+=size-1)   
                sum += square[i];  
           if(sum != constant) return false;  
             
           //all tests passed  
           return true;  
      }  
We first calculate the size of the square in 1-dimension, i.e. its row/column/diagonal size by taking the square root of the total size (length) of the input array. It would be wise to make a check here to make sure the input array is indeed square, but I omitted it here for simplicity. The magic constant is calculated using the formula from the wikipedia page. Note that the magic does not work for square size 2x2, it is forever cursed...

Next, we calculate the appropriate sums. For row sums, we sum up size segments in sequential order, e.g. for 3x3 it would check sums 0+1+2, 3+4+5, 6+7+8. For column sums, we need to jump in size chunks, but starting from different point each time (for each column), e.g. for 3x3 we sum 0+3+6, 1+4+7, 2+5+8. Finally for the diagonal, we just need to start at 0 and jump size+1 at a time to land on the diagonal, e.g. for 3x3 it's 0+4+8. And for the second diagonal, we start at the edge but jump size-1 at a time, e.g. 2+4+6 for the 3x3 square.

I tested the code with the 3x3 example from the wikipedia page and made sure it returns false when some numbers are changed.

Baseline Recursive Solution
Now it's time to find the recursive solution. One way is to look at the search space as a tree, with the choice of 1..n^2 numbers as branches and the total number of cells to fill as the depth of the tree. This tree represents all possible candidate solutions to the problem. A tree is also a natural representation for recursion, so it shouldn't be too hard to implement from here. Here is my solution:
      private int[] solveSquare(int[] square, int i){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     return square;  
                else  
                     return null;  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     square[i] = num;  
                     if(solveSquare(square, i+1) != null)  
                          return square;  
                }  
                return null;  
           }  
      }  
We basically fill one cell at a time, trying all possible numbers in each cell. This search through solution space is depth first, i.e. trying all possibilities at the deepest level and working up from there. So for the 3x3 case, it will try 000000001, then 000000002, and so on... until the checkAnswer() returns a positive result. Note that we only return one possible solution in this case, e.g. for the 3x3 example, running solveSquare(new int[9], 0) returns array [2 8 5 8 5 2 5 2 8].

That interesting solution made me realize that this is too easy, and upon reading the wikipedia article in more detail, I found out that the numbers are "usually distinct integers". This should be easy to fix...
      private boolean numberUsed(int[] square, int i, int num){  
           for (int j = 0; j < i; j++)  
                if(square[j] == num) return true;  
             
           return false;  
      }  
        
      private int[] solveSquare(int[] square, int i){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     return square;  
                else  
                     return null;  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(numberUsed(square, i, num)) continue;  
                     square[i] = num;  
                     if(solveSquare(square, i+1) != null)  
                          return square;  
                }  
                return null;  
           }  
      }  
We add a new helper function to check if a number has already been used at a preceding depth level. If so, we skip attempting that assignment. The resulting output for the 3x3 example I get is [2 7 6 9 5 1 4 3 8], which is consistent with the Lo Shu square described in the wikipedia article.

Getting All Possible Solutions
I was interested to get all possible solutions, instead of just one, so I modified the code a little bit to take advantage of some Java-specific constructs and collect all possible solutions:
      private void solveSquare(Integer[] square, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(numberUsed(square, i, num)) continue;  
                     square[i] = num;  
                     solveSquare(square, i+1, solutions);  
                }  
           }  
      }  
        
      public static void main(String[] args) {  
           MagicSquare ms = new MagicSquare();  
             
           List<Integer[]> solutions = new LinkedList<Integer[]>();  
           ms.solveSquare(new Integer[9], 0, solutions);  
           for(Integer[] square : solutions){  
                for (int i = 0; i < square.length; i++)  
                     System.out.print(square[i] + " ");  
                System.out.println();  
           }  
      }  
The code actually looks simple this way too. All we had to do is use the Integer object wrapper for the integer arrays, so we could store them in a list. When we reach the deepest level and get a positive checkAnswer(), we add the solution to the list (it is important to clone it here, as the Integer array is just an object reference and the array contents of this solution will get modified as the code goes on to try other combinations). I've also included the tester main method here to show how the recursive function gets initiated and then tested.

The results I got for the 3x3 magic square are:
2 7 6 9 5 1 4 3 8
2 9 4 7 5 3 6 1 8
4 3 8 9 5 1 2 7 6
4 9 2 3 5 7 8 1 6
6 1 8 7 5 3 2 9 4
6 7 2 1 5 9 8 3 4
8 1 6 3 5 7 4 9 2
8 3 4 1 5 9 6 7 2
Which can be obtained through rotations/reflections of the Lo Shu square, as described in the wikipedia article.

Evaluating the Solution
It's now time to cross-check our solution with one provided in the online tutorial. As expected, the solution is very similar, with a few interesting differences:
  • The check for the answer is similarly done in a separate function; however, it is accessed inside the for loop, not as a base condition. I personally prefer having the base condition clearly defined at the beginning of the recursive function, even as expense of elegance.
  • Tracking which numbers have been used is implemented differently and is more efficient than my version. The author cleverly uses another array to check mark numbers (represented by array index) that has already been used.
The author goes on optimizing their baseline solution with various tweaks. Without reading too far in, I want to try some of my own tweaks first.

Optimization
My baseline solution takes 1641ms for 3x3. Using the used-number tracking method from the tutorial reduces these times to 141ms - over factor of 10 improvement! Here is the updated code:
      private void solveSquare(Integer[] square, boolean[] used, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(used[num] == false)  
                          used[num] = true;  
                     else  
                          continue;  
                       
                     square[i] = num;  
                     solveSquare(square, used, i+1, solutions);  
                     used[num] = false;  
                }  
           }  
      }  
Notice that we reset the flag after the recursive call, effectively "unlocking" that number to be used on a different branch (called from the preceding level).

The next optimization we can do is to do some of the checking before we get to the bottom of recursion. So we can check if the rows sum up to the magic constant before we go any deeper. Here is my implementation of this optimization:
      private void solveSquare(Integer[] square, boolean[] used, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     square[i] = num;  
                       
                     if(used[num] == false){  
                          //check if row sums to magic  
                          if( (i+1) % size == 0){   
                               int sum = 0;  
                               for (int j = i-size+1; j <= i; j++)  
                                    sum += square[j];  
                               if(sum != constant) continue;  
                          }  
                          used[num] = true;  
                     }  
                     else  
                          continue;  
                       
                     solveSquare(square, used, i+1, solutions);  
                     used[num] = false;  
                }  
           }  
      }  
When the depth (variable i) reaches the end of current row, i.e. it is divisible by the size of the square, we check if the row sum works for the current candidate cell assignment. I have two new variables (size and constant) that are declared outside of this function at the construction of the class, since they stay constant for any given magic square (I use these in the checkAnswer() now too, instead of recomputing them every time).

The 3x3 square solutions are now found in 0ms - so that's a further factor of 140 improvement over the baseline!

Monday, September 26, 2011

Beginning Recursion

Since I consider recursion to be one of my weaker areas, I'm going to start off by improving this particular  aspect of my programming skills.To be honest, recursion is probably the area that I'm most uncomfortable with, so what better way to start off my development program?!

I just picked up a good book on algorithms from the library - "Algorithms" (Johnsonbaugh & Schaefer, 2004). Well, it has two 4-star reviews on Amazon, with reasonably reliable comments. The book does not have a specialized section dedicated to recursion. It comes up as part of divide-and-conquer algorithms and also in the introductory section on recurrence relations, but I am looking for a fundamental, self-contained section just on the recursive programming concept. I shall get back to the book a bit later to master the divide and conquer technique, once I'm completely comfortable with the basics of recursion.

Online, I found a tutorial on recursive programming: http://erwnerve.tripod.com/prog/recursion/recintro.htm. It has an interesting analogy/imagination technique that can be used to visualize and better understand recursion. The page also has a few good exercises that should help me improve.

First exercise is Factorial calculation. I remember I programmed this sometime before, as Java does not have an in-built Math function to do this. It is important to remember to have the base condition (aka end condition), so the recursive calling stops in the end and returns the (final) answer. My code is as follows:
      private int factorial(int n){  
           if(n == 1)  
                return 1;  
           else  
                return factorial(n-1) * n;  
      }  
I ran some tests on random integers: 3, 10, 20, 50. 3 and 10 return correct results (checked against Windows calculator), but 20 returns a negative number and 50 returns 0. The answers for these numbers will be out of range for the "int" data type, so I suppose it's reasonable to expect wrong results - to be honest, I was expecting some kind of "out of range" exception to be thrown by Java.

Checking the solution with the one on the web page, the code is identical (except for the order of operands). Now let's do the exercises on the bottom of the page.

Printing Number Sequences
First exercise is to "Write a function using Recursion to print numbers from n to 0". This is my code:
      private void printSequence(int n){  
           if(n == 0)  
                System.out.print(n);  
           else{  
                System.out.print(n + " ");  
                printSequence(n - 1);  
           }  
      }  
Of course, the obvious solution to this problem would be to use a for loop, but it's an interesting practice for recursion nonetheless...

The next problem is to "Write a function using Recursion to print numbers from 0 to n". This appears to be much less trivial than the previous exercise, mainly because we have to keep track of the current depth, as well as the end condition - with only one function argument. Or so it seems... It turns out that we must use the "function return waiting" feature of recursion:
      private void printSequenceReverse(int n){  
           if(n == 0)  
                System.out.print(0 + " ");  
           else{  
                printSequenceReverse(n - 1);  
                System.out.print(n + " ");  
           }  
      }  
We reverse the order of the previous algorithm, by doing the recursive call first, and only printing the result after the call returns. So the algorithm goes all the way to depth n, prints "0", and print out the function stack value of "n" at each subsequent level.

Reversing a String
Next, the third exercise asks to "Write a function using Recursion to enter and display a string in reverse". It also requests that we don't use arrays/string, so I assume the author expects us to use memory pointers. Since we're using Java, let's just try to solve it using the String class:
      private void reverseString(String s){  
           if(s.length() == 1)  
                System.out.print(s);  
           else{  
                System.out.print(s.charAt(s.length() - 1));  
                reverseString(s.substring(0, s.length() - 1));  
           }  
      }  
This code prints the last character of the string first, then makes a recursive call with the string with last character removed. Finally, when the length reaches one, it means we're only left with the first character of the original string. Perhaps a cleaner version of this code can be written using what we've learned from previous exercise:
      private void reverseString(String s){  
           if(s.length() == 1)  
                System.out.print(s);  
           else{  
                reverseString(s.substring(1));  
                System.out.print(s.charAt(0));  
           }  
      }  
This version goes to the deepest level before printing anything by using the one-argument substring function that takes the substring from the specified beginning index.  So the first character to be printed will be the last character of the original string. This is great that we can apply what we've learned in the most basic examples right away!

Checking for Prime Numbers
Exercise 5 asks us to "Write a function using Recursion to check if a number n is prime". So this means, we just need to check if the given number if divisible by any other number smaller than it. Should be simple enough to implement, with all the experience from previous exercises! Definition of a prime number from wikipedia: "A natural number is called a prime number (or a prime) if it is greater than one and has no divisors other than 1 and itself"Here is my code:
      private boolean checkPrime(int n){  
           if(n > 1)  
                return _checkPrime(n, 2);  
           else  
                return false;  
      }  
      private boolean _checkPrime(int n, int m){  
           if(m == n)  
                return true;  
           else{  
                if(n % m == 0)  
                     return false;  
                else  
                     return _checkPrime(n, m+1);  
           }  
      }  

http://primes.utm.edu/lists/small/1000.txt

Note that there is some redundancy in this code, which would become apparent for very large number. Can you spot it? We actually only need to check divisibility up to the square root of the input - details and more interesting facts can be found here: http://en.wikipedia.org/wiki/Primality_test