Tricks of the trade: Recursion to Iteration, Part 1: The Simple Method, secret features, and accumulators

By
Posted on
Tags: , , , , , ,

Alternative title: I wish Python had tail-call elimination.

Recursive programming is powerful because it maps so easily to proof by induction, making it easy to design algorithms and prove them correct.

But recursion is poorly supported by many popular programming languages. Drop a large input into a recursive algorithm in Python, and you’ll probably hit the runtime’s recursion limit. Raise the limit, and you may run out of stack space and segfault.

These are not happy outcomes.

Therefore, an important trick of the trade is knowing how to translate recursive algorithms into iterative algorithms. That way, you can design, prove, and initially code your algorithms in the almighty realm of recursion. Then, after you’ve got things just the way you want them, you can translate your algorithms into equivalent iterative forms through a series of mechanical steps. You can prove your cake and run it in Python, too.

This topic – turning recursion into iteration – is fascinating enough that I’m going to do a series of posts on it. Tail calls, trampolines, continuation-passing style – and more. Lots of good stuff.

For today, though, let’s just look at one simple method and one supporting trick.

The Simple Method

This translation method works on many simple recursive functions. When it works, it works well, and the results are lean and fast. I generally try it first and consider more complicated methods only when it fails.

In a nutshell:

  1. Study the function.
  2. Convert all recursive calls into tail calls. (If you can’t, stop. Try another method.)
  3. Introduce a one-shot loop around the function body.
  4. Convert tail calls into continue statements.
  5. Tidy up.

An important property of this method is that it’s incrementally correct – after every step you have a function that’s equivalent to the original. So if you have unit tests, you can run them after each and every step to make sure you didn’t make a mistake.

Let’s see the method in action.

Example: factorial

With a function this simple, we could probably go straight to the iterative version without using any techniques, just a little noggin power. But the point here is to develop a mechanical process that we can trust when our functions aren’t so simple or our noggins aren’t so powered. So we’re going to work on a really simple function so that we can focus on the process.

Ready? Then let’s show these guys how cyber-commandos get it done! Mark IV style!

  1. Study the original function.

    def factorial(n):
        if n < 2:
            return 1
        return n * factorial(n - 1)

    Nothing scary here. Just one recursive call. We can do this!

  2. Convert recursive calls into tail calls.

     def factorial1a(n, acc=1):
         if n < 2:
             return 1 * acc
         return factorial1a(n - 1, acc * n)

    (If this step seemed confusing, see the Bonus Explainer at the end of the article for the “secret feature” trick behind the step.)

  3. Introduce a one-shot loop around the function body. You want while True: body ; break.

    def factorial1b(n, acc=1):
        while True:
            if n < 2:
                return 1 * acc
            return factorial1b(n - 1, acc * n)
            break

    Yes, I know putting a break after a return is crazy, but do it anyway. Clean-up comes later. For now, we’re going by the numbers.

  4. Replace all recursive tail calls f(x=x1, y=y1, ...) with (x, y, ...) = (x1, y1, ...); continue. Be sure to update all arguments in the assignment.

    def factorial1c(n, acc=1):
        while True:
            if n < 2:
                return 1 * acc
            (n, acc) = (n - 1, acc * n)
            continue
            break

    For this step, I copy the original function’s argument list, parentheses and all, and paste it over the return keyword. Less chance to screw something up that way. It’s all about being mechanical.

  5. Tidy the code and make it more idiomatic.

    def factorial1d(n, acc=1):
        while n > 1:
            (n, acc) = (n - 1, acc * n)
        return acc

    Okay. This step is less about mechanics and more about style. Eliminate the cruft. Simplify. Make it sparkle.

  6. That’s it. You’re finished!

What have we gained?

We just did five steps of work to convert our original, recursive factorial function into the equivalent, iterative factorial1d function. If our programming language had supported tail-call elimination, we could have stopped at step two with factorial1a. But nooooooo… We had to press on, all the way through step five, because we’re using Python. It’s almost enough to make a cyber-commando punch a kitten.

No, the work wasn’t hard, but it was still work. So what did it buy us?

To see what it bought us, let’s look inside the Python run-time environment. We’ll use the Online Python Tutor’s visualizer to observe the build-up of stack frames as factorial, factorial1a, and factorial1d each compute the factorial of 5.

This is very cool, so don’t miss the link: Visualize It! (ProTip: Open it in a new tab.)

Click the “Forward” button to step through the execution of the functions. Notice what happens in the Frames column. When factorial is computing the factorial of 5, five frames build up on the stack. Not a coincidence.

Same thing for our tail-recursive factorial1a. (You’re right. It is tragic.)

