(** * CPS Conversion *) Require Import Eqdep. Require Import String. Require Import List. Require Import Omega. Require Import Recdef. Require Import Arith. Require Import Div2. Require Import Program. Set Implicit Arguments. Unset Automatic Introduction. (** We need to be able to generate fresh variable names, so these functions provide some help doing "gensym" for us. *) Local Open Scope string_scope. Fixpoint digit2string (n:nat) : string := match n with | 0 => "0" | S 0 => "1" | S (S n) => digit2string n end. (** Here, I'm using the [Program] and [measure] mechanism to prove that [nat2string] terminates. The [measure n] indicates that we mean for [n] to be strictly smaller on each recursive call. [Program] generates a proof obligation for us to show that this is the case for each recursive call, which we then knock off in tactic mode. *) Program Fixpoint nat2string (n:nat) { measure n }: string := match n with | 0 => "0" | S 0 => "1" | _ => nat2string (div2 n) ++ digit2string n end. Next Obligation. apply lt_div2. omega. Defined. (** This module defines a little core lambda calculus with unit and pairs. *) Module EXP. Definition var := string. Inductive exp : Type := | Var_e : var -> exp | Lam_e : var -> exp -> exp | App_e : exp -> exp -> exp | Unit_e : exp | Pair_e : exp -> exp -> exp | Fst_e : exp -> exp | Snd_e : exp -> exp. (** Here's a nice trick for making the construction of [exp] values a little more readable. We're going to use higher-order abstract syntax so that we don't have to write [Lam_e x (Var_e x)] for the identity, but rather something like [lam (fun x => x)]. The idea is that [lam] will generate a fresh string to use for the variable, say ["v"], and then instantiate the meta-level function with [Var_e "v"]. But of course, we need to operate under a state monad to be able to generate fresh variables. *) (** A [G] is an expression generator. *) Definition G := nat -> (nat * exp). (** [gen g] runs the generator [g] with a new counter initialized to [0], then throws away the state at the end. *) Definition gen (g:G) := snd (g 0). (** The rest of these definitions provide generators for all of the abstract syntax constructors, plus some notation to make it all a lot more readable. *) Definition var_g (x:string) : G := fun n => (n, Var_e x). (** Here is where we are leveraging higher order abstract syntax. We generate the fresh name using our current state, updating the state from [n] to [1 + n]. We then invoke the meta-level function [f] on an appropriate variable generator to get out the body of the lambda, specialized to our fresh variable which we then bind using the [Lam_e] constructor. *) Definition lam_g (f: G -> G) := fun n => let x := "x" ++ digit2string n in let (n', e) := f (var_g x) (1 + n) in (n', Lam_e x e). (** Now this notation makes it very Haskell-like to write our object- level functions. *) Notation "x ==> e" := (lam_g (fun x => e)) (right associativity, at level 70). (** The rest of the definitions are straightforward. *) Definition app_g (g1 g2:G) := fun n => let (n1,e1) := g1 n in let (n2,e2) := g2 n1 in (n2, App_e e1 e2). Infix "@" := app_g (left associativity, at level 61). Definition unit_g : G := fun n => (n, Unit_e). Notation "'()'" := unit_g. Definition pair_g (g1 g2:G) : G := fun n => let (n1,e1) := g1 n in let (n2,e2) := g2 n1 in (n2,Pair_e e1 e2). Notation "[ x ; y ]" := (pair_g x y). Definition fst_g (g:G) := fun n => let (n1,e) := g n in (n1,Fst_e e). Definition snd_g (g:G) := fun n => let (n1,e) := g n in (n1,Snd_e e). Notation "'#1' x" := (fst_g x) (at level 65). Notation "'#2' x" := (snd_g x) (at level 65). (** Some examples using our embedded syntax. *) Definition identity := (x ==> x). Definition apply := (f ==> x ==> f @ x). Definition twice := (f ==> x ==> f @ f @ x). Definition compose := (f ==> g ==> x ==> f @ g @ x). Definition pair := (f ==> g ==> [f @ () ; g @ ()]). Definition eta_pair := (x ==> [#1 x ; #2 x]). (** We have to run the generators using [gen] to get actual [exp] values. *) Eval compute in gen identity. Eval compute in gen apply. Eval compute in gen twice. Eval compute in gen pair. Eval compute in gen compose. Eval compute in gen eta_pair. End EXP. (** This module defines CPS conversion as a source-to-source transformation. This is the kind of thing you will see in the literature (c.f., Danvy and Filinski) but isn't really what most compilers do as we'll see below. *) Module SOURCE_TO_SOURCE. Import EXP. (** Our translation is complicated by the fact that we need to generate fresh names (for continuations) as we do the CPS conversion. In addition, CPS conversion is itself written in a continuation-passing style! So we are actually building a combined monad here that has both state (for fresh name generation) and a continuation. Note that in Haskell, we would probably use a monad transformer to build this. But you have to be careful: you get two very different behaviors if you compose the state and continuation monads versus the continuation and state monads. Anyway, [M] is parameterized by the kind of value that the local computation returns ([A]) as well as the type that the entire computation returns [ans]. It takes a continuation [k] and a number [n] (for state) as arguments and returns a new state and [ans]. Conceptually, we run the local computation, threading the state through to get an [A] value, and then feed the [A] value, along with the threaded state to [k] to get out the final state and answer. *) Definition K(A ans:Type) := A -> nat -> (nat * ans). Definition M(A ans:Type) := K A ans -> nat -> (nat * ans). Definition Ret(A ans:Type)(v:A) : M A ans := fun k n => k v n. Definition Bind(A B ans:Type)(c:M A ans)(f:A -> M B ans) : M B ans := fun k n => c (fun v n' => f v k n') n. Notation "'ret' x" := (Ret x) (at level 75). Notation "x <- c ; f" := (Bind c (fun x => f)) (right associativity, at level 84, c at next level). (** Generate a fresh variable *) Definition freshVar {ans} (x:string) : M var ans := fun k n => k (x ++ nat2string n) (1 + n). (** At key points in the CPS translation, we need to get our hands on the meta-level continuation. But this is exactly what [callcc] is supposed to do for us and indeed, that's what happens here. *) Definition callcc {A ans}(f : K A ans -> M A ans) := fun k n => f k k n. (** Some notation will make this a little easier to work with. *) Notation "'letcc' x 'in' e" := (callcc (fun x => e)) (right associativity, at level 84). (** At other points in the CPS translation, we will want to abandon the current meta-level continuation (because we've already captured it elsewhere) and this is what [throw] lets us do. We just immediately return an answer and ignore the continuation. *) Definition throw{A ans}(a:ans) : M A ans := fun k n => (n,a). (** Although we don't use it here, we could also provide a facility to capture the current state, or change the current state. This is a "call-with-current-state" operation. *) Definition callcs {A ans}(f : nat -> M A ans) := fun n k => f n k n. Notation "'letcs' x 'in' e" := (callcs (fun x => e)) (right associativity, at level 84). (** Finally, the [run] operation takes a continuation [c] (of type [K B A]) which has already been given its local return value of type [B], and runs the continuation to get out the [A]. But it must thread the state through this computation, so we get out an [M] computation. We'll see below how this, in conjunction with [callcc], are used to reify a meta-level continuation into an object-level one. *) Definition run{A ans}(c : nat -> nat * A) : M A ans := fun k n => let (n',v) := c n in k v n'. (* Now that we have set up all of our monadic operations, the translation is pretty straightforward except for the [App_e] and [Lam_e] cases. For [App_e e1 e2] we would like to pass in [e2] followed by a continuation which computes the rest of the program, once [e1] has calculated its result. We start by capturing the current meta-level continuation using [letcc]. Ignoring the state, this meta-level continuation [k] is essentially a Coq function from [exp]s to [exp]s (i.e., a context!). We can turn it into an [exp] by generating a fresh variable [x] and by applying the meta-level continuation to [Var_e x]. But we have to make sure we thread the state through, hence the use of [run] to do the application. Finally, because we've already accounted for the rest of the program by reifying the current continuation, we [throw] the application of [e1] to [e2] and the continuation to avoid returning the whole thing again. In short, when we reified the meta-level continuation, we made a copy of the "stack". But we only need one copy, so we can throw away the other one. For [Lam_e x e], we need to transform the function to take an extra fresh argument [c]. Instead of invoking the current continuation, when translating the body [e], we want it to invoke [c]. So again, we [run] the translation, but feed it an initial meta-level continuation which applies the object-level continuation. Finally, we return the transformed lambda to the current meta-level continuation. *) Fixpoint ET (e:exp) : M exp exp := match e with | Var_e x => ret Var_e x | Unit_e => ret Unit_e | Pair_e e1 e2 => v1 <- ET e1 ; v2 <- ET e2 ; ret (Pair_e v1 v2) | Fst_e e => v <- ET e ; ret (Fst_e v) | Snd_e e => v <- ET e ; ret (Snd_e v) | App_e e1 e2 => v1 <- ET e1 ; v2 <- ET e2 ; letcc k in x <- freshVar "_v" ; e <- run (k (Var_e x)) ; throw (App_e (App_e v1 v2) (Lam_e x e)) | Lam_e x e => c <- freshVar "_k" ; e' <- run (ET e (fun v n => (n,App_e (Var_e c) v))) ; ret (Lam_e x (Lam_e c e')) end. (** We kick off the entire translation by calling [ET] and passing it an initial continuation that corresponds to the identity function, and an initial state. *) Definition cps (g:G) : exp := snd (ET (gen g) (fun v n => (n,v)) 0). Eval compute in cps identity. Eval compute in cps apply. Eval compute in cps twice. Eval compute in cps pair. Eval compute in cps compose. Eval compute in cps eta_pair. End SOURCE_TO_SOURCE. (** This module defines a more refined intermediate language that captures the real structure of CPS translation. It is much closer to what a compiler uses where we make a distinction between "small" values (called [operands] here) versus "large" values that we want to bind to variables. This form of CPS intermediate language corresponds pretty closely to what was used in the SML/NJ compiler and facilitates lots of optimization. Notice the structure: none of the operations supports nested expressions. They only manipulate operands (which are either variables or constants.) Everything else has to be bound to variable using a "let". Note that an [cexp] is a sequence of let-declarations that is terminated in one of three ways: (1) A function call which passes in the argument and the continuation, (2) a function "return" which just invokes a continuation with the result, and (3) program termination. Finally, note that you could always embed this language into the original one, using an application of a lambda to encode the let's. *) Module CS153. Import EXP. Inductive operand := | Var_c : var -> operand | Unit_c : operand. Inductive decl := | Lam_c : var -> var -> cexp -> decl | Kont_c : var -> cexp -> decl | Pair_c : operand -> operand -> decl | Fst_c : operand -> decl | Snd_c : operand -> decl with cexp := | App_c : operand -> operand -> operand -> cexp | Ret_c : operand -> operand -> cexp | Exit_c : operand -> cexp | Let_c : var -> decl -> cexp -> cexp. (** This will make it a little easier to read the output of the CPS translation. *) Notation "'Let' x := e1 'in' e2" := (Let_c x e1 e2) (right associativity, at level 70). Notation "$ x" := (Var_c x) (at level 0). Implicit Arguments inl [A B]. Implicit Arguments inr [A B]. (** I want to show you the operational semantics for this language and point out that it no longer requires a stack. In particular, notice that I'm able to define the "step" function without appealing to any form of recursion. *) Definition env(A:Type) := list (var * A). Fixpoint lookup A (env:env A) (x:var) : option A := match env with | nil => None | (y,v)::env' => if string_dec x y then Some v else lookup env' x end. Inductive value : Type := | Unit_v : value | Pair_v : value -> value -> value | Lam_v : var -> var -> cexp -> env value -> value | Kont_v : var -> cexp -> env value -> value. Inductive answer : Type := | Value : value -> answer | TypeError : answer. Definition evalop (op:operand) (p: env value) : answer := match op with | Var_c x => match lookup p x with | Some v => Value v | None => TypeError end | Unit_c => Value Unit_v end. Definition evaldecl (d:decl) (p:env value) : answer := match d with | Lam_c x k e => Value (Lam_v x k e p) | Kont_c x e => Value (Kont_v x e p) | Pair_c op1 op2 => match evalop op1 p, evalop op2 p with | Value v1, Value v2 => Value (Pair_v v1 v2) | _, _ => TypeError end | Fst_c op => match evalop op p with | Value (Pair_v v1 v2) => Value v1 | _ => TypeError end | Snd_c op => match evalop op p with | Value (Pair_v v1 v2) => Value v2 | _ => TypeError end end. Definition step_fn (e : cexp) (p : env value) : answer + (cexp * env value) := match e with | Exit_c op => inl (evalop op p) | App_c op1 op2 op3 => match evalop op1 p, evalop op2 p, evalop op3 p with | Value (Lam_v x k e p'), Value v2, Value v3 => inr (e, (k,v3)::(x,v2)::p') | _, _, _ => inl TypeError end | Ret_c op1 op2 => match evalop op1 p, evalop op2 p with | Value (Kont_v x e p'), Value v => inr (e, (x,v)::p') | _, _ => inl TypeError end | Let_c x d e => match evaldecl d p with | Value v => inr (e, (x,v)::p) | _ => inl TypeError end end. Fixpoint eval(n:nat)(e:cexp)(p:env value) : answer := match n with | 0 => TypeError | S n => match step_fn e p with | inl a => a | inr (e,p) => eval n e p end end. (** The translation to this form is largely the same except that we are forced to let-bind some things that previously could be nested. *) Definition M(A ans:Type) := (A -> nat -> (nat * ans)) -> nat -> (nat * ans). Definition Ret{A ans:Type}(v:A) : M A ans := fun k n => k v n. Definition Bind{A B ans:Type}(c:M A ans)(f:A -> M B ans) : M B ans := fun k n => c (fun v n' => f v k n') n. Notation "'ret' x" := (Ret x) (at level 75). Notation "x <- c ; f" := (Bind c (fun x => f)) (right associativity, at level 84, c at next level). Definition freshVar{ans:Type}(s:string) : M var ans := fun k n => k (s ++ nat2string n) (1 + n). Definition callcc {A ans}(f : (A -> nat -> nat * ans) -> M A ans) := fun k n => f k k n. Notation "'letcc' x 'in' e" := (callcc (fun x => e)) (right associativity, at level 84). Definition run{A ans}(c : nat -> nat * A) : M A ans := fun k n => let (n',v) := c n in k v n'. Definition throw{A ans}(a:ans) : M A ans := fun k n => (n,a). (** Notice that our continuation is expecting us to feed it an [operand] so it can build a [cexp], but all we have is a [decl]. So we create a fresh variable [x], and let-bind the declaration to [x], feeding our continuation [Var_c x] to generate the entire expression. *) Definition letbind (d:decl) : M operand cexp := x <- freshVar "_x" ; letcc k in e <- run (k (Var_c x)) ; throw (Let_c x d e). (** So other than the let-binding, the translation is largely the same. *) Fixpoint ET (e:exp) : M operand cexp := match e with | Var_e x => ret Var_c x | Unit_e => ret Unit_c | Pair_e e1 e2 => v1 <- ET e1 ; v2 <- ET e2 ; letbind (Pair_c v1 v2) | Fst_e e => v <- ET e ; letbind (Fst_c v) | Snd_e e => v <- ET e ; letbind (Snd_c v) | App_e e1 e2 => v1 <- ET e1 ; v2 <- ET e2 ; letcc k in x <- freshVar "_v" ; e <- run (k (Var_c x)) ; c <- letbind (Kont_c x e) ; throw (App_c v1 v2 c) | Lam_e x e => c <- freshVar "_c" ; e' <- run (ET e (fun v n => (n,Ret_c (Var_c c) v))) ; letbind (Lam_c x c e') end. Definition cps (g:G) : cexp := snd (ET (gen g) (fun v n => (n, Exit_c v)) 0). Eval compute in cps identity. Eval compute in cps apply. Eval compute in cps twice. Eval compute in cps pair. Eval compute in cps compose. Eval compute in eval 100 (cps ((x ==> x) @ ())) nil. Eval compute in eval 100 (cps (twice @ identity @ ())) nil. End CS153. (** Exercises: - Add Callcc_e and Throw_e to the source language and extend the translations to handle these things. You shouldn't need to modify the target languages. - Add booleans, a conditional ([If_e]) expression, and short-circuited [And_e] and [Or_e] expressions to the language and extend one of the translations to cover these new constructs. Be careful! If you're not careful, you'll end up making a copy of the stack. And of course by short-circuited, I mean that you should only end up evaluating an expression if it's needed. You get more credit for avoiding unnecessary intermediate continuations. - Set up typing rules (after the fact) for expressions (e.g., [G |- e ; t]). Then set up a type translation [T : type -> type] which captures the invariants of the CPS translation. In particular, if [|- e : t], then we should have that [|- cps e : T t]. I don't expect you to prove this (because it's a pain in the ass with the fresh variable generation) but I do expect you be able to set up the definitions and theorems and reason through the argument on paper informally. - Extra credit: formally prove the cps translation respects your type translation. - Super Extra credit: formally prove that the cps translation respects the semantics of the original program. *)