Library CPS

CPS Conversion

Require Import Eqdep.
Require Import String.
Require Import List.
Require Import Omega.
Require Import Recdef.
Require Import Arith.
Require Import Div2.
Require Import Program.
Set Implicit Arguments.
Unset Automatic Introduction.

We need to be able to generate fresh variable names, so these functions provide some help doing "gensym" for us.
Local Open Scope string_scope.
Fixpoint digit2string (n:nat) : string :=
  match n with
    | 0 => "0"
    | S 0 => "1"
    | S (S n) => digit2string n
  end.

Here, I'm using the Program and measure mechanism to prove that nat2string terminates. The measure n indicates that we mean for n to be strictly smaller on each recursive call. Program generates a proof obligation for us to show that this is the case for each recursive call, which we then knock off in tactic mode.
Program Fixpoint nat2string (n:nat) { measure n }: string :=
  match n with
    | 0 => "0"
    | S 0 => "1"
    | _ => nat2string (div2 n) ++ digit2string n
  end.
Next Obligation. apply lt_div2. omega. Defined.

This module defines a little core lambda calculus with unit and pairs.
Module EXP.
  Definition var := string.

  Inductive exp : Type :=
  | Var_e : var -> exp
  | Lam_e : var -> exp -> exp
  | App_e : exp -> exp -> exp
  | Unit_e : exp
  | Pair_e : exp -> exp -> exp
  | Fst_e : exp -> exp
  | Snd_e : exp -> exp.

Here's a nice trick for making the construction of exp values a little more readable. We're going to use higher-order abstract syntax so that we don't have to write Lam_e x (Var_e x) for the identity, but rather something like lam fun( x => x). The idea is that lam will generate a fresh string to use for the variable, say "v", and then instantiate the meta-level function with Var_e "v". But of course, we need to operate under a state monad to be able to generate fresh variables.

A G is an expression generator.
  Definition G := nat -> (nat * exp).

gen g runs the generator g with a new counter initialized to 0, then throws away the state at the end.
  Definition gen (g:G) := snd (g 0).

The rest of these definitions provide generators for all of the abstract syntax constructors, plus some notation to make it all a lot more readable.
  Definition var_g (x:string) : G :=
    fun n => (n, Var_e x).

