Inverting a Binary Tree with Scala

The problem of Inverting a Binary Tree has got some hype after following tweet.

As a problem-solver, I was wondering how to approach this problem, as it seems to be a great application of structural recursion. In this post, let’s see how to solve it with functional programming and Scala.

Problem Definition:

The problem can be specified as follows.

Given a Binary Tree t1:

4
/ \
2 7
/ \ / \
1 3 6 9
view raw actual-tree.txt hosted with ❤ by GitHub

We have to implement a function to transform t1 to a Binary Tree t2:

4
/ \
7 2
/ \ / \
9 6 3 1
view raw invertedTree.txt hosted with ❤ by GitHub

Thus the function invertTree essentially has following signature.

invertTree: Tree => Tree

Solution

First, we define a Binary Tree ADT.

sealed trait Tree[+A]
case object EmptyTree extends Tree[Nothing]
case class Node[A](value: A , left: Tree[A], right: Tree[A]) extends Tree[A]
view raw BinaryTree.scala hosted with ❤ by GitHub

In order to conveniently encode tree instances, we add following methods in the companion object of Tree.

object Tree{
def empty[A]: Tree[A] = EmptyTree
def node[A](value: A, left: Tree[A] = empty, right: Tree[A] = empty): Tree[A]
= Node(value, left, right)
}
view raw BinaryTree.scala hosted with ❤ by GitHub

As a result, we can define an instance of a tree in Scala REPL as follows.

