Inverting a Binary Tree with Scala

The problem of Inverting a Binary Tree has got some hype after following tweet.

As a problem-solver, I was wondering how to approach this problem, as it seems to be a great application of structural recursion. In this post, let’s see how to solve it with functional programming and Scala.

Problem Definition:

The problem can be specified as follows.

Given a Binary Tree t1:

4
/ \
2 7
/ \ / \
1 3 6 9
view raw actual-tree.txt hosted with ❤ by GitHub

We have to implement a function to transform t1 to a Binary Tree t2:

4
/ \
7 2
/ \ / \
9 6 3 1
view raw invertedTree.txt hosted with ❤ by GitHub

Thus the function invertTree essentially has following signature.

invertTree: Tree => Tree

Solution

First, we define a Binary Tree ADT.

sealed trait Tree[+A]
case object EmptyTree extends Tree[Nothing]
case class Node[A](value: A , left: Tree[A], right: Tree[A]) extends Tree[A]
view raw BinaryTree.scala hosted with ❤ by GitHub

In order to conveniently encode tree instances, we add following methods in the companion object of Tree.

object Tree{
def empty[A]: Tree[A] = EmptyTree
def node[A](value: A, left: Tree[A] = empty, right: Tree[A] = empty): Tree[A]
= Node(value, left, right)
}
view raw BinaryTree.scala hosted with ❤ by GitHub

As a result, we can define an instance of a tree in Scala REPL as follows.

scala> val t1 = node(4,
node(2
, node(1)
, node(3)
) ,
node(7
, node(6)
, node(9)
)
)
t1: Tree[Int] =
Node(4,
Node(2,
Node(1,EmptyTree,EmptyTree),
Node(3,EmptyTree,EmptyTree)),
Node(7,
Node(6,EmptyTree,EmptyTree),
Node(9,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

Next, in order to facilitate structural recursion, we define fold function for the binary tree as follows:

def fold[A, B](t:Tree[A] , z:B)(f:(B,A,B) => B): B = t match {
case EmptyTree => z
case Node(x,l,r) => f ( fold( l , z )(f) , x , fold( r , z )(f) )
}
view raw fold.scala hosted with ❤ by GitHub

It allows to traverse the tree, perform transformations and accumulate the result. For instance, we can define a function to count the length of the tree in a generic manner–

def size[T] (tree: Tree[T]) =
fold(tree, 0: Int){(l,x,r) => l + r + 1}
scala> size(t1)
res11: Int = 7
view raw BinaryTree.scala hosted with ❤ by GitHub

Also, we can define a map function that applies a function f: A ⇒ B on the value of each Node. Note that the application of map is always structure-preserving, that is, it retains the existing shape as it was before application (unlike the aforementioned size function) and perform only local transformations.

import Tree._
def map[A,B](tree: Tree[A])(f: A => B): Tree[B] =
fold(tree, Tree.empty[B]){(l, x, r) => Node(f(x), l,r)}
scala> map (t1) ( x => x * 10)
res11: Tree[Int] =
Node(40,
Node(20,
Node(10,EmptyTree,EmptyTree),
Node(30,EmptyTree,EmptyTree)),
Node(70,
Node(60,EmptyTree,EmptyTree),
Node(90,EmptyTree,EmptyTree)))
view raw BinaryTree.scala hosted with ❤ by GitHub

As you have guessed, we can similarly define the invertTree function in a generic manner as follows:

def invertTree[A](tree: Tree[A]): Tree[A] =
fold (tree, Tree.empty[A]){(leftT, value, rightT) => Node(value, rightT, leftT)}
view raw BinaryTree.scala hosted with ❤ by GitHub

In essence, invertTree simply swaps left node with the right node, and thus derives the resultant tree with the generic fold.

scala> invertTree(t1)
res12: Tree[Int] =
Node(4,
Node(7,
Node(9,EmptyTree,EmptyTree),
Node(6,EmptyTree,EmptyTree)),
Node(2,
Node(3,EmptyTree,EmptyTree),
Node(1,EmptyTree,EmptyTree)))
/*
4
/ \
7 2
/ \ / \
9 6 3 1
*/
view raw BinaryTree.scala hosted with ❤ by GitHub

Neat..uh? By the way, this problem can be solved in several ways. This post particularly demonstrates the application of structural recursion in a generic manner (e.g., with fold), which is the essence of #fp, imho ;).

If you have any question/suggestion or a different idea to solve this problem, please free to post it as a comment; I highly appreciate that! Thanks for visiting!

References

  1. LeetCode OJ: Invert Binary Tree