Here is where we are leveraging higher order abstract syntax. We generate the fresh name using our current state, updating the state from n to 1 + n. We then invoke the meta-level function f on an appropriate variable generator to get out the body of the lambda, specialized to our fresh variable which we then bind using the Lam_e constructor.
  Definition lam_g (f: G -> G) :=
    fun n => let x := "x" ++ digit2string n in
             let (n', e) := f (var_g x) (1 + n) in
               (n', Lam_e x e).
Now this notation makes it very Haskell-like to write our object- level functions.
  Notation "x ==> e" := (lam_g (fun x => e))
    (right associativity, at level 70).

The rest of the definitions are straightforward.
  Definition app_g (g1 g2:G) :=
    fun n => let (n1,e1) := g1 n in
             let (n2,e2) := g2 n1 in
               (n2, App_e e1 e2).
  Infix "@" := app_g (left associativity, at level 61).
  Definition unit_g : G := fun n => (n, Unit_e).
  Notation "'()'" := unit_g.
  Definition pair_g (g1 g2:G) : G :=
    fun n => let (n1,e1) := g1 n in let (n2,e2) := g2 n1 in (n2,Pair_e e1 e2).
  Notation "[ x ; y ]" := (pair_g x y).
  Definition fst_g (g:G) := fun n => let (n1,e) := g n in (n1,Fst_e e).
  Definition snd_g (g:G) := fun n => let (n1,e) := g n in (n1,Snd_e e).
  Notation "'#1' x" := (fst_g x) (at level 65).
  Notation "'#2' x" := (snd_g x) (at level 65).

Some examples using our embedded syntax.
  Definition identity := (x ==> x).

  Definition apply := (f ==> x ==> f @ x).

  Definition twice := (f ==> x ==> f @ f @ x).

  Definition compose := (f ==> g ==> x ==> f @ g @ x).

  Definition pair := (f ==> g ==> [f @ () ; g @ ()]).

  Definition eta_pair := (x ==> [#1 x ; #2 x]).

We have to run the generators using gen to get actual exp values.
  Eval compute in gen identity.

  Eval compute in gen apply.

  Eval compute in gen twice.

  Eval compute in gen pair.

  Eval compute in gen compose.

  Eval compute in gen eta_pair.
End EXP.

This module defines CPS conversion as a source-to-source transformation. This is the kind of thing you will see in the literature (c.f., Danvy and Filinski) but isn't really what most compilers do as we'll see below.
Module SOURCE_TO_SOURCE.
  Import EXP.

Our translation is complicated by the fact that we need to generate fresh names (for continuations) as we do the CPS conversion. In addition, CPS conversion is itself written in a continuation-passing style! So we are actually building a combined monad here that has both state (for fresh name generation) and a continuation. Note that in Haskell, we would probably use a monad transformer to build this. But you have to be careful: you get two very different behaviors if you compose the state and continuation monads versus the continuation and state monads.

Anyway, M is parameterized by the kind of value that the local computation returns (A) as well as the type that the entire computation returns ans. It takes a continuation k and a number n (for state) as arguments and returns a new state and ans. Conceptually, we run the local computation, threading the state through to get an A value, and then feed the A value, along with the threaded state to k to get out the final state and answer.
  Definition K(A ans:Type) := A -> nat -> (nat * ans).
  Definition M(A ans:Type) := K A ans -> nat -> (nat * ans).
  Definition Ret(A ans:Type)(v:A) : M A ans := fun k n => k v n.
  Definition Bind(A B ans:Type)(c:M A ans)(f:A -> M B ans) : M B ans :=
    fun k n => c (fun v n' => f v k n') n.
  Notation "'ret' x" := (Ret x) (at level 75).
  Notation "x <- c ; f" := (Bind c (fun x => f))
    (right associativity, at level 84, c at next level).

Generate a fresh variable
  Definition freshVar {ans} (x:string) : M var ans :=
    fun k n => k (x ++ nat2string n) (1 + n).

At key points in the CPS translation, we need to get our hands on the meta-level continuation. But this is exactly what callcc is supposed to do for us and indeed, that's what happens here.
  Definition callcc {A ans}(f : K A ans -> M A ans) :=
    fun k n => f k k n.
Some notation will make this a little easier to work with.
  Notation "'letcc' x 'in' e" := (callcc (fun x => e))
    (right associativity, at level 84).
At other points in the CPS translation, we will want to abandon the current meta-level continuation (because we've already captured it elsewhere) and this is what throw lets us do. We just immediately return an answer and ignore the continuation.
  Definition throw{A ans}(a:ans) : M A ans :=
    fun k n => (n,a).

Although we don't use it here, we could also provide a facility to capture the current state, or change the current state. This is a "call-with-current-state" operation.
  Definition callcs {A ans}(f : nat -> M A ans) :=
    fun n k => f n k n.
  Notation "'letcs' x 'in' e" := (callcs (fun x => e))
    (right associativity, at level 84).

Finally, the run operation takes a continuation c (of type K B A) which has already been given its local return value of type B, and runs the continuation to get out the A. But it must thread the state through this computation, so we get out an M computation. We'll see below how this, in conjunction with callcc, are used to reify a meta-level continuation into an object-level one.
  Definition run{A ans}(c : nat -> nat * A) : M A ans :=
    fun k n => let (n',v) := c n in k v n'.

  Fixpoint ET (e:exp) : M exp exp :=
    match e with
      | Var_e x => ret Var_e x
      | Unit_e => ret Unit_e
      | Pair_e e1 e2 =>
        v1 <- ET e1 ;
        v2 <- ET e2 ;
        ret (Pair_e v1 v2)
      | Fst_e e =>
        v <- ET e ;
        ret (Fst_e v)
      | Snd_e e =>
        v <- ET e ;
        ret (Snd_e v)
      | App_e e1 e2 =>
        v1 <- ET e1 ;
        v2 <- ET e2 ;
        letcc k in
        x <- freshVar "_v" ;
        e <- run (k (Var_e x)) ;
        throw (App_e (App_e v1 v2) (Lam_e x e))
      | Lam_e x e =>
        c <- freshVar "_k" ;
        e' <- run (ET e (fun v n => (n,App_e (Var_e c) v))) ;
        ret (Lam_e x (Lam_e c e'))
    end.

We kick off the entire translation by calling ET and passing it an initial continuation that corresponds to the identity function, and an initial state.
  Definition cps (g:G) : exp :=
    snd (ET (gen g) (fun v n => (n,v)) 0).

  Eval compute in cps identity.

  Eval compute in cps apply.

  Eval compute in cps twice.

  Eval compute in cps pair.

  Eval compute in cps compose.

  Eval compute in cps eta_pair.
End SOURCE_TO_SOURCE.

This module defines a more refined intermediate language that captures the real structure of CPS translation. It is much closer to what a compiler uses where we make a distinction between "small" values (called operands here) versus "large" values that we want to bind to variables. This form of CPS intermediate language corresponds pretty closely to what was used in the SML/NJ compiler and facilitates lots of optimization.

Notice the structure: none of the operations supports nested expressions. They only manipulate operands (which are either variables or constants.) Everything else has to be bound to variable using a "let".

Note that an cexp is a sequence of let-declarations that is terminated in one of three ways: (1) A function call which passes in the argument and the continuation, (2) a function "return" which just invokes a continuation with the result, and (3) program termination.

Finally, note that you could always embed this language into the original one, using an application of a lambda to encode the let's.
Module CS153.
  Import EXP.
  Inductive operand :=
  | Var_c : var -> operand
  | Unit_c : operand.
  Inductive decl :=
  | Lam_c : var -> var -> cexp -> decl
  | Kont_c : var -> cexp -> decl
  | Pair_c : operand -> operand -> decl
  | Fst_c : operand -> decl
  | Snd_c : operand -> decl
  with cexp :=
  | App_c : operand -> operand -> operand -> cexp
  | Ret_c : operand -> operand -> cexp
  | Exit_c : operand -> cexp
  | Let_c : var -> decl -> cexp -> cexp.

This will make it a little easier to read the output of the CPS translation.
  Notation "'Let' x := e1 'in' e2" := (Let_c x e1 e2)
    (right associativity, at level 70).
  Notation "$ x" := (Var_c x) (at level 0).
  Implicit Arguments inl [A B].
  Implicit Arguments inr [A B].

I want to show you the operational semantics for this language and point out that it no longer requires a stack. In particular, notice that I'm able to define the "step" function without appealing to any form of recursion.
  Definition env(A:Type) := list (var * A).
  Fixpoint lookup A (env:env A) (x:var) : option A :=
    match env with
      | nil => None
      | (y,v)::env' => if string_dec x y then Some v else lookup env' x
    end.

  Inductive value : Type :=
  | Unit_v : value
  | Pair_v : value -> value -> value
  | Lam_v : var -> var -> cexp -> env value -> value
  | Kont_v : var -> cexp -> env value -> value.

  Inductive answer : Type :=
  | Value : value -> answer
  | TypeError : answer.

  Definition evalop (op:operand) (p: env value) : answer :=
    match op with
      | Var_c x => match lookup p x with | Some v => Value v | None => TypeError end
      | Unit_c => Value Unit_v
    end.

  Definition evaldecl (d:decl) (p:env value) : answer :=
    match d with
      | Lam_c x k e => Value (Lam_v x k e p)
      | Kont_c x e => Value (Kont_v x e p)
      | Pair_c op1 op2 =>
        match evalop op1 p, evalop op2 p with
          | Value v1, Value v2 => Value (Pair_v v1 v2)
          | _, _ => TypeError
        end
      | Fst_c op =>
        match evalop op p with
          | Value (Pair_v v1 v2) => Value v1
          | _ => TypeError
        end
      | Snd_c op =>
        match evalop op p with
          | Value (Pair_v v1 v2) => Value v2
          | _ => TypeError
        end
    end.

  Definition step_fn (e : cexp) (p : env value) : answer + (cexp * env value) :=
    match e with
      | Exit_c op => inl (evalop op p)
      | App_c op1 op2 op3 =>
        match evalop op1 p, evalop op2 p, evalop op3 p with
          | Value (Lam_v x k e p'), Value v2, Value v3 => inr (e, (k,v3)::(x,v2)::p')
          | _, _, _ => inl TypeError
        end
      | Ret_c op1 op2 =>
        match evalop op1 p, evalop op2 p with
          | Value (Kont_v x e p'), Value v => inr (e, (x,v)::p')
          | _, _ => inl TypeError
        end
      | Let_c x d e =>
        match evaldecl d p with
          | Value v => inr (e, (x,v)::p)
          | _ => inl TypeError
        end
    end.

  Fixpoint eval(n:nat)(e:cexp)(p:env value) : answer :=
    match n with
      | 0 => TypeError
      | S n => match step_fn e p with
                 | inl a => a
                 | inr (e,p) => eval n e p
               end
    end.

The translation to this form is largely the same except that we are forced to let-bind some things that previously could be nested.
  Definition M(A ans:Type) := (A -> nat -> (nat * ans)) -> nat -> (nat * ans).
  Definition Ret{A ans:Type}(v:A) : M A ans := fun k n => k v n.
  Definition Bind{A B ans:Type}(c:M A ans)(f:A -> M B ans) : M B ans :=
    fun k n => c (fun v n' => f v k n') n.
  Notation "'ret' x" := (Ret x) (at level 75).
  Notation "x <- c ; f" := (Bind c (fun x => f))
    (right associativity, at level 84, c at next level).

  Definition freshVar{ans:Type}(s:string) : M var ans :=
    fun k n => k (s ++ nat2string n) (1 + n).

  Definition callcc {A ans}(f : (A -> nat -> nat * ans) -> M A ans) :=
    fun k n => f k k n.
  Notation "'letcc' x 'in' e" := (callcc (fun x => e))
    (right associativity, at level 84).
  Definition run{A ans}(c : nat -> nat * A) : M A ans :=
    fun k n => let (n',v) := c n in k v n'.
  Definition throw{A ans}(a:ans) : M A ans :=
    fun k n => (n,a).

Notice that our continuation is expecting us to feed it an operand so it can build a cexp, but all we have is a decl. So we create a fresh variable x, and let-bind the declaration to x, feeding our continuation Var_c x to generate the entire expression.
  Definition letbind (d:decl) : M operand cexp :=
    x <- freshVar "_x" ;
    letcc k in
    e <- run (k (Var_c x)) ;
    throw (Let_c x d e).

So other than the let-binding, the translation is largely the same.
  Fixpoint ET (e:exp) : M operand cexp :=
    match e with
      | Var_e x => ret Var_c x
      | Unit_e => ret Unit_c
      | Pair_e e1 e2 =>
        v1 <- ET e1 ;
        v2 <- ET e2 ;
        letbind (Pair_c v1 v2)
      | Fst_e e =>
        v <- ET e ;
        letbind (Fst_c v)
      | Snd_e e =>
        v <- ET e ;
        letbind (Snd_c v)
      | App_e e1 e2 =>
        v1 <- ET e1 ;
        v2 <- ET e2 ;
        letcc k in
        x <- freshVar "_v" ;
        e <- run (k (Var_c x)) ;
        c <- letbind (Kont_c x e) ;
        throw (App_c v1 v2 c)
      | Lam_e x e =>
        c <- freshVar "_c" ;
        e' <- run (ET e (fun v n => (n,Ret_c (Var_c c) v))) ;
        letbind (Lam_c x c e')
    end.

  Definition cps (g:G) : cexp :=
    snd (ET (gen g) (fun v n => (n, Exit_c v)) 0).

  Eval compute in cps identity.

  Eval compute in cps apply.

  Eval compute in cps twice.

  Eval compute in cps pair.

  Eval compute in cps compose.

  Eval compute in eval 100 (cps ((x ==> x) @ ())) nil.
  Eval compute in eval 100 (cps (twice @ identity @ ())) nil.
End CS153.

Exercises:
  • Add Callcc_e and Throw_e to the source language and extend the translations to handle these things. You shouldn't need to modify the target languages.

  • Add booleans, a conditional (If_e) expression, and short-circuited And_e and Or_e expressions to the language and extend one of the translations to cover these new constructs. Be careful! If you're not careful, you'll end up making a copy of the stack. And of course by short-circuited, I mean that you should only end up evaluating an expression if it's needed. You get more credit for avoiding unnecessary intermediate continuations.

  • Set up typing rules (after the fact) for expressions (e.g., G |- e ; t). Then set up a type translation T : type -> type which captures the invariants of the CPS translation. In particular, if |- e : t, then we should have that |- cps e : T t. I don't expect you to prove this (because it's a pain in the ass with the fresh variable generation) but I do expect you be able to set up the definitions and theorems and reason through the argument on paper informally.

  • Extra credit: formally prove the cps translation respects your type translation.

  • Super Extra credit: formally prove that the cps translation respects the semantics of the original program.