Tricks of the trade: Recursion to Iteration, Part 3: Recursive Data Structures

By
Posted on
Tags: programming, recursion, iteration, python, recursion-to-iteration series, tail calls, data structures

This is the third article in a series on converting recursive algorithms into iterative algorithms. If any of what follows seems confusing, you may want to read the earlier articles first.

This is an extra article that I hadn’t planned. I’m writing it because in a comment on the previous article a reader asked me to show a less mathematical example and suggested tree traversal. So that’s the subject of this article: We’ll take a binary tree and flatten it into a list, first recursively, then iteratively.

The challenge

First, let’s define a binary tree to be either empty or given by a node having three parts: (1) a value, (2) a left subtree, and (3) a right subtree, where both of the subtrees are themselves binary trees. In Haskell, we might define it like so:

data BinaryTree a = Empty | Node a (BinaryTree a) (BinaryTree a)

In Python, which we’ll use for the rest of this article, we’ll say that None represents an empty tree and that the following class represents a node:

import collections
Node = collections.namedtuple('Node', 'val left right')

# some sample trees having various node counts
tree0 = None  # empty tree
tree1 = Node(5, None, None)
tree2 = Node(7, tree1, None)
tree3 = Node(7, tree1, Node(9, None, None))
tree4 = Node(2, None, tree3)
tree5 = Node(2, Node(1, None, None), tree3)

Let us now define a function to flatten a tree using an in-order traversal. The recursive definition is absurdly simple, the data type having only two cases to consider:

def flatten(bst):
    # empty case
    if bst is None:
        return []
    # node case
    return flatten(bst.left) + [bst.val] + flatten(bst.right)

A few tests to check that it does what we expect:

def check_flattener(f):
    assert f(tree0) == []
    assert f(tree1) == [5]
    assert f(tree2) == [5, 7]
    assert f(tree3) == [5, 7, 9]
    assert f(tree4) == [2, 5, 7, 9]
    assert f(tree5) == [1, 2, 5, 7, 9]
    print 'ok'

check_flattener(flatten)  # ok

Our challenge for today is to convert flatten into an iterative version. Other than a new trick – partial evaluation – the transformation is straightforward, so I’ll move quickly.

Let’s do this!

Eliminating the first recursive call

First, let’s separate the base case from the incremental work:

def step(bst):
    return flatten(bst.left) + [bst.val] + flatten(bst.right)

def flatten(bst):
    if bst is None:
        return []
    return step(bst)

And let’s break the incremental work into smaller pieces to see what’s going on.

def step(bst):
    left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return left

def flatten(bst):
    if bst is None:
        return []
    return step(bst)

Let’s try to get rid of the first recursive call by assuming that somebody has passed us its result via a secret argument left:

def step(bst, left=None):
    if left is None:
        left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return left

def flatten(bst):
    if bst is None:
        return []
    return step(bst)

And now we’ll make step return values that parallel its input arguments:

def step(bst, left=None):
    if left is None:
        left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return bst, left  # <-- add bst

def flatten(bst):
    if bst is None:
        return []
    return step(bst)[-1]  # <-- note [-1]

In the first recursive call, the transformation applied to bst is .left, so we want to apply the opposite transformation to bst in the returned values. And what’s the opposite of descending to a node’s left subtree? It’s ascending to the node’s parent. So we want something like this:

# this code does not work!

def step(bst, left=None):
    if left is None:
        left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return get_parent(bst), left  # <-- need get_parent

But we’re stuck. We can’t define get_parent because our tree data structure doesn’t keep track of parents, only children.

New plan: Maybe we can assume that someone has passed us the node’s parent and go from there?

But this plan hits the same brick wall: If we add a new argument to accept the parent, we must for parallelism add a new return value to emit the transformed parent, which is the parent of the parent. But we can’t compute the parent of the parent because, as before, we have no way of implementing get_parent.

So we do what mathematicians do when their assumptions hit a brick wall: we strengthen our assumption! Now we assume that someone has passed us all of the parents, right up to the tree’s root. And that assumption gives us what we need:

def step(bst, parents, left=None):
    if left is None:
        left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return parents[-1], parents[:-1], left

Note that we’re using the Python stack convention for parents; thus the immediate parent of bst is given by the final element parents[-1].

As a simplification, we can eliminate the bst argument by considering it the final parent pushed onto the stack:

def step(parents, left=None):
    bst = parents.pop()  # <-- bst = top of parents stack
    if left is None:
        left = flatten(bst.left)
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return parents, left

Now that step requires the parents stack as an argument, the base function must provide it:

def flatten(bst):
    if bst is None:
        return []
    parents = [bst]
    return step(parents)[-1]

But we still haven’t eliminated the first recursive call. To do that, we’ll need to pass the step function a value for its left argument, which will cause the recursive call to be skipped.

But we only know what that value should be for one case, the base case, when bst is None; then left must be []. To get to that case from the tree’s root, where bst is definitely not None, we must iteratively replicate the normal recursive calls on bst.left until we hit the leftmost leaf node. And then, to compute the desired result, we must reverse the trip, iterating the step function until we have returned to the tree’s root, where the parents stack must be empty:

def flatten(bst):
    # find initial conditions for secret-feature "left"
    left = []
    parents = []
    while bst is not None:
        parents.append(bst)
        bst = bst.left
    # iterate to compute the result
    while parents:
        parents, left = step(parents, left)
    return left