scala> val t1 = node(4,
node(2
, node(1)
, node(3)
) ,
node(7
, node(6)
, node(9)
)
)
t1: Tree[Int] =
Node(4,
Node(2,
Node(1,EmptyTree,EmptyTree),
Node(3,EmptyTree,EmptyTree)),
Node(7,
Node(6,EmptyTree,EmptyTree),
Node(9,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

Next, in order to facilitate structural recursion, we define fold function for the binary tree as follows:

def fold[A, B](t:Tree[A] , z:B)(f:(B,A,B) => B): B = t match {
case EmptyTree => z
case Node(x,l,r) => f ( fold( l , z )(f) , x , fold( r , z )(f) )
}
view raw fold.scala hosted with ❤ by GitHub

It allows to traverse the tree, perform transformations and accumulate the result. For instance, we can define a function to count the length of the tree in a generic manner–

def size[T] (tree: Tree[T]) =
fold(tree, 0: Int){(l,x,r) => l + r + 1}
scala> size(t1)
res11: Int = 7
view raw BinaryTree.scala hosted with ❤ by GitHub

Also, we can define a map function that applies a function f: A ⇒ B on the value of each Node. Note that the application of map is always structure-preserving, that is, it retains the existing shape as it was before application (unlike the aforementioned size function) and perform only local transformations.

import Tree._
def map[A,B](tree: Tree[A])(f: A => B): Tree[B] =
fold(tree, Tree.empty[B]){(l, x, r) => Node(f(x), l,r)}
scala> map (t1) ( x => x * 10)
res11: Tree[Int] =
Node(40,
Node(20,
Node(10,EmptyTree,EmptyTree),
Node(30,EmptyTree,EmptyTree)),
Node(70,
Node(60,EmptyTree,EmptyTree),
Node(90,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

As you have guessed, we can similarly define the invertTree function in a generic manner as follows:

def invertTree[A](tree: Tree[A]): Tree[A] =
fold (tree, Tree.empty[A]){(leftT, value, rightT) => Node(value, rightT, leftT)}
view raw BinaryTree.scala hosted with ❤ by GitHub

In essence, invertTree simply swaps left node with the right node, and thus derives the resultant tree with the generic fold.

scala> invertTree(t1)
res12: Tree[Int] =
Node(4,
Node(7,
Node(9,EmptyTree,EmptyTree),
Node(6,EmptyTree,EmptyTree)),
Node(2,
Node(3,EmptyTree,EmptyTree),
Node(1,EmptyTree,EmptyTree)))
/*
4
/ \
7 2
/ \ / \
9 6 3 1
*/
view raw BinaryTree.scala hosted with ❤ by GitHub

Neat..uh? By the way, this problem can be solved in several ways. This post particularly demonstrates the application of structural recursion in a generic manner (e.g., with fold), which is the essence of #fp, imho ;).

If you have any question/suggestion or a different idea to solve this problem, please free to post it as a comment; I highly appreciate that! Thanks for visiting!

References

  1. LeetCode OJ: Invert Binary Tree

Constructing a balanced Binary Search Tree from a sorted List in O(N) time

This post discusses a O(n) algorithm that construct a balanced binary search tree (BST) from a sorted list. For instance, consider that we are given a sorted list: [1,2,3,4,5,6,7]. We have to construct a balanced BST as follows.

        4
        |
    2        6 
    |        |
  1   3   5     7

To do so, we use the following definition of Tree, described in Scala By Example book.

abstract class IntSet
case object Empty extends IntSet
case class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet

One straight-forward approach would be to repeatedly perform binary search on the given list to find the median of the list, and then, to construct a balanced BST recursively. Complexity of such approach is O(nlogn), where n is the number of elements in the given list.

A better algorithm constructs balanced BST while iterating the list only once. It begins with the leaf nodes and construct the tree in a bottom-up manner. As such, it avoids repeated binary searches and achieves better runtime complexity (i.e., O(n), where n is the total number of elements in the given list). Following Scala code outlines this algorithm, which effectively converts a list ls to an IntSet, a balanced BST:

def toTree(ls: List[Int]): IntSet = {
def toTreeAux(ls: List[Int], n: Int): (List[Int], IntSet) = {
if (n <= 0)
(ls, Empty)
else {
val (lls, lt) = toTreeAux(ls, n / 2) // construct left sub-tree
val x :: xs = lls // extract root node
val (xr, rt) = toTreeAux(xs, n - n / 2 - 1) // construct right sub-tree
(xr, IntSet(x, lt, rt)) // construct tree
}
}
val (ls_1, tree) = toTreeAux(ls, List.length(ls))
tree
}
view raw toTree.scala hosted with ❤ by GitHub

Any comment or query regarding this post is highly appreciated. Thanks.

UVa 10706. Number Sequence with F#

This post describes an algorithm to solve UVa 10706: Number Sequence problem from UVa OJ. Before outlining the algorithm, we first give an overview of this problem in the next section.

Interpretation

This is a straight-forward problem, which however, requires careful introspection. Our objective is to write a program that compute the value of a sequence , which is comprised of the number groups . Each consists of positive integers written one after another and thus, the first digits of are —

11212312341234512345612345671234567812345678912345678910123456789101112345678910

It is imperative to note that, the value of has the following range . Realize that, the maximum value of is indeed Int32.MaxValue. In the provided test cases, the values of are stated. We simply have to compute ; therefore, we are required to write a function that has following signature: .  That is —

f: 10 --> S.[10] = 4
f: 8  --> S.[8]  = 2
f: 3  --> S.[3]  = 2

Note also that the maximum number of test cases is 25. Question for the reader: What is the maximum possible value of for the specified sequence , which is constituted of  ?

We hope that by this time, it is clear where we are headed towards. In next section, we outline an algorithm to solve this problem. But, before moving further, we would like to recommend you to try to solve it yourselves first and then, read the rest of the post. By the way, we believe that the solution that we have presented next can be further optimized (or, even a better solution can be derived). We highly welcome your feedback regarding this.

Algorithm

Recall that, is constituted of number groups . In order to identify the digit located at the position of , we first determine the number group   that digit is associated with and then, we figure out the    digit from .

To do so, we first compute the length of each number group . Consider the number group till

112123123412345123456…12345678910

As mentioned, this sequence is basically constituted of  as shown below (with their respective lengths).

1                    --> 1 
1 2                  --> 2 
1 2 3                --> 3
1 2 3 4              --> 4
1 2 3 4 5            --> 5
1 2 3 4 5 6          --> 6
1 2 3 4 5 6 7        --> 7
1 2 3 4 5 6 7 8      --> 8
1 2 3 4 5 6 7 8 9    --> 9
1 2 3 4 5 6 7 8 9 10 --> 11

It implies that the length of can be computed from the length of as shown below.

eqn-1

Using Eq.1, we calculate the cumulative sum of each number group as follows.

eqn-2

Why do we calculate cumulative sum? The purpose in this is to be able to simply run a binary search to determine which number group the digit belongs to in time. For example, consider i=40.

1                    --> 1 
1 2                  --> 2 
1 2 3                --> 6 
1 2 3 4              --> 10
1 2 3 4 5            --> 15
1 2 3 4 5 6          --> 21
1 2 3 4 5 6 7        --> 28
1 2 3 4 5 6 7 8      --> 36 
1 2 3 4 5 6 7 8 9    --> 45
1 2 3 4 5 6 7 8 9 10 --> 56

Using binary search, we can find out that contains the digit. Then, using a linear search, we can simply derive the first four digits of to eventually find out the corresponding digit, 4.

eqn-3

Similarly, for i=55, we can figure out that the digit is indeed 1.

Implementation

Next, we outline a F# implementation of the stated algorithm. We start with computing the cumulative sums of the lengths as follows. Continue reading “UVa 10706. Number Sequence with F#”

UVa 136. Ugly Numbers

This blog-post is about UVa 136: Ugly Number, a trivial, but interesting UVa problem. The crux involves computing 1500th Ugly number, where a Ugly number is defined as a number whose prime factors are only 2, 3 or 5. Following illustrates a sequence of Ugly numbers:

1,2,3,4,5,6,8,9,10,12,15...

Using F#, we can derive 1500th Ugly Number in F#’s REPL as follows.

seq{1L..System.Int64.MaxValue}
|> Seq.filter (isUglyNumber)
|> Seq.take 1500
|> Seq.last
view raw gistfile1.fs hosted with ❤ by GitHub

In this context, the primary function the determines whether a number is a Ugly number or not–isUglyNumber–is outlined as follows. As we can see, it is a naive algorithm that can be further optimized using memoization (as listed here).

(*
* :a:int64 -> b:bool
*)
let isUglyNumber (x:int64) :bool =
let rec checkFactors (n_org:int64) (n:int64) lfactors =
match n with
| 1L -> true
| _ ->
match lfactors with
| [] -> false
| d::xs ->
if isFactor n d then
checkFactors n_org (n/d) lfactors
elif n > d then
checkFactors n_org n xs
else
false
checkFactors x x [2L;3L;5L]
view raw gistfile1.fs hosted with ❤ by GitHub

After computing the 1500th Ugly number in this manner, we submit the Java source code listed below. For complete source code. please visit this gist. Alternatively, this script is also available at tryfsharp.org for further introspection.

class Main {
public static void main(String[] args) {
System.out.println("The 1500'th ugly number is 859963392.");
}
}
view raw gistfile1.java hosted with ❤ by GitHub

Please leave a comment in case of any question or improvement of this implementation. Thanks.

UVa 371. Ackermann Function

The underlying concepts of UVa 371: Ackermann Functions have been discussed in great details in our post of Collatz problem. In this post, we simply outlines an ad-hoc algorithm as a solution to this problem as follows.

import java.io.PrintWriter;
import java.util.Scanner;
public class Main {
private final Scanner in;
private final PrintWriter out;
private final static int _MaxValue = 1000000;
private final static long[] memo = new long[_MaxValue];
public Main() {
in = new Scanner(System.in);
out = new PrintWriter(System.out, true);
}
public Main(Scanner in, PrintWriter out) {
this.in = in;
this.out = out;
}
private static long[] getInts(String input) {
String[] ints = input.trim().split(" ");
long[] rets = new long[2];
rets[0] = Long.parseLong(ints[0]);
rets[1] = Long.parseLong(ints[1]);
return rets;
}
private void solveAckermannProblem(long from, long to) {
long maxValue = from;
long maxLength = 0;
for (long i = from; i <= to; i++) {
long length = computeCycleLength(nextAckermannNumber(i));
if (maxLength < length) {
maxValue = i;
maxLength = length;
}
}
out.println(String
.format("Between %d and %d, %d generates the longest sequence of %d values.",
from, to, maxValue, maxLength));
}
private static long computeCycleLength(long n) {
if (n == 0)
return 0;
if (n == 1)
return 1;
if (n < _MaxValue && memo[(int) n] != 0)
return memo[(int) n];
long len = 1 + computeCycleLength(nextAckermannNumber(n));// computing
// length of
// Ackermann
// sequence
if (n < _MaxValue) // storing it in cache
memo[(int) n] = len;
return len;
}
public static long nextAckermannNumber(long n) {
if (n % 2 == 0)
return n / 2;
else
return n * 3 + 1;
}
public void run() {
while (in.hasNextLine()) {
long[] range = getInts(in.nextLine());
if ((range[0] == 0) && (range[1] == 0))
break;
long from = Math.min(range[0], range[1]);
long to = Math.max(range[0], range[1]);
solveAckermannProblem(from, to);
}
}
public static void main(String[] args) {
Main solver = new Main();
solver.run();
}
}
view raw UVa371.java hosted with ❤ by GitHub

Please leave a comment if you have any question regarding this problem or implementation. Thanks.


See Also

see Collatz Problem a.k.a. 3n+1 ProblemCollatz Problem a.k.a. 3n+1 Problem.
see UVa100. 3n+1 ProblemUVa 100. 3n+1 Problem.