Tail-recursive functions
Whenever you get an example, it's usually something that's trivial, because the function was tail-recursive, so you don't even need a stack:def fact(n, value=1): if n<2: return 1 return fact(n-1, value*n)That maps directly to:
def fact(n): value = 1 while True: if n<2: return value n, value = n-1, value*nYou can merge the while and if, at which point you realize you have a for in disguise. and then you can realize that, multiplication being commutative and associative and all that, you might as well turn the loop around, so you get:
def fact(n): value = 1 for i in range(2, n+1): value *= i return value(You might also notice that this is just
functools.reduce(operator.mul, range(2, n+1), 1)
, but if you're the kind of person who notices that and finds it more readable, you probably also rewrote the tail-recursive version into a recursive fold/reduce function, and all you had to do was find an iterative reduce function to replace it with.)Continuation stacks
Your real program isn't tail-recursive. Either you didn't bother making that transformation because your language doesn't do tail call elimination (Python doesn't), or the whole reason you're switching from recursive to iterative in the first place is that you couldn't figure out a clean way to write your code tail-recursively.So, now you need a stack. But what goes on the stack?
The most general answer to that is that you want continuations on the stack: what the result of the function does with the result of each recursive call. That may sound scary, and in general it is… but in most practical cases, it's not.
Let's say you have this:
def fact(n): if n < 2: return 1 return n * fact(n-1)What's the continuation? It's "return n * _", where that _ is the return value of the recursive call. You can write a function with one argument that does that. (What about the base case? Well, a function of 1 argument can always ignore its argument). So, instead of storing continuations, you can just store functions:
def fact(n): stack = [] while True: if n < 2: stack.append(lambda _: 1) break stack.append(lambda _, n=n: _ * n) value = None for frame in reversed(stack): value = frame(value) return value(Notice the n=n in the second lambda. See the Python FAQ for an explanation, but basically it's to make sure we're building a function that uses the current value of n, instead of one that closes over the variable n.)
This is undeniably kind of ugly, but we can start simplifying it. If only the base case and the recursive call had the same form, we could factor out the whole function, right? Well, if we start with 1 instead of None, the base case can return _ * 1. And then, yes, we can factor out the whole function, and just store each n value on the stack:
def fact(n): stack = [] while True: if n < 2: stack.append(1) break stack.append(n) value = 1 for frame in reversed(stack): value = value * frame return valueBut once we're doing this, why even store the 1? And, once you take that out, the while loop is obviously a for loop over a range in disguise:
def fact(n): stack = [] for i in range(n, 1, -1): stack.append(i) value = 1 for frame in reversed(stack): value *= frame return valueNow stack is obviously just list(range(n, 1, -1)), so we can skip the loop entirely:
def fact(n): stack = list(range(n, 1, -1)) value = 1 for frame in reversed(stack): value *= frame return valueNow, we don't really care that it's a list, as long as it's something we can pass to reversed. In fact, why even call reversed on a backward range when we can just write a forward range directly?
def fact(n): value = 1 for frame in range(2, n+1): value *= frame return valueNot surprisingly, we ended up with the same function we got from the tail recursive starting point.
Interpreter stacks
Is there a way to do this in general without stacking up continuations? Of course there is. After all, an interpreter doesn't have to call itself recursively just to execute your recursive call (even if CPython does, Stackless doesn't…), and your CPU certainly isn't calling itself recursively to execute compiled recursive code.Here's what a function call does: The caller pushes the "program counter" and the arguments onto the stack, then it jumps to the callee. The callee pops, computes the result, pushes the result, jumps to the popped counter. The only issue is that the callee can have locals that shadow the caller's; you can handle that by just pushing all of your locals (not the post-transformation locals, which include the stack itself, just the set used by the recursive function) as well.
This sounds like it might be hard to write without a goto, but you can always simulate goto with a loop around a state machine. So:
State = enum.Enum('State', 'start cont done') def fact(n): state = State.start stack = [(State.done, n)] while True: if state == State.start: pc, n = stack.pop() if n < 2: # return 1 stack.append(1) state = pc continue # stash locals stack.append((pc, n)) # call recursively stack.append((State.cont, n-1)) state = State.start continue elif state == State.cont: # get return value retval = stack.pop() # restore locals pc, n = stack.pop() # return n * fact(n-1) stack.append(n * retval) state = pc continue elif state == State.done: retval = stack.pop() return retvalBeautiful, right? Well, we can find ways to simplify this. Let's start by using one of the tricks native-code compilers use: in addition to the stack, you've also got registers. As long as you've got enough registers, you can pass arguments in registers instead of on the stack, and you can return values in registers too. And we can just use local variables for the registers. So:
def fact(n): state = State.start pc = State.done stack = [] while True: if state == State.start: if n < 2: # return 1 retval = 1 state = pc continue stack.append((pc, n)) pc, n, state = State.cont, n-1, State.start elif state == State.cont: state, n = stack.pop() retval = n * retval elif state == State.done: return retval
View comments