Software Engineering 19/02/2020

Advanced recursion techniques in F#

What is a recursive function?

The way information flows through a program depends on how functions are connected together. For example, function A can call function B, which in turn calls function C, and so on. However, sometimes the situation arises when you need to perform one operation multiple times, and in those cases recursive functions can be beneficial. An imperative solution to this problem is to use a for loop, however this requires mutable state. As mutable state can introduce subtle bugs, we’ll avoid anything mutable in this blog post.

Here is a naive functional implementation (all examples in this post will be in F#) of how to add one to every integer in a list:

let rec addOne (input : int list) : int list =
    match input with
    | [] -> []
    | x :: xs -> (x + 1) :: addOne xs

We can see that this will add one to the current element and then call itself with the remainder of the list until we are presented with an empty list. Then each function call can return its result (starting with the empty list) by recreating the new list with all of the transformed elements.

Perfect. Except for one small problem.

Stack Overflow Exceptions

There is a situation here that can cause the dreaded StackOverflowException. This can occur because every time a function is called, the calling function is added to the call stack. Adding functions to the call stack is useful for multiple reasons, primarily so that the computer knows what to do after completing the work in the function it is currently executing. However, in this case it will cause a problem for very long lists.

Consider a list containing 400,000 integers. Passing this to addOne will result in 400,000 nested recursive calls to addOne. The call stack has only a fixed size, which cannot grow dynamically. Therefore, with enough nested calls, the memory set aside for the call stack is exhausted from storing all the times addOne has been called. This results in a StackOverflowException. This problem can be seen by running the following code:

let rec addOne (input : int list) : int list =
    match input with
    | [] -> []
    | x :: xs ->
        let head = (x + 1)
        let tail = addOne xs
        head :: tail

// This will cause a StackOverflowException!
let ohNo = [ 0 .. 400000 ] |> addOne

Tail recursion

Because programmers are smart people, stack overflow in recursion is a solved problem. The rest of this blog post is going to explore two techniques you can use to avoid the problem in your code. There is also an exploration of how to use the type system to convert some coding errors from being correctness bugs to compile-time errors. Almost all solutions to the StackOverflowException problem exploit tail recursion, that is by ensuring that the only place you make a recursive call to a function is when there is only one recursive call, and that call is the very last thing the function does.

Tail recursing resolves the StackOverflowException by allowing the compiler to perform a mechanical transformation turning the recursive call into a goto statement rather than a function call. A goto statement is not a function call, so does not grow the stack.

The addOne example above is not tail recursive because in the non-empty list case x :: xs -> ..., there is a recursive call let tail = addOne xs followed by an append to the resulting list from the recursive call head :: tail.

Accumulation tail recursion

The first (and I think easier to understand) method to make a function tail recursive is to use an accumulator. That is, we add an extra parameter to the function, which accumulates the final result after each call to the function.

Using an accumulator in the addOne example results in this code:

let addOne (input : int list) : int list =
    let rec addOneInner (input : int list) (acc : int list) : int list = 
        match input with
        | [] -> acc
        | x :: xs -> addOneInner xs ((x + 1) :: acc)
    addOneInner input [] |> List.rev

There are a few things going on here, so let’s walk through what is happening:

Firstly, there is an inner recursive function addOneInner rather than the outer function being recursive. This means that the signature of addOne is unchanged from before: val addOne : int list -> int list. Callers of addOne don’t need to worry about any accumulators. The recursive addOneInner function is similar to addOne in the non-tail recursive case, but also takes an accumulator and does the following with it:

  • When passed an empty list, return the accumulator (which has all of the results in it).
  • When passed a non-empty list, add one to the current value and prepend it to the accumulator. Finally make a recursive call with the updated accumulator and the tail of the list.

The outer addOne function then takes result of the recursive addOneInner function (which is the final accumulation of all the values) and reverses the list. This reversal is required because the accumulator is built up from the start of the list backwards i.e. the first element from input is stored in acc, then the second element is stored before it, and so on. The backward list is an artefact of how immutable linked lists are constructed, not a feature of tail recursion. If you don’t find it obvious why we need to do the reversal here, I would definitely recommend trying out this example and seeing what happens if we just return the list at the end.

Brilliant. The StackOverflow issue has been resolved and the following code will succeed at runtime:

let addOne (input : int list) : int list =
    let rec addOneInner (input : int list) (acc : int list) : int list = 
        match input with
        | [] -> acc
        | x :: xs -> addOneInner xs ((x + 1) :: acc)
    addOneInner input [] |> List.rev

// No StackOverflows here!
let yay = [ 0 .. 400000 ] |> addOne

In the initial addOne example, the problem was with the call stack growing too large, as it has a limited size. Our solution has moved the problem off the stack and onto the heap (which has notionally unlimited size). On each iteration of the function, the call stack is not growing, because the compiler has optimised it to be a goto, and the heap is storing the accumulator, which can grow as large as required.

Continuation Passing Style

So we can use the accumulator trick and never have to worry about stack overflows again. Right? Well, almost.

Unfortunately, there are situations where a plain old accumulator just won’t work. These occur when you need to call the recursive function more than once in order to get a result.

For example with trees:

type 'a Tree =
  | Leaf of 'a
  | Node of 'a Tree * 'a Tree

Here is a function, which is not tail recursive, that finds the maximum value in a tree:

let rec findMax (tree : int Tree) : int =
    match tree with
    | Leaf i -> i
    | Node (l, r) -> Math.Max(findMax l, findMax r)

It is not possible to add an accumulator here for tail recursion, as that requires findMax to be the last call in the function and to only be called once, but in the Node case it has to be called twice.

The solution to this situation is called Continuation Passing Style (CPS), and is very similar to the accumulator method; however, rather than accumulating a simple object, a function (or continuation) is accumulated.

let findMax (tree : int Tree) : int =
    // The continuation here can have a generic type, see later in this post for more details.
    let rec findMaxInner (tree : int Tree) (continuation : int -> int) : int =
        match tree with
        | Leaf i -> i |> continuation
        | Node (left, right) ->
            findMaxInner left (fun lMax ->
                findMaxInner right (fun rMax ->
                    System.Math.Max(lMax, rMax) |> continuation
                )
            )
    findMaxInner tree id

So what’s going on here?

Earlier, we saw that the call stack was keeping track of what to do when the current function finishes executing. In CPS, we avoid this requirement by storing “what to do next” in a continuation function. Each code path in a CPS function will then either:

  • Add extra instructions to the continuation, and recurse.
  • Execute the continuation with a result.

In the example above, the continuation function will return an integer (i.e. the largest number seen) and it is constructed by combining instructions comparing the maxima of the two trees in each node. Let’s take a closer look at each case in the pattern match:

  • In the Leaf case, we do not need to do any recursing, so we can execute the stored instructions (i.e. the continuation) passing in the leaf’s value.
  • In the Node case, there are two trees that need exploring (left and right).
    • To handle this, construct an outer continuation (with an inner continuation inside it), which is passed recursively back to findMaxInner.
    • The outer continuation (fun lMax -> ...) is an instruction to be executed when the maximum value of the left tree has been found, and is passed to findMaxInner with the left tree.
    • The inner continuation (fun rMax -> ...) is an instruction to be executed when the maximum value of the right tree has been found, and is passed to findMaxInner with the right tree.
    • The inner continuation only executes once the maxima of the left and right trees have been discovered, and once we know the maxima of both the left and right trees, the maximum of both can be trivially calculated.
    • Note that the result of Math.Max is then piped into continuation, as we may still have other instructions to run (which are stored here).

The important thing to see above is that all code paths (including inside the continuations) end with either:

  • A recursive call to findMaxInner with a continuation
  • An execution of continuation

Here’s a small example of a tree we can pass to findMax:

let tree : int Tree = Node (Leaf 2, Leaf 1)

When passed into findMax, the following code is executed (You can put any of these into a repl and they’ll all evaluate to 2):

findMaxInner (Node (Leaf 2, Leaf 1)) id
findMaxInner (Leaf 2) (fun lMax -> findMaxInner (Leaf 1) (fun rMax -> System.Math.Max(lMax, rMax) |> id))
// 2 is now passed into the continuation
findMaxInner (Leaf 1) (fun rMax -> System.Math.Max(2, rMax) |> id)
// 1 is now passed into the continuation
System.Math.Max(2, 1) |> id
2 |> id
2

At every stage, all we’re doing is calling findMaxInner with an updated tree and continuation. It’s this feature that allows the compiler to compile this code as a simple goto statement instead of putting function calls on the call stack.

Using types to prevent coding errors

There is an easy error to make when writing CPS functions. (It is so easy to do that I actually made this mistake when initially writing this blog post!)

Here’s the example above, with the error to see if you can spot it (this function will type check and run without throwing, but won’t always return you the maximum value in a Tree):

// Buggy code!
let findMax (tree : int Tree) : int =
    let rec findMaxInner (tree : int Tree) (continuation : int -> int) : int =
        match tree with
        | Leaf i -> i |> continuation
        | Node (l, r) ->
            findMaxInner l (fun lMax ->
                findMaxInner r (fun rMax ->
                    System.Math.Max(lMax, rMax)
                )
            )
    findMaxInner tree id

As mentioned in the previous section, all code paths (including inside the continuations) end with either:

  • A recursive call to findMaxInner with a continuation
  • An execution of continuation

In this example, in the innermost continuation, there is no call to the original continuation function, we just return the result of System.Math.Max. The problem here is that if continuation contained instructions to execute after this function has completed, which it often will, we won’t have executed them!

This is a coding error that can be easily avoided using types to ensure the continuation function is being called. The trick is to change the type of the continuation function and the return type of the method.

We currently have
val findMaxInner : int Tree -> (int -> int) -> int

But we can ensure that the continuation is never ignored by introducing a new generic type:
val findMaxInner : int Tree -> (int -> 'ret) -> 'ret

With this definition, it is only possible to create a return value (of type 'ret) by calling continuation. So we have to call it or the code won’t compile! The concrete type of 'ret is fixed in the call to findMaxInner. Therefore in the implementation above, using id as the initial continuation fixes the type of 'ret to be int.

Ensuring that continuation is always called moves a correctness bug (a subtle bug that does not throw exceptions, but does return incorrect values) into a compile time bug that will prevent the program from running.

Also, keep in mind that the continuation is just a set of instructions to be executed. So in the (int -> int) definition, having a concrete return type of int forces the continuation to be: “Give me an int and I’ll do something that generates an int” Whereas in the (int -> 'ret) definition having a generic return type of 'ret allows the continuation to be: “Give me an int and I’ll run the next calculation” which is much less restrictive.

Continuation.sequence

There is one tool from our toolkit that is missing now. In the Tree example above, there are two calls to findMaxInner. What if the data structure is a rose tree (i.e. can contain any number of branches at each node)?

type 'a RoseTree =
  | Leaf of 'a
  | Node of 'a * 'a RoseTree list

Now, in the Node case, one recursive call is needed for every inner rose tree in the list, but the length of the list isn’t known at compile time. One (bad) solution is to do a pattern match for different sized lists and only support lists below a certain length:

let findMax (roseTree : int RoseTree) : int =
    let rec findMaxInner (roseTree : int RoseTree) (continuation : int -> 'ret) : 'ret =
        match roseTree with
        | Leaf i
        | Node (i, [])  -> i |> continuation
        | Node (i, [ x ]) ->
            findMaxInner x (fun xMax ->
                System.Math.Max(i, xMax) |> continuation
            )
        | Node (i, [ x; y ]) ->
            findMaxInner x (fun xMax ->
                findMaxInner y (fun yMax ->
                    System.Math.Max(i, System.Math.Max(xMax, yMax)) |> continuation
                )
            )
        | _ -> failwith "Nodes with lists longer than 2 are not supported"
    findMaxInner roseTree id

This function is tail recursive, but obviously isn’t great as it will throw when any Node in the structure has a list with more than two elements.

There must be a better way!

With Continuation.sequence it is possible to turn a (('a -> 'ret) -> 'ret) list along with a final continuation 'a list -> 'ret into a 'ret.

Or more concretely in our case:
val sequence : ((int -> 'ret) -> 'ret) list -> (int list -> 'ret) -> 'ret

If we have this function, we can re-write findMax as follows:

let findMax (roseTree : int RoseTree) : int =
    let rec findMaxInner (roseTree : int RoseTree) (finalContinuation : int -> 'ret) : 'ret =
        match roseTree with
        | Leaf i ->
            i |> finalContinuation
        | Node (i : int, xs : int RoseTree list) ->
            let continuations : ((int -> 'ret) -> 'ret) list = xs |> List.map findMaxInner
            let finalContinuation (maxValues : int list) : 'ret = List.max (i :: maxValues) |> finalContinuation
            Continuation.sequence continuations finalContinuation
    findMaxInner roseTree id

So Continuation.sequence needs to do the following:

  1. Take a list of partially applied calls to the outer recursive function findMaxInner (i.e. a list of continuation -> 'ret functions)
  2. Take a final continuation (set of instructions) to perform on an int list (i.e. the function we should use when we have values rather recursions)
  3. Thread all the recursions together to end up with a continuation with all values in scope, so that we can call the final continuation.

The implementation of Continuation.sequence is as follows:

[<RequireQualifiedAccess>]
module Continuation =
    let rec sequence<'a, 'ret> (recursions : (('a -> 'ret) -> 'ret) list) (finalContinuation : 'a list -> 'ret) : 'ret =
        match recursions with
        | [] -> [] |> finalContinuation
        | recurse :: recurses ->
            recurse (fun ret ->
                sequence recurses (fun rets ->
                    ret :: rets |> finalContinuation
                )
            )

Like the original CPS function, this can look quite daunting. So let’s step through what’s going on. In the empty list case, there are no 'a values to access, so call finalContinuation with an empty list. In the case where there is one or more recursion functions, the logic is as follows:

  • Pass a continuation (fun ret -> ...) to the first recurse function (in our example we’re passing the continuation to findMaxInner so that it can tell us the max value for one tree at this node).
  • Inside this continuation, make a recursive call to sequence with the rest of the findMaxInner recursions and an inner continuation (fun rets -> ...) (this will find the max values for all the other trees at this node).
  • The inner continuation (fun rets -> ...) now has access to the maxValues for all the trees at this node ret :: rets, so we can pass them to finalContinuation (which in the findMax example will pass the values to List.max)

As with the original CPS example, let’s look at a simple example:

let roseTree : int RoseTree = Node (1, [ Leaf 2; Leaf 3 ])

When into findMax, the following code is executed (You can put any of these into a repl and they’ll all evaluate to 3):

findMax roseTree
findMaxInner (Node (1,[Leaf 2; Leaf 3])) id
Continuation.sequence [findMaxInner (Leaf 2); findMaxInner (Leaf 3)] (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id)
findMaxInner (Leaf 2) (fun ret -> Continuation.sequence [findMaxInner (Leaf 3)] (fun rets -> ret :: rets |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id)))
Continuation.sequence [findMaxInner (Leaf 3)] (fun rets -> 2 :: rets |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id))
findMaxInner (Leaf 3) (fun ret -> Continuation.sequence [] (fun rets -> ret :: rets |> (fun rets -> 2 :: rets |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id))))
Continuation.sequence [] (fun rets' -> 3 :: rets' |> (fun rets -> 2 :: rets |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id)))
3 :: [] |> (fun rets -> 2 :: rets |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id))
2 :: 3 :: [] |> (fun maxOfInnerTrees -> List.max (1 :: maxOfInnerTrees) |> id)
List.max (1 :: 2 :: 3 :: []) |> id
3 |> id
3

In each execution, all that occurs is a function invocation with a continuation that represents what to do when the function has completed. There are two functions we call alternately here, findMaxInner and Continuation.sequence, but the compiler will still optimise this away to goto statements, trading call stack space for heap space (which is where the continuations are stored).

Summary

As a recap, here is everything that was covered in this blog post.

  1. We learnt what recursion is.
  2. We learnt why tail recursion is important.
  3. We learnt how to use an accumulator to prevent stack overflows.
  4. We learnt how to do Continuation Passing Style recursion when accumulators aren’t sufficient.
  5. We learnt how to use Continuation.sequence to handle lists of continuation functions.

With any luck, you will now know enough to avoid some stack overflows due to recursion in your code. If you have enjoyed this post and would like to work with these sorts of concepts on a day to day basis, why not consider applying for a job with us?

Chris Arnott – Developer

Stay up to-date with G-Research

Subscribe to our newsletter to receive news & updates

You can click here to read our privacy policy. You can unsubscribe at anytime.