(** * CPS Conversion *)
Require Import Eqdep.
Require Import String.
Require Import List.
Require Import Omega.
Require Import Recdef.
Require Import Arith.
Require Import Div2.
Require Import Program.
Set Implicit Arguments.
Unset Automatic Introduction.
(** We need to be able to generate fresh variable names, so these
functions provide some help doing "gensym" for us. *)
Local Open Scope string_scope.
Fixpoint digit2string (n:nat) : string :=
match n with
| 0 => "0"
| S 0 => "1"
| S (S n) => digit2string n
end.
(** Here, I'm using the [Program] and [measure] mechanism to prove that
[nat2string] terminates. The [measure n] indicates that we mean for
[n] to be strictly smaller on each recursive call. [Program] generates
a proof obligation for us to show that this is the case for each
recursive call, which we then knock off in tactic mode. *)
Program Fixpoint nat2string (n:nat) { measure n }: string :=
match n with
| 0 => "0"
| S 0 => "1"
| _ => nat2string (div2 n) ++ digit2string n
end.
Next Obligation. apply lt_div2. omega. Defined.
(** This module defines a little core lambda calculus with unit and pairs. *)
Module EXP.
Definition var := string.
Inductive exp : Type :=
| Var_e : var -> exp
| Lam_e : var -> exp -> exp
| App_e : exp -> exp -> exp
| Unit_e : exp
| Pair_e : exp -> exp -> exp
| Fst_e : exp -> exp
| Snd_e : exp -> exp.
(** Here's a nice trick for making the construction of [exp] values a
little more readable. We're going to use higher-order abstract
syntax so that we don't have to write [Lam_e x (Var_e x)] for the
identity, but rather something like [lam (fun x => x)]. The idea
is that [lam] will generate a fresh string to use for the variable,
say ["v"], and then instantiate the meta-level function with
[Var_e "v"]. But of course, we need to operate under a state monad
to be able to generate fresh variables. *)
(** A [G] is an expression generator. *)
Definition G := nat -> (nat * exp).
(** [gen g] runs the generator [g] with a new counter initialized to [0],
then throws away the state at the end. *)
Definition gen (g:G) := snd (g 0).
(** The rest of these definitions provide generators for all of the
abstract syntax constructors, plus some notation to make it all
a lot more readable. *)
Definition var_g (x:string) : G :=
fun n => (n, Var_e x).
(** Here is where we are leveraging higher order abstract syntax.
We generate the fresh name using our current state, updating the
state from [n] to [1 + n]. We then invoke the meta-level function
[f] on an appropriate variable generator to get out the body of
the lambda, specialized to our fresh variable which we then bind
using the [Lam_e] constructor. *)
Definition lam_g (f: G -> G) :=
fun n => let x := "x" ++ digit2string n in
let (n', e) := f (var_g x) (1 + n) in
(n', Lam_e x e).
(** Now this notation makes it very Haskell-like to write our object-
level functions. *)
Notation "x ==> e" := (lam_g (fun x => e))
(right associativity, at level 70).
(** The rest of the definitions are straightforward. *)
Definition app_g (g1 g2:G) :=
fun n => let (n1,e1) := g1 n in
let (n2,e2) := g2 n1 in
(n2, App_e e1 e2).
Infix "@" := app_g (left associativity, at level 61).
Definition unit_g : G := fun n => (n, Unit_e).
Notation "'()'" := unit_g.
Definition pair_g (g1 g2:G) : G :=
fun n => let (n1,e1) := g1 n in let (n2,e2) := g2 n1 in (n2,Pair_e e1 e2).
Notation "[ x ; y ]" := (pair_g x y).
Definition fst_g (g:G) := fun n => let (n1,e) := g n in (n1,Fst_e e).
Definition snd_g (g:G) := fun n => let (n1,e) := g n in (n1,Snd_e e).
Notation "'#1' x" := (fst_g x) (at level 65).
Notation "'#2' x" := (snd_g x) (at level 65).
(** Some examples using our embedded syntax. *)
Definition identity := (x ==> x).
Definition apply := (f ==> x ==> f @ x).
Definition twice := (f ==> x ==> f @ f @ x).
Definition compose := (f ==> g ==> x ==> f @ g @ x).
Definition pair := (f ==> g ==> [f @ () ; g @ ()]).
Definition eta_pair := (x ==> [#1 x ; #2 x]).
(** We have to run the generators using [gen] to get actual [exp] values. *)
Eval compute in gen identity.
Eval compute in gen apply.
Eval compute in gen twice.
Eval compute in gen pair.
Eval compute in gen compose.
Eval compute in gen eta_pair.
End EXP.
(** This module defines CPS conversion as a source-to-source transformation.
This is the kind of thing you will see in the literature (c.f., Danvy and
Filinski) but isn't really what most compilers do as we'll see below. *)
Module SOURCE_TO_SOURCE.
Import EXP.
(** Our translation is complicated by the fact that we need to generate
fresh names (for continuations) as we do the CPS conversion. In addition,
CPS conversion is itself written in a continuation-passing style! So
we are actually building a combined monad here that has both state
(for fresh name generation) and a continuation. Note that in Haskell,
we would probably use a monad transformer to build this. But you have
to be careful: you get two very different behaviors if you compose the
state and continuation monads versus the continuation and state monads.
Anyway, [M] is parameterized by the kind of value that the local computation
returns ([A]) as well as the type that the entire computation returns [ans].
It takes a continuation [k] and a number [n] (for state) as arguments and
returns a new state and [ans]. Conceptually, we run the local computation,
threading the state through to get an [A] value, and then feed the [A] value,
along with the threaded state to [k] to get out the final state and answer.
*)
Definition K(A ans:Type) := A -> nat -> (nat * ans).
Definition M(A ans:Type) := K A ans -> nat -> (nat * ans).
Definition Ret(A ans:Type)(v:A) : M A ans := fun k n => k v n.
Definition Bind(A B ans:Type)(c:M A ans)(f:A -> M B ans) : M B ans :=
fun k n => c (fun v n' => f v k n') n.
Notation "'ret' x" := (Ret x) (at level 75).
Notation "x <- c ; f" := (Bind c (fun x => f))
(right associativity, at level 84, c at next level).
(** Generate a fresh variable *)
Definition freshVar {ans} (x:string) : M var ans :=
fun k n => k (x ++ nat2string n) (1 + n).
(** At key points in the CPS translation, we need to get our hands on the
meta-level continuation. But this is exactly what [callcc] is supposed
to do for us and indeed, that's what happens here. *)
Definition callcc {A ans}(f : K A ans -> M A ans) :=
fun k n => f k k n.
(** Some notation will make this a little easier to work with. *)
Notation "'letcc' x 'in' e" := (callcc (fun x => e))
(right associativity, at level 84).
(** At other points in the CPS translation, we will want to abandon the current
meta-level continuation (because we've already captured it elsewhere) and
this is what [throw] lets us do. We just immediately return an answer and
ignore the continuation. *)
Definition throw{A ans}(a:ans) : M A ans :=
fun k n => (n,a).
(** Although we don't use it here, we could also provide a facility to
capture the current state, or change the current state. This is a
"call-with-current-state" operation. *)
Definition callcs {A ans}(f : nat -> M A ans) :=
fun n k => f n k n.
Notation "'letcs' x 'in' e" := (callcs (fun x => e))
(right associativity, at level 84).
(** Finally, the [run] operation takes a continuation [c] (of type [K B A])
which has already been given its local return value of type [B], and
runs the continuation to get out the [A]. But it must thread the state
through this computation, so we get out an [M] computation. We'll see
below how this, in conjunction with [callcc], are used to reify a
meta-level continuation into an object-level one. *)
Definition run{A ans}(c : nat -> nat * A) : M A ans :=
fun k n => let (n',v) := c n in k v n'.
(* Now that we have set up all of our monadic operations, the translation
is pretty straightforward except for the [App_e] and [Lam_e] cases.
For [App_e e1 e2] we would like to pass in [e2] followed by a continuation
which computes the rest of the program, once [e1] has calculated its result.
We start by capturing the current meta-level continuation using [letcc].
Ignoring the state, this meta-level continuation [k] is essentially a
Coq function from [exp]s to [exp]s (i.e., a context!). We can turn it into
an [exp] by generating a fresh variable [x] and by applying the meta-level
continuation to [Var_e x]. But we have to make sure we thread the state
through, hence the use of [run] to do the application. Finally, because
we've already accounted for the rest of the program by reifying the
current continuation, we [throw] the application of [e1] to [e2] and
the continuation to avoid returning the whole thing again. In short,
when we reified the meta-level continuation, we made a copy of the "stack".
But we only need one copy, so we can throw away the other one.
For [Lam_e x e], we need to transform the function to take an extra
fresh argument [c]. Instead of invoking the current continuation, when
translating the body [e], we want it to invoke [c]. So again, we [run]
the translation, but feed it an initial meta-level continuation which
applies the object-level continuation. Finally, we return the transformed
lambda to the current meta-level continuation.
*)
Fixpoint ET (e:exp) : M exp exp :=
match e with
| Var_e x => ret Var_e x
| Unit_e => ret Unit_e
| Pair_e e1 e2 =>
v1 <- ET e1 ;
v2 <- ET e2 ;
ret (Pair_e v1 v2)
| Fst_e e =>
v <- ET e ;
ret (Fst_e v)
| Snd_e e =>
v <- ET e ;
ret (Snd_e v)
| App_e e1 e2 =>
v1 <- ET e1 ;
v2 <- ET e2 ;
letcc k in
x <- freshVar "_v" ;
e <- run (k (Var_e x)) ;
throw (App_e (App_e v1 v2) (Lam_e x e))
| Lam_e x e =>
c <- freshVar "_k" ;
e' <- run (ET e (fun v n => (n,App_e (Var_e c) v))) ;
ret (Lam_e x (Lam_e c e'))
end.
(** We kick off the entire translation by calling [ET] and passing
it an initial continuation that corresponds to the identity function,
and an initial state. *)
Definition cps (g:G) : exp :=
snd (ET (gen g) (fun v n => (n,v)) 0).
Eval compute in cps identity.
Eval compute in cps apply.
Eval compute in cps twice.
Eval compute in cps pair.
Eval compute in cps compose.
Eval compute in cps eta_pair.
End SOURCE_TO_SOURCE.
(** This module defines a more refined intermediate language that captures
the real structure of CPS translation. It is much closer to what a
compiler uses where we make a distinction between "small" values
(called [operands] here) versus "large" values that we want to bind
to variables. This form of CPS intermediate language corresponds pretty
closely to what was used in the SML/NJ compiler and facilitates lots of
optimization.
Notice the structure: none of the operations supports nested expressions.
They only manipulate operands (which are either variables or constants.)
Everything else has to be bound to variable using a "let".
Note that an [cexp] is a sequence of let-declarations that is terminated
in one of three ways: (1) A function call which passes in the argument and
the continuation, (2) a function "return" which just invokes a continuation
with the result, and (3) program termination.
Finally, note that you could always embed this language into the original
one, using an application of a lambda to encode the let's.
*)
Module CS153.
Import EXP.
Inductive operand :=
| Var_c : var -> operand
| Unit_c : operand.
Inductive decl :=
| Lam_c : var -> var -> cexp -> decl
| Kont_c : var -> cexp -> decl
| Pair_c : operand -> operand -> decl
| Fst_c : operand -> decl
| Snd_c : operand -> decl
with cexp :=
| App_c : operand -> operand -> operand -> cexp
| Ret_c : operand -> operand -> cexp
| Exit_c : operand -> cexp
| Let_c : var -> decl -> cexp -> cexp.
(** This will make it a little easier to read the output of the CPS translation. *)
Notation "'Let' x := e1 'in' e2" := (Let_c x e1 e2)
(right associativity, at level 70).
Notation "$ x" := (Var_c x) (at level 0).
Implicit Arguments inl [A B].
Implicit Arguments inr [A B].
(** I want to show you the operational semantics for this language and point out
that it no longer requires a stack. In particular, notice that I'm able to
define the "step" function without appealing to any form of recursion. *)
Definition env(A:Type) := list (var * A).
Fixpoint lookup A (env:env A) (x:var) : option A :=
match env with
| nil => None
| (y,v)::env' => if string_dec x y then Some v else lookup env' x
end.
Inductive value : Type :=
| Unit_v : value
| Pair_v : value -> value -> value
| Lam_v : var -> var -> cexp -> env value -> value
| Kont_v : var -> cexp -> env value -> value.
Inductive answer : Type :=
| Value : value -> answer
| TypeError : answer.
Definition evalop (op:operand) (p: env value) : answer :=
match op with
| Var_c x => match lookup p x with | Some v => Value v | None => TypeError end
| Unit_c => Value Unit_v
end.
Definition evaldecl (d:decl) (p:env value) : answer :=
match d with
| Lam_c x k e => Value (Lam_v x k e p)
| Kont_c x e => Value (Kont_v x e p)
| Pair_c op1 op2 =>
match evalop op1 p, evalop op2 p with
| Value v1, Value v2 => Value (Pair_v v1 v2)
| _, _ => TypeError
end
| Fst_c op =>
match evalop op p with
| Value (Pair_v v1 v2) => Value v1
| _ => TypeError
end
| Snd_c op =>
match evalop op p with
| Value (Pair_v v1 v2) => Value v2
| _ => TypeError
end
end.
Definition step_fn (e : cexp) (p : env value) : answer + (cexp * env value) :=
match e with
| Exit_c op => inl (evalop op p)
| App_c op1 op2 op3 =>
match evalop op1 p, evalop op2 p, evalop op3 p with
| Value (Lam_v x k e p'), Value v2, Value v3 => inr (e, (k,v3)::(x,v2)::p')
| _, _, _ => inl TypeError
end
| Ret_c op1 op2 =>
match evalop op1 p, evalop op2 p with
| Value (Kont_v x e p'), Value v => inr (e, (x,v)::p')
| _, _ => inl TypeError
end
| Let_c x d e =>
match evaldecl d p with
| Value v => inr (e, (x,v)::p)
| _ => inl TypeError
end
end.
Fixpoint eval(n:nat)(e:cexp)(p:env value) : answer :=
match n with
| 0 => TypeError
| S n => match step_fn e p with
| inl a => a
| inr (e,p) => eval n e p
end
end.
(** The translation to this form is largely the same except that we are
forced to let-bind some things that previously could be nested. *)
Definition M(A ans:Type) := (A -> nat -> (nat * ans)) -> nat -> (nat * ans).
Definition Ret{A ans:Type}(v:A) : M A ans := fun k n => k v n.
Definition Bind{A B ans:Type}(c:M A ans)(f:A -> M B ans) : M B ans :=
fun k n => c (fun v n' => f v k n') n.
Notation "'ret' x" := (Ret x) (at level 75).
Notation "x <- c ; f" := (Bind c (fun x => f))
(right associativity, at level 84, c at next level).
Definition freshVar{ans:Type}(s:string) : M var ans :=
fun k n => k (s ++ nat2string n) (1 + n).
Definition callcc {A ans}(f : (A -> nat -> nat * ans) -> M A ans) :=
fun k n => f k k n.
Notation "'letcc' x 'in' e" := (callcc (fun x => e))
(right associativity, at level 84).
Definition run{A ans}(c : nat -> nat * A) : M A ans :=
fun k n => let (n',v) := c n in k v n'.
Definition throw{A ans}(a:ans) : M A ans :=
fun k n => (n,a).
(** Notice that our continuation is expecting us to feed it an [operand] so
it can build a [cexp], but all we have is a [decl]. So we create a
fresh variable [x], and let-bind the declaration to [x], feeding our
continuation [Var_c x] to generate the entire expression. *)
Definition letbind (d:decl) : M operand cexp :=
x <- freshVar "_x" ;
letcc k in
e <- run (k (Var_c x)) ;
throw (Let_c x d e).
(** So other than the let-binding, the translation is largely the same. *)
Fixpoint ET (e:exp) : M operand cexp :=
match e with
| Var_e x => ret Var_c x
| Unit_e => ret Unit_c
| Pair_e e1 e2 =>
v1 <- ET e1 ;
v2 <- ET e2 ;
letbind (Pair_c v1 v2)
| Fst_e e =>
v <- ET e ;
letbind (Fst_c v)
| Snd_e e =>
v <- ET e ;
letbind (Snd_c v)
| App_e e1 e2 =>
v1 <- ET e1 ;
v2 <- ET e2 ;
letcc k in
x <- freshVar "_v" ;
e <- run (k (Var_c x)) ;
c <- letbind (Kont_c x e) ;
throw (App_c v1 v2 c)
| Lam_e x e =>
c <- freshVar "_c" ;
e' <- run (ET e (fun v n => (n,Ret_c (Var_c c) v))) ;
letbind (Lam_c x c e')
end.
Definition cps (g:G) : cexp :=
snd (ET (gen g) (fun v n => (n, Exit_c v)) 0).
Eval compute in cps identity.
Eval compute in cps apply.
Eval compute in cps twice.
Eval compute in cps pair.
Eval compute in cps compose.
Eval compute in eval 100 (cps ((x ==> x) @ ())) nil.
Eval compute in eval 100 (cps (twice @ identity @ ())) nil.
End CS153.
(** Exercises:
- Add Callcc_e and Throw_e to the source language and extend the translations
to handle these things. You shouldn't need to modify the target languages.
- Add booleans, a conditional ([If_e]) expression, and short-circuited [And_e]
and [Or_e] expressions to the language and extend one of the translations to
cover these new constructs. Be careful! If you're not careful, you'll end up
making a copy of the stack. And of course by short-circuited, I mean that
you should only end up evaluating an expression if it's needed. You get
more credit for avoiding unnecessary intermediate continuations.
- Set up typing rules (after the fact) for expressions (e.g., [G |- e ; t]).
Then set up a type translation [T : type -> type] which captures the invariants
of the CPS translation. In particular, if [|- e : t], then
we should have that [|- cps e : T t]. I don't expect you to prove this
(because it's a pain in the ass with the fresh variable generation) but
I do expect you be able to set up the definitions and theorems and reason
through the argument on paper informally.
- Extra credit: formally prove the cps translation respects your type translation.
- Super Extra credit: formally prove that the cps translation respects the semantics
of the original program.
*)