# Inverting a Binary Tree with Scala

The problem of Inverting a Binary Tree has got some hype after following tweet.

As a problem-solver, I was wondering how to approach this problem, as it seems to be a great application of structural recursion. In this post, let’s see how to solve it with functional programming and Scala.

## Problem Definition:

The problem can be specified as follows.

Given a Binary Tree `t1`:

 4 / \ 2 7 / \ / \ 1 3 6 9
view raw actual-tree.txt hosted with ❤ by GitHub

We have to implement a function to transform `t1` to a Binary Tree `t2`:

 4 / \ 7 2 / \ / \ 9 6 3 1
view raw invertedTree.txt hosted with ❤ by GitHub

Thus the function `invertTree` essentially has following signature. $invertTree: Tree => Tree$

## Solution

First, we define a Binary Tree ADT.

 sealed trait Tree[+A] case object EmptyTree extends Tree[Nothing] case class Node[A](value: A , left: Tree[A], right: Tree[A]) extends Tree[A]
view raw BinaryTree.scala hosted with ❤ by GitHub

In order to conveniently encode tree instances, we add following methods in the companion object of `Tree`.

 object Tree{ def empty[A]: Tree[A] = EmptyTree def node[A](value: A, left: Tree[A] = empty, right: Tree[A] = empty): Tree[A] = Node(value, left, right) }
view raw BinaryTree.scala hosted with ❤ by GitHub

As a result, we can define an instance of a tree in Scala REPL as follows.

 scala> val t1 = node(4, node(2 , node(1) , node(3) ) , node(7 , node(6) , node(9) ) ) t1: Tree[Int] = Node(4, Node(2, Node(1,EmptyTree,EmptyTree), Node(3,EmptyTree,EmptyTree)), Node(7, Node(6,EmptyTree,EmptyTree), Node(9,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

Next, in order to facilitate structural recursion, we define `fold` function for the binary tree as follows:

 def fold[A, B](t:Tree[A] , z:B)(f:(B,A,B) => B): B = t match { case EmptyTree => z case Node(x,l,r) => f ( fold( l , z )(f) , x , fold( r , z )(f) ) }
view raw fold.scala hosted with ❤ by GitHub

It allows to traverse the tree, perform transformations and accumulate the result. For instance, we can define a function to count the length of the tree in a generic manner–

 def size[T] (tree: Tree[T]) = fold(tree, 0: Int){(l,x,r) => l + r + 1} scala> size(t1) res11: Int = 7
view raw BinaryTree.scala hosted with ❤ by GitHub

Also, we can define a `map` function that applies a function `f: A ⇒ B` on the `value` of each `Node`. Note that the application of `map` is always structure-preserving, that is, it retains the existing shape as it was before application (unlike the aforementioned `size` function) and perform only local transformations.

 import Tree._ def map[A,B](tree: Tree[A])(f: A => B): Tree[B] = fold(tree, Tree.empty[B]){(l, x, r) => Node(f(x), l,r)} scala> map (t1) ( x => x * 10) res11: Tree[Int] = Node(40, Node(20, Node(10,EmptyTree,EmptyTree), Node(30,EmptyTree,EmptyTree)), Node(70, Node(60,EmptyTree,EmptyTree), Node(90,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

As you have guessed, we can similarly define the `invertTree` function in a generic manner as follows:

 def invertTree[A](tree: Tree[A]): Tree[A] = fold (tree, Tree.empty[A]){(leftT, value, rightT) => Node(value, rightT, leftT)}
view raw BinaryTree.scala hosted with ❤ by GitHub

In essence, `invertTree` simply swaps `left` node with the `right` node, and thus derives the resultant tree with the generic `fold`.

 scala> invertTree(t1) res12: Tree[Int] = Node(4, Node(7, Node(9,EmptyTree,EmptyTree), Node(6,EmptyTree,EmptyTree)), Node(2, Node(3,EmptyTree,EmptyTree), Node(1,EmptyTree,EmptyTree))) /* 4 / \ 7 2 / \ / \ 9 6 3 1 */
view raw BinaryTree.scala hosted with ❤ by GitHub

Neat..uh? By the way, this problem can be solved in several ways. This post particularly demonstrates the application of structural recursion in a generic manner (e.g., with `fold`), which is the essence of #fp, imho ;).

If you have any question/suggestion or a different idea to solve this problem, please free to post it as a comment; I highly appreciate that! Thanks for visiting!