The problem of Inverting a Binary Tree has got some hype after following tweet.
As a problem-solver, I was wondering how to approach this problem, as it seems to be a great application of structural recursion. In this post, let’s see how to solve it with functional programming and Scala.
Problem Definition:
The problem can be specified as follows.
Given a Binary Tree t1:
| 4 | |
| / \ | |
| 2 7 | |
| / \ / \ | |
| 1 3 6 9 |
We have to implement a function to transform t1 to a Binary Tree t2:
| 4 | |
| / \ | |
| 7 2 | |
| / \ / \ | |
| 9 6 3 1 |
Thus the function invertTree essentially has following signature.
Solution
First, we define a Binary Tree ADT.
| sealed trait Tree[+A] | |
| case object EmptyTree extends Tree[Nothing] | |
| case class Node[A](value: A , left: Tree[A], right: Tree[A]) extends Tree[A] |
In order to conveniently encode tree instances, we add following methods in the companion object of Tree.
| object Tree{ | |
| def empty[A]: Tree[A] = EmptyTree | |
| def node[A](value: A, left: Tree[A] = empty, right: Tree[A] = empty): Tree[A] | |
| = Node(value, left, right) | |
| } |
As a result, we can define an instance of a tree in Scala REPL as follows.
| scala> val t1 = node(4, | |
| node(2 | |
| , node(1) | |
| , node(3) | |
| ) , | |
| node(7 | |
| , node(6) | |
| , node(9) | |
| ) | |
| ) | |
| t1: Tree[Int] = | |
| Node(4, | |
| Node(2, | |
| Node(1,EmptyTree,EmptyTree), | |
| Node(3,EmptyTree,EmptyTree)), | |
| Node(7, | |
| Node(6,EmptyTree,EmptyTree), | |
| Node(9,EmptyTree,EmptyTree))) | |
Next, in order to facilitate structural recursion, we define fold function for the binary tree as follows:
| def fold[A, B](t:Tree[A] , z:B)(f:(B,A,B) => B): B = t match { | |
| case EmptyTree => z | |
| case Node(x,l,r) => f ( fold( l , z )(f) , x , fold( r , z )(f) ) | |
| } |
It allows to traverse the tree, perform transformations and accumulate the result. For instance, we can define a function to count the length of the tree in a generic manner–
| def size[T] (tree: Tree[T]) = | |
| fold(tree, 0: Int){(l,x,r) => l + r + 1} | |
| scala> size(t1) | |
| res11: Int = 7 |
Also, we can define a map function that applies a function f: A ⇒ B on the value of each Node. Note that the application of map is always structure-preserving, that is, it retains the existing shape as it was before application (unlike the aforementioned size function) and perform only local transformations.
| import Tree._ | |
| def map[A,B](tree: Tree[A])(f: A => B): Tree[B] = | |
| fold(tree, Tree.empty[B]){(l, x, r) => Node(f(x), l,r)} | |
| scala> map (t1) ( x => x * 10) | |
| res11: Tree[Int] = | |
| Node(40, | |
| Node(20, | |
| Node(10,EmptyTree,EmptyTree), | |
| Node(30,EmptyTree,EmptyTree)), | |
| Node(70, | |
| Node(60,EmptyTree,EmptyTree), | |
| Node(90,EmptyTree,EmptyTree))) |
As you have guessed, we can similarly define the invertTree function in a generic manner as follows:
| def invertTree[A](tree: Tree[A]): Tree[A] = | |
| fold (tree, Tree.empty[A]){(leftT, value, rightT) => Node(value, rightT, leftT)} |
In essence, invertTree simply swaps left node with the right node, and thus derives the resultant tree with the generic fold.
| scala> invertTree(t1) | |
| res12: Tree[Int] = | |
| Node(4, | |
| Node(7, | |
| Node(9,EmptyTree,EmptyTree), | |
| Node(6,EmptyTree,EmptyTree)), | |
| Node(2, | |
| Node(3,EmptyTree,EmptyTree), | |
| Node(1,EmptyTree,EmptyTree))) | |
| /* | |
| 4 | |
| / \ | |
| 7 2 | |
| / \ / \ | |
| 9 6 3 1 | |
| */ |
Neat..uh? By the way, this problem can be solved in several ways. This post particularly demonstrates the application of structural recursion in a generic manner (e.g., with fold), which is the essence of #fp, imho ;).
If you have any question/suggestion or a different idea to solve this problem, please free to post it as a comment; I highly appreciate that! Thanks for visiting!