Tuesday, September 27, 2011

Recursion II: Permutations (or Let's Do Some Magix)

In the first post we looked at simple examples where an iterative solution seemed as natural - if not more - than its recursive counterpart. Now that we've mastered the simplest examples, we can start looking at problems where a recursive solution seems like a natural approach.

Magic Squares
One are of such problems is permutations, i.e. possible combinations of ordering certain items. I'm still following a great tutorial from here, so we will look at Magic Squares first. The problem of Magic Squares is to generate a square (2x2, 3x3, ... NxN) grid of numbers (1 to N^2) such that the sum in each row, column and diagonal is equal to the magical constant (N^3+N)/2, e.g. for 3x3 square we can only use numbers 1...9 and the sums should equal 15.

Before reading into the solution, we can try solving the problem ourselves. The first thing to worry about is representation of the problem in code. To me, the way that made most sense was to use a 1-dimensional array to represent the whole square - I figured that it will be easier to use recursion in 1-dimension.

Checking Candidate Solution
The first thing I did was writing a check function that would accept a representation of a candidate magic square solution and evaluate its magic:
      private boolean checkAnswer(int[] square){  
             
           int size = (int)Math.sqrt(square.length);  
           int constant = ((int)Math.pow(size, 3) + size) / 2;  
             
           //check row sums  
           for (int i = 0; i < square.length; i+=size) {  
                int sum = 0;  
                for (int j = i; j < i+size; j++) {  
                     sum += square[j];  
                }  
                if(sum != constant) return false;  
           }  
             
           //check column sums  
           for (int i = 0; i < size; i++) {  
                int sum = 0;  
                for (int j = i; j < square.length; j+=size) {  
                     sum += square[j];  
                }  
                if(sum != constant) return false;  
           }  
             
           //check first diagonal sum  
           int sum = 0;  
           for (int i = 0; i < square.length; i+=size+1)   
                sum += square[i];  
           if(sum != constant) return false;  
             
           //check second diagonal sum  
           sum = 0;  
           for (int i = size-1; i < square.length-1; i+=size-1)   
                sum += square[i];  
           if(sum != constant) return false;  
             
           //all tests passed  
           return true;  
      }  
We first calculate the size of the square in 1-dimension, i.e. its row/column/diagonal size by taking the square root of the total size (length) of the input array. It would be wise to make a check here to make sure the input array is indeed square, but I omitted it here for simplicity. The magic constant is calculated using the formula from the wikipedia page. Note that the magic does not work for square size 2x2, it is forever cursed...

Next, we calculate the appropriate sums. For row sums, we sum up size segments in sequential order, e.g. for 3x3 it would check sums 0+1+2, 3+4+5, 6+7+8. For column sums, we need to jump in size chunks, but starting from different point each time (for each column), e.g. for 3x3 we sum 0+3+6, 1+4+7, 2+5+8. Finally for the diagonal, we just need to start at 0 and jump size+1 at a time to land on the diagonal, e.g. for 3x3 it's 0+4+8. And for the second diagonal, we start at the edge but jump size-1 at a time, e.g. 2+4+6 for the 3x3 square.

I tested the code with the 3x3 example from the wikipedia page and made sure it returns false when some numbers are changed.

Baseline Recursive Solution
Now it's time to find the recursive solution. One way is to look at the search space as a tree, with the choice of 1..n^2 numbers as branches and the total number of cells to fill as the depth of the tree. This tree represents all possible candidate solutions to the problem. A tree is also a natural representation for recursion, so it shouldn't be too hard to implement from here. Here is my solution:
      private int[] solveSquare(int[] square, int i){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     return square;  
                else  
                     return null;  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     square[i] = num;  
                     if(solveSquare(square, i+1) != null)  
                          return square;  
                }  
                return null;  
           }  
      }  
We basically fill one cell at a time, trying all possible numbers in each cell. This search through solution space is depth first, i.e. trying all possibilities at the deepest level and working up from there. So for the 3x3 case, it will try 000000001, then 000000002, and so on... until the checkAnswer() returns a positive result. Note that we only return one possible solution in this case, e.g. for the 3x3 example, running solveSquare(new int[9], 0) returns array [2 8 5 8 5 2 5 2 8].

That interesting solution made me realize that this is too easy, and upon reading the wikipedia article in more detail, I found out that the numbers are "usually distinct integers". This should be easy to fix...
      private boolean numberUsed(int[] square, int i, int num){  
           for (int j = 0; j < i; j++)  
                if(square[j] == num) return true;  
             
           return false;  
      }  
        
      private int[] solveSquare(int[] square, int i){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     return square;  
                else  
                     return null;  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(numberUsed(square, i, num)) continue;  
                     square[i] = num;  
                     if(solveSquare(square, i+1) != null)  
                          return square;  
                }  
                return null;  
           }  
      }  