And just like that, one of the recursive calls has been transformed into iteration. We’re halfway to the finish line!

Eliminating the second recursive call

But we still have to eliminate that final recursive call to flatten, now sequestered in step. Let’s take a closer look at that function after we make its left argument required since it always gets called with a value now:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    right = flatten(bst.right)
    left.extend(right)
    return parents, left

To get rid of the recursive call to flatten, we’re going to use a new trick: partial evaluation. Basically, we’re going to replace the call to flatten with the function body of flatten, after we rename all its variables to prevent conflicts. So let’s make a copy of flatten and suffix all its variables with 1:

def flatten1(bst1):
    left1 = []
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    while parents1:
        parents1, left1 = step(parents1, left1)
    return left1

And then let’s make its arguments and return values explicit:

    (bst1, ) = ARGUMENTS
    left1 = []
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    while parents1:
        parents1, left1 = step(parents1, left1)
    RETURNS = (left1, )

And then we’ll drop this expansion into step:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    # -- begin partial evaluation --
    (bst1, ) = (bst.right, )
    left1 = []
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    while parents1:
        parents1, left1 = step(parents1, left1)
    (right, ) = (left1, )
    # -- end partial evaluation --
    left.extend(right)
    return parents, left

Now we can eliminate code by fusion across the partial-evaluation boundary.

First up: left1. We can now see that this variable accumulates values that, in the end, get appended to left (via the return variable right). But we can just as well append those values to left directly, eliminating left1 within the boundary and the call to left.extend(right) without:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    # -- begin partial evaluation --
    (bst1, ) = (bst.right, )
    # left1 = []  # <-- eliminate and use left instead
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    while parents1:
        parents1, left = step(parents1, left)
    # (right, ) = (left, )  # <-- eliminated
    # -- end partial evaluation --
    # left.extend(right)  # <-- eliminated
    return parents, left

For this next fusion, we’re going to need to recall our base function to get the necessary outside scope:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    # -- begin partial evaluation --
    (bst1, ) = (bst.right, )
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    while parents1:
        parents1, left = step(parents1, left)
    # -- end partial evaluation --
    return parents, left

def flatten(bst):
    left = []
    parents = []
    while bst is not None:
        parents.append(bst)
        bst = bst.left
    while parents:
        parents, left = step(parents, left)
    return left

When flatten calls step and the code within the partially evaluated region executes, it builds up a stack of nodes parents1 and then calls step iteratively to pop values off of that stack and process them. When it’s finished, control returns to step proper, which then returns to its caller, flatten, with the values (parents, left). But look at what flatten then does with parents: it calls step iteratively to pop values off of that stack and process them in exactly the same way.

So we can eliminate the while loop in step – and the recursive call! – by returning not parents but parents + parents1, which will make the while loop in flatten do the exact same work.

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    # -- begin partial evaluation --
    (bst1, ) = (bst.right, )
    parents1 = []
    while bst1 is not None:
        parents1.append(bst1)
        bst1 = bst1.left
    # while parents1:                            # <-- eliminated
    #     parents1, left = step(parents1, left)  #
    # -- end partial evaluation --
    return parents + parents1, left  # parents -> parents + parents1

And then we can eliminate parents1 completely by taking the values we would have appended to it and appending them directly to parents:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    # -- begin partial evaluation --
    (bst1, ) = (bst.right, )
    # parents1 = []  # <-- eliminated
    while bst1 is not None:
        parents.append(bst1)  # parents1 -> parents
        bst1 = bst1.left
    # -- end partial evaluation --
    return parents, left  # parents + parents1 -> parents

And now, once we remove our partial-evaluation scaffolding, our step function is looking simple again:

def step(parents, left):
    bst = parents.pop()
    left.append(bst.val)
    bst1 = bst.right
    while bst1 is not None:
        parents.append(bst1)
        bst1 = bst1.left
    return parents, left

For the final leg of our journey – simplification – let’s inline the step logic back into the base function:

def flatten(bst):
    left = []
    parents = []
    while bst is not None:
        parents.append(bst)
        bst = bst.left
    while parents:
        parents, left = parents, left
        bst = parents.pop()
        left.append(bst.val)
        bst1 = bst.right
        while bst1 is not None:
            parents.append(bst1)
            bst1 = bst1.left
        parents, left = parents, left
    return left

Let’s eliminate the trivial argument-binding and return-value assignments:

def flatten(bst):
    left = []
    parents = []
    while bst is not None:
        parents.append(bst)
        bst = bst.left
    while parents:
        # parents, left = parents, left  # = no-op
        bst = parents.pop()
        left.append(bst.val)
        bst1 = bst.right
        while bst1 is not None:
            parents.append(bst1)
            bst1 = bst1.left
        # parents, left = parents, left  # = no-op
    return left

And, finally, factor out the duplicated while loop into a local function:

def flatten(bst):
    left = []
    parents = []
    def descend_left(bst):
        while bst is not None:
            parents.append(bst)
            bst = bst.left
    descend_left(bst)
    while parents:
        bst = parents.pop()
        left.append(bst.val)
        descend_left(bst.right)
    return left

And that’s it! We now have a tight, efficient, and iterative version of our original function. Further, the code is close to idiomatic.

That’s it for this time. If you have any questions or comments, just hit me at @tmoertel or use the comment form below.

Thanks for reading!