But not for our iterative wonder factorial1d. It uses just one stack frame, over and over, until it’s done. That’s economy!

That’s why we did the work. Economy. We converted O(n) stack use into O(1) stack use. When n could be large, that savings matters. It could be the difference between getting an answer and getting a segfault.

Not-so-simple cases

Okay, so we tackled factorial. But that was an easy one. What if your function isn’t so easy? Then it’s time for more advanced methods.

That’s our topic for next time.

Until then, keep your brain recursive and your Python code iterative.

Bonus Explainer: Using secret features for tail-call conversion

In step 2 of The Simple Method, we converted the recursive call in this code:

def factorial(n):
    if n < 2:
        return 1
    return n * factorial(n - 1)  # <-- right here!

into the recursive tail call in this code:

 def factorial(n, acc=1):
     if n < 2:
         return 1 * acc
     return factorial(n - 1, acc * n)  # <-- oh, yeah!

This conversion is easy once you get the hang of it, but the first few times you see it, it seems like magic. So let’s take this one step by step.

First, the problem. We want to get rid of the n * in the following code:

    return n * factorial(n - 1)

That n * stands between our recursive call to factorial and the return keyword. In other words, the code is actually equivalent to the following:

    x = factorial(n - 1)
    result = n * x
    return result

That is, our code has to call the factorial function, await its result (x), and then do something with that result (multiply it by n) before it can return its result. It’s that pesky intermediate doing something we must get rid of. We want nothing but the recursive call to factorial in the return statement.

So how do we get rid of that multiplication?

Here’s the trick. We extend our function with a multiplication feature and use it to do the multiplication for us.

Shh! It’s a secret feature, though, just for us. Nobody else.

In essence, every call to our extended function will not only compute a factorial but also (secretly) multiply that factorial by whatever extra value we give it. The variables that hold these extra values are often called “accumulators,” so I use the name acc here as a nod to tradition.

So here’s our function, freshly extended:

def factorial(n, acc=1):
    if n < 2:
        return acc * 1
    return acc * n * factorial(n - 1)

See what I did to add the secret multiplication feature? Two things.

First, I took the original function and added an additional argument acc, the multiplier. Note that it defaults to 1 so that it has no effect until we give it some other value (since \(1 \cdot x = x\)). Gotta keep the secret secret, after all.

Second, I changed every single return statement from return {whatever} to return acc * {whatever}. Whenever our function would have returned x, it now returns acc * x. And that’s it. Our secret feature is done! And it’s trivial to prove correct. (In fact, we just did prove it correct! Re-read the second sentence.)

Both changes were entirely mechanical, hard to screw up, and, together, default to doing nothing. These are the properties you want when adding secret features to your functions.

Okay. Now we have a function that computes the factorial of n and, secretly, multiplies it by acc.

Now let’s return to that troublesome line, but in our newly extended function:

    return acc * n * factorial(n - 1)

It computes the factorial of n - 1 and then multiplies it by acc * n. But wait! We don’t need to do that multiplication ourselves. Not anymore. Now we can ask our extended factorial function to do it for us, using the secret feature.

So we can rewrite the troublesome line as

    return factorial(n - 1, acc * n)

And that’s a tail call!

So our tail-call version of the factorial function is this:

def factorial(n, acc=1):
    if n < 2:
        return acc * 1
    return factorial(n - 1, acc * n)

And now that all our recursive calls are tail calls – there was only the one – this function is easy to convert into iterative form using The Simple Method described in the main article.

Let’s review the Secret Feature trick for making recursive calls into tail calls. By the numbers:

  1. Find a recursive call that’s not a tail call.
  2. Identify what work is being done between that call and its return statement.
  3. Extend the function with a secret feature to do that work, as controlled by a new accumulator argument with a default value that causes it to do nothing.
  4. Use the secret feature to eliminate the old work.
  5. You’ve now got a tail call!
  6. Repeat until all recursive calls are tail calls.

With a little practice, it becomes second nature. So…

Exercise: Get your practice on!

You mission is to get rid of the recursion in the following function. Feel like you can handle it? Then just fork the exercise repo and do your stuff to exercise1.py.

def find_val_or_next_smallest(bst, x):
    """Get the greatest value <= x in a binary search tree.

    Returns None if no such value can be found.

    """
    if bst is None:
        return None
    elif bst.val == x:
        return x
    elif bst.val > x:
        return find_val_or_next_smallest(bst.left, x)
    else:
        right_best = find_val_or_next_smallest(bst.right, x)
        if right_best is None:
            return bst.val
        return right_best

Have fun!