We add a new helper function to check if a number has already been used at a preceding depth level. If so, we skip attempting that assignment. The resulting output for the 3x3 example I get is [2 7 6 9 5 1 4 3 8], which is consistent with the Lo Shu square described in the wikipedia article.

Getting All Possible Solutions
I was interested to get all possible solutions, instead of just one, so I modified the code a little bit to take advantage of some Java-specific constructs and collect all possible solutions:
      private void solveSquare(Integer[] square, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(numberUsed(square, i, num)) continue;  
                     square[i] = num;  
                     solveSquare(square, i+1, solutions);  
                }  
           }  
      }  
        
      public static void main(String[] args) {  
           MagicSquare ms = new MagicSquare();  
             
           List<Integer[]> solutions = new LinkedList<Integer[]>();  
           ms.solveSquare(new Integer[9], 0, solutions);  
           for(Integer[] square : solutions){  
                for (int i = 0; i < square.length; i++)  
                     System.out.print(square[i] + " ");  
                System.out.println();  
           }  
      }  
The code actually looks simple this way too. All we had to do is use the Integer object wrapper for the integer arrays, so we could store them in a list. When we reach the deepest level and get a positive checkAnswer(), we add the solution to the list (it is important to clone it here, as the Integer array is just an object reference and the array contents of this solution will get modified as the code goes on to try other combinations). I've also included the tester main method here to show how the recursive function gets initiated and then tested.

The results I got for the 3x3 magic square are:
2 7 6 9 5 1 4 3 8
2 9 4 7 5 3 6 1 8
4 3 8 9 5 1 2 7 6
4 9 2 3 5 7 8 1 6
6 1 8 7 5 3 2 9 4
6 7 2 1 5 9 8 3 4
8 1 6 3 5 7 4 9 2
8 3 4 1 5 9 6 7 2
Which can be obtained through rotations/reflections of the Lo Shu square, as described in the wikipedia article.

Evaluating the Solution
It's now time to cross-check our solution with one provided in the online tutorial. As expected, the solution is very similar, with a few interesting differences:
  • The check for the answer is similarly done in a separate function; however, it is accessed inside the for loop, not as a base condition. I personally prefer having the base condition clearly defined at the beginning of the recursive function, even as expense of elegance.
  • Tracking which numbers have been used is implemented differently and is more efficient than my version. The author cleverly uses another array to check mark numbers (represented by array index) that has already been used.
The author goes on optimizing their baseline solution with various tweaks. Without reading too far in, I want to try some of my own tweaks first.

Optimization
My baseline solution takes 1641ms for 3x3. Using the used-number tracking method from the tutorial reduces these times to 141ms - over factor of 10 improvement! Here is the updated code:
      private void solveSquare(Integer[] square, boolean[] used, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     if(used[num] == false)  
                          used[num] = true;  
                     else  
                          continue;  
                       
                     square[i] = num;  
                     solveSquare(square, used, i+1, solutions);  
                     used[num] = false;  
                }  
           }  
      }  
Notice that we reset the flag after the recursive call, effectively "unlocking" that number to be used on a different branch (called from the preceding level).

The next optimization we can do is to do some of the checking before we get to the bottom of recursion. So we can check if the rows sum up to the magic constant before we go any deeper. Here is my implementation of this optimization:
      private void solveSquare(Integer[] square, boolean[] used, int i, List<Integer[]> solutions){  
           if(i == square.length){  
                if(checkAnswer(square))  
                     solutions.add(square.clone());  
           }  
           else{  
                for (int num = 1; num <= square.length; num++) {  
                     square[i] = num;  
                       
                     if(used[num] == false){  
                          //check if row sums to magic  
                          if( (i+1) % size == 0){   
                               int sum = 0;  
                               for (int j = i-size+1; j <= i; j++)  
                                    sum += square[j];  
                               if(sum != constant) continue;  
                          }  
                          used[num] = true;  
                     }  
                     else  
                          continue;  
                       
                     solveSquare(square, used, i+1, solutions);  
                     used[num] = false;  
                }  
           }  
      }  
When the depth (variable i) reaches the end of current row, i.e. it is divisible by the size of the square, we check if the row sum works for the current candidate cell assignment. I have two new variables (size and constant) that are declared outside of this function at the construction of the class, since they stay constant for any given magic square (I use these in the checkAnswer() now too, instead of recomputing them every time).

The 3x3 square solutions are now found in 0ms - so that's a further factor of 140 improvement over the baseline!

No comments:

Post a Comment