The problem of **Inverting a Binary Tree** has got some hype after following tweet.

As a problem-solver, I was wondering how to approach this problem, as it seems to be a great application of structural recursion. In this post, let’s see how to solve it with functional programming and Scala.

## Problem Definition:

The problem can be specified as follows.

Given a Binary Tree `t1`

:

4 | |

/ \ | |

2 7 | |

/ \ / \ | |

1 3 6 9 |

We have to implement a function to transform `t1`

to a Binary Tree `t2`

:

4 | |

/ \ | |

7 2 | |

/ \ / \ | |

9 6 3 1 |

Thus the function `invertTree`

essentially has following signature.

## Solution

First, we define a Binary Tree ADT.

sealed trait Tree[+A] | |

case object EmptyTree extends Tree[Nothing] | |

case class Node[A](value: A , left: Tree[A], right: Tree[A]) extends Tree[A] |

In order to conveniently encode tree instances, we add following methods in the companion object of `Tree`

.

object Tree{ | |

def empty[A]: Tree[A] = EmptyTree | |

def node[A](value: A, left: Tree[A] = empty, right: Tree[A] = empty): Tree[A] | |

= Node(value, left, right) | |

} |

As a result, we can define an instance of a tree in Scala REPL as follows.

scala> val t1 = node(4, | |

node(2 | |

, node(1) | |

, node(3) | |

) , | |

node(7 | |

, node(6) | |

, node(9) | |

) | |

) | |

t1: Tree[Int] = | |

Node(4, | |

Node(2, | |

Node(1,EmptyTree,EmptyTree), | |

Node(3,EmptyTree,EmptyTree)), | |

Node(7, | |

Node(6,EmptyTree,EmptyTree), | |

Node(9,EmptyTree,EmptyTree))) | |

Next, in order to facilitate structural recursion, we define `fold`

function for the binary tree as follows:

def fold[A, B](t:Tree[A] , z:B)(f:(B,A,B) => B): B = t match { | |

case EmptyTree => z | |

case Node(x,l,r) => f ( fold( l , z )(f) , x , fold( r , z )(f) ) | |

} |

It allows to traverse the tree, perform transformations and accumulate the result. For instance, we can define a function to count the length of the tree in a generic manner–

def size[T] (tree: Tree[T]) = | |

fold(tree, 0: Int){(l,x,r) => l + r + 1} | |

scala> size(t1) | |

res11: Int = 7 |

Also, we can define a `map`

function that applies a function `f: A ⇒ B`

on the `value`

of each `Node`

. Note that the application of `map`

is always structure-preserving, that is, it retains the existing shape as it was before application (unlike the aforementioned `size`

function) and perform only local transformations.

import Tree._ | |

def map[A,B](tree: Tree[A])(f: A => B): Tree[B] = | |

fold(tree, Tree.empty[B]){(l, x, r) => Node(f(x), l,r)} | |

scala> map (t1) ( x => x * 10) | |

res11: Tree[Int] = | |

Node(40, | |

Node(20, | |

Node(10,EmptyTree,EmptyTree), | |

Node(30,EmptyTree,EmptyTree)), | |

Node(70, | |

Node(60,EmptyTree,EmptyTree), | |

Node(90,EmptyTree,EmptyTree))) |

As you have guessed, we can similarly define the `invertTree`

function in a generic manner as follows:

def invertTree[A](tree: Tree[A]): Tree[A] = | |

fold (tree, Tree.empty[A]){(leftT, value, rightT) => Node(value, rightT, leftT)} |

In essence, `invertTree`

simply swaps `left`

node with the `right`

node, and thus derives the resultant tree with the generic `fold`

.

scala> invertTree(t1) | |

res12: Tree[Int] = | |

Node(4, | |

Node(7, | |

Node(9,EmptyTree,EmptyTree), | |

Node(6,EmptyTree,EmptyTree)), | |

Node(2, | |

Node(3,EmptyTree,EmptyTree), | |

Node(1,EmptyTree,EmptyTree))) | |

/* | |

4 | |

/ \ | |

7 2 | |

/ \ / \ | |

9 6 3 1 | |

*/ |

Neat..uh? By the way, this problem can be solved in several ways. This post particularly demonstrates the application of structural recursion in a generic manner (e.g., with `fold`

), which is the essence of #fp, imho ;).

If you have any question/suggestion or a different idea to solve this problem, please free to post it as a comment; I highly appreciate that! Thanks for visiting!