# Inference: Simple Type Inference.

by G.Morrisett but strongly inspired by C. McBride's JFP 13(6) article.
Require Import Eqdep.
Require Import String.
Require Import List.
Require Import Omega.
Require Import Recdef.
Set Implicit Arguments.
Unset Automatic Introduction.
Axiom proof_irr : forall (P:Prop) (x y:P), x = y.

Here, we're going to develop a type-inference algorithm for a small, simply-typed language and prove its correctness. As before, we will try to use the "Chlipala" style of proof development to keep things robust to extension and change. And, we will show you how dependent types can become a crucial tool for constructing the termination argument of a function, and how Coq allows us to step outside of simple structural induction when we need to. In particular, we will see that writing unification, the key sub-routine in type- inference, requires a much more delicate treatment than is naturally apparent.

# Abstract Syntax

We'll represent type variables using natural numbers. This will give us an easy way to generate fresh type variables when we need them.
Definition tvar := nat.

For this little language, types will only include type variables, arrow types, and [nat].
Definition tvar_eq_dec (t1 t2:tvar) : {t1=t2} + {t1<>t2} := eq_nat_dec t1 t2.
Inductive type :=
| Tvar_t : tvar -> type
| Arrow_t : type -> type -> type
| Nat_t : type.

Our terms will have variables, numbers, lambdas, and applications. We'll use strings to represent program variables.
Definition var := string.
Inductive exp : Set :=
| Var_e : var -> exp
| Num_e : nat -> exp
| Lam_e : var -> exp -> exp
| App_e : exp -> exp -> exp.

A context associates a variable to a type -- we'll use this to help define our declarative type-checker.
Definition ctxt := list(var * type).

Now we can define our typing relation.
Reserved Notation "G |-- e ; t" (at level 80).

Inductive hasType : ctxt -> exp -> type -> Prop :=
| Var_ht : forall G x t, In (x,t) G -> G |-- Var_e x ; t
| Num_ht : forall G n, G |-- Num_e n ; Nat_t
| Lam_ht : forall G x e t1 t2,
((x,t1)::G) |-- e ; t2 -> G |-- Lam_e x e ; Arrow_t t1 t2
| App_ht : forall G e1 e2 t1 t2,
G |-- e1 ; (Arrow_t t1 t2) -> G |-- e2 ; t1 -> G |-- (App_e e1 e2) ; t2

where "G |-- e ; t" := (hasType G e t).

Equality for types -- we can convince Coq to generate the code for us.
Definition eq_type (t1 t2:type) : {t1 = t2} + {t1 <> t2}.
pose tvar_eq_dec. decide equality.
Defined.

A type context is just a list of type variables. We will use a tctxt to track which type variables are in scope, and thus when types are well-formed.
Definition tctxt := list tvar.

We'll now define a simplification tactic which takes care of some busywork that we would otherwise have to do in all of our proofs. I've broken this into a simple tactic s and a looping tactic mysimp which repeats s, does some simplification, and tries auto with arith. In general, I've chosen to break apart basic constructors such as /\, \/, and *, and to decompose comparisons such as tvar_eq_dec. I've also put in some tactics to simplify equalities on pairs, and equalities on options.
Ltac s :=
match goal with
| [ H : _ /\ _ |- _] => destruct H
| [ H : _ \/ _ |- _] => destruct H
| [ |- context[tvar_eq_dec ?a ?b] ] => destruct (tvar_eq_dec a b) ; subst ; try congruence
| [ |- context[eq_nat_dec ?a ?b] ] => destruct (eq_nat_dec a b) ; subst ; try congruence
| [ x : (tvar * type)%type |- _ ] => let t := fresh "t" in destruct x as [x t]
| [ x : (var * type)%type |- _ ] => let t := fresh "t" in destruct x as [x t]
| [ H : (_,_) = (_,_) |- _] => inversion H ; clear H
| [ H : Some _ = Some _ |- _] => inversion H ; clear H
| [ H : Some _ = None |- _] => congruence
| [ H : None = Some _ |- _] => congruence
| [ |- _ /\ _] => split
| [ H : ex _ |- _] => destruct H
| [ H : context[string_dec ?a ?b] |- _] => destruct (string_dec a b) ; subst ; try congruence
end.

Ltac mysimp := (repeat (simpl; s)) ; simpl; auto with arith.

Membership in a type context.
Fixpoint Mem (D:tctxt) (x:tvar) : Prop :=
match D with
| nil => False
| h::t => if tvar_eq_dec h x then True else Mem t x
end.

Membership is decidable. Note that here we use the refine tactic to construct the skeleton of the function, leaving the proofs to be filled in. We then chain refine with an application of the mysimp tactic which takes care of discharging all of the predicates.
Definition member (D:tctxt) (x:tvar) : {Mem D x} + {~Mem D x}.
refine (fix member D x : {Mem D x} + {~Mem D x} :=
match D with
| nil => right _ _
| h::t => match tvar_eq_dec h x with
| left Heq => left _ _
| right Hneq =>
match member t x with
| left H => left _ _
| right H => right _ _
end
end
end) ; mysimp.
Defined.

Well-formed types.
Fixpoint WfType (D:tctxt) (t:type) : Prop :=
match t with
| Tvar_t x => Mem D x
| Arrow_t t1 t2 => WfType D t1 /\ WfType D t2
| Nat_t => True
end.

A predicate for checking whether a variable x occurs free in a type t.
Fixpoint Occurs (x:tvar) (t:type) : Prop :=
match t with
| Tvar_t y => if tvar_eq_dec x y then True else False
| Arrow_t t1 t2 => Occurs x t1 \/ Occurs x t2
| Nat_t => False
end.

Occurs is decidable. Once again, we take advantage of the refine tactic.
Definition occurs : forall (x:tvar) (t:type), {Occurs x t} + {~Occurs x t}.
refine (fix occurs x (t:type) {struct t} : {Occurs x t} + {~Occurs x t} :=
match t return {Occurs x t} + {~Occurs x t} with
| Tvar_t y => match tvar_eq_dec x y with
| left Heq => left _ _
| right Hneq => right _ _
end
| Nat_t => right _ _
| Arrow_t t1 t2 =>
match occurs x t1 with
| left Hocc1 => left _ _
| right Hnocc1 =>
match occurs x t2 with
| left Hocc2 => left _ _
| right Hnocc2 => right _ _
end
end
end) ; mysimp ; firstorder.
Defined.

Substitute t1 for x in t2.
Fixpoint sub (t1:type) (x:tvar) (t2:type) : type :=
match t2 with
| Tvar_t y => if tvar_eq_dec x y then t1 else Tvar_t y
| Arrow_t ta tb => Arrow_t (sub t1 x ta) (sub t1 x tb)
| Nat_t => Nat_t
end.

If t is well-formed with respect to D, and x is in D, then if we substitute a type u for x and u is well-formed with respect to D - x, then we get a type that is well-formed with respect to D - x. In short, we eliminate the variable x.
Fixpoint remove (x:tvar) (D:tctxt) : tctxt :=
match D with
| nil => nil
| y::rest => if tvar_eq_dec y x then (remove x rest) else y::(remove x rest)
end.

Lemma SubstRemove' : forall t x D, Mem D t -> x <> t -> Mem (remove x D) t.
induction D ; mysimp.
Qed. Hint Resolve SubstRemove'.

Lemma SubstRemove :
forall t x D, WfType D t -> Mem D x -> forall u, WfType (remove x D) u ->
WfType (remove x D) (sub u x t).
induction t ; simpl ; intros ; mysimp.
Qed. Hint Resolve SubstRemove.

A substitution maps a type variable to a type.
Definition substitution := list(tvar*type).

The support is simply the domain of the substitution.
Definition support(s:substitution) : list tvar := List.map (@fst tvar type) s.

Well-formed substitution -- prevents duplicates for type variables, and ensures that the types are well-formed with respect to D as we move down the list.
Fixpoint WfSubst (D:tctxt) (s:substitution) : Prop :=
match s with
| nil => True
| (x,t)::rest => Mem D x /\ WfType (remove x D) t /\ WfSubst (remove x D) rest
end.

Apply a substitution to a type
Fixpoint substs (s:substitution) (t:type) : type :=
match s with
| nil => t
| (x,u)::rest => substs rest (sub u x t)
end.

Remove a list of type variables from a type context
Fixpoint minus (D:tctxt) (xs:list tvar) :=
match xs with
| nil => D
| x::xs => remove x (minus D xs)
end.

Lemma RemoveComm : forall x y D, remove x (remove y D) = remove y (remove x D).
induction D ; mysimp.
Qed. Hint Immediate RemoveComm.

Lemma MinusRemove : forall D2 D1 x, minus (remove x D1) D2 = remove x (minus D1 D2).
induction D2 ; mysimp ; intros ; rewrite IHD2 ; auto.
Qed.

If s is a substitution and t a type that are well-formed with respect to D, then applying s to t yields a type well-formed with respect to D - support s.
Lemma SubstsRemove :
forall s D, WfSubst D s -> forall t,WfType D t -> WfType (minus D (support s)) (substs s t).
Proof.
induction s ; mysimp ; intros ; mysimp.
generalize (IHs (remove a D)) ; rewrite MinusRemove ; intros ; apply H3 ; auto.
Qed. Hint Resolve SubstsRemove.

Lemma MinusApp(D:tctxt) s t t0 :
minus D (support (s ++ (t,t0)::nil)) = remove t (minus D (support s)).
Proof.
induction s ; mysimp ; intros ; mysimp ; rewrite IHs ; auto.
Qed.

Equations for multi-substitution over types
Lemma SubstArrow (s:substitution) (t1 t2:type) :
substs s (Arrow_t t1 t2) = Arrow_t (substs s t1) (substs s t2).
Proof.
induction s ; mysimp.
Qed.

Lemma SubstNat (s:substitution) : substs s Nat_t = Nat_t.
induction s ; mysimp.
Qed.

Lemma SubstEnd s x u t : substs (s ++ (x,u)::nil) t = sub u x (substs s t).
induction s ; mysimp.
Qed.

Lemma SubstAppend s2 s1 t : substs (s1 ++ s2) t = substs s2 (substs s1 t).
induction s2 ; intros ; simpl. rewrite <- app_nil_end ; auto.
assert (s1 ++ a :: s2 = (s1 ++ (a::nil)) ++ s2).
rewrite app_ass ; auto. rewrite H. destruct a. rewrite (IHs2 (s1 ++ (t0,t1)::nil)).
rewrite <- SubstEnd. auto.
Qed.

Lemma SubstsNilId(t:type) : substs nil t = t.
induction t ; auto.
Qed.

# Unification

We would like to write a unification procedure that takes two types t1 and t2 and produces an (optional) substitution s such that when we apply s to t1 and t2, it equates them. We might try to write the following:

Fixpoint unify (t1 t2:type) : option substitution =
if eq_type t1 t2 then (Some nil) else
match t1, t2 with
| Tvar_t x, t2 =>
if occurs x t2 then None else Some [(x,t2)]
| t1, Tvar_t x =>
if occurs x t1 then None else Some [(x,t1)]
| Arrow_t t11 t12, Arrow_t t21 t22 =>
match unify t11 t21 with
| None => None
| Some s1 =>
match unify (substs s1 t12) (substs s1 t22) with
| None => None
| Some s2 => Some (s1 ++ s2)
end
end
| _, _ => None
end

Unfortunately, Coq will reject this because the recursive call to unify (substs s1 t12) (substs s1 t22) isn't structurally decreasing. Indeed, after substitution, t12 and t22 could be bigger than what we started with.

So does this algorithm terminate? The answer is "yes", but depends crucially on the occurs check. The thing to note is that once we map a type variable x to a type t, then this eliminates the variable x, but only if x does not occur in t.

Thus, to argue that unify terminates, informally, we need to show on each recursive call that either the type gets smaller, or else the number of type variables that can occur free in the types we are unifying has gotten smaller.

Unfortunately, Coq doesn't make it easy to write functions that have this kind of termination argument. We're going to take advantage of the Wellfounded library which allows us to build recursive computations as long as we can show that the recursive calls respect some partial order that has no infinite descending chains. In our case, we will pick an order on pairs (D,t) where D is a list of type variables and t is a type. The order will compare pairs lexicographically, so that if the length of the list of variables is smaller, or if the length of the lists is the same but the type is shorter, then the relation holds.

I should note that you can sometimes use the Function or Program constructs to convince Coq to do this work for you. However, in my experience reasoning about the generated code is harder. So I prefer to do this kind of thing by hand.
Require Import Relation_Operators.
Require Import Transitive_Closure.
Require Import Wellfounded.Lexicographic_Product.
A tpair is really just a tctxt*type, but the lexographic ordering provided by the library supports and expects a generalized dependent pair.
Definition tpair := sigT (fun _:tctxt => type).
Definition get_ctxt (Dt : tpair) : tctxt := let (D,_) := Dt in D.
Definition get_type (Dt : tpair) : type := let (_,t) := Dt in t.
Definition make_tpair (D:tctxt) (t:type) : tpair := existT _ D t.

Definition mmax (n m:nat) := if le_lt_dec n m then m else n.

Fixpoint height(t:type) : nat :=
match t with
| Arrow_t ta tb => 1 + mmax (height ta) (height tb)
| _ => 0
end.

Lexographic ordering on tctxt*type pairs.
Definition tpair_lt : tpair -> tpair -> Prop :=
lexprod tctxt (fun _ => type)
(fun (x y:tctxt) => length x < length y)
(fun (x:tctxt) (t u:type) => height t < height u).

A proof that the ordering is well-founded: This takes advantage of the Wellfounded library which already has results for natural numbers, and lexicographical orderings.
Definition well_founded_tpair_lt : well_founded tpair_lt :=
@wf_lexprod tctxt (fun _:tctxt => type) (fun (x y:tctxt) => length x < length y)
(fun (x:tctxt) (t u:type) => height t < height u)
(well_founded_ltof tctxt (@length tvar))
(fun _ => well_founded_ltof type height).

One consequence of defining unification in terms of this well-founded ordering is that we need to make sure we are always manipulating well-formed types and substitutions with respect to a type context D, else we won't be able to argue that the number of type variables gets smaller when we do a substitution.

So we begin by defining a notion of a well-formed subsitution, indexed by a type context D. This will be what unification ultimately returns to make sure we can keep the recursion going.
Definition wf_subst(D:tctxt) := sigT (fun s:substitution => WfSubst D s).

Glueing together two well-formed substitutions.
Lemma WfSubstLast x t (s:substitution) (D:tctxt) : WfSubst D s ->
Mem (minus D (support s)) x -> WfType (remove x (minus D (support s))) t ->
WfSubst D (s ++ (x,t)::nil).
Proof.
induction s ; simpl ; intros ; mysimp.
apply (IHs (remove a D)) ; auto ; rewrite MinusRemove ; auto.
Qed. Hint Resolve WfSubstLast.

Lemma AppCons(A:Type) : forall (s1 s2:list A) x, s1 ++ x::s2 = (s1 ++ x::nil) ++ s2.
intros ; rewrite app_ass ; auto.
Qed.

Lemma WfSubstAppend(D:tctxt)(s2 s1:substitution) :
WfSubst D s1 -> WfSubst (minus D (support s1)) s2 -> WfSubst D (s1 ++ s2).
Proof.
induction s2 ; simpl ; intros. rewrite <- app_nil_end ; auto.
mysimp. rewrite AppCons. apply IHs2. auto. rewrite MinusApp ; auto.
Qed. Hint Resolve WfSubstAppend.

Facts about lengths of typing contexts needed to discharge verification conditions below.
Lemma LenRem : forall x D, ~ Mem D x -> length (remove x D) = length D.
induction D ; auto ; mysimp ; tauto.
Qed.

Lemma LengthRemove : forall x D, Mem D x -> length (remove x D) < length D.
induction D ; simpl ; try tauto ; mysimp ; intros. destruct (member D x) ; auto with arith ;
match goal with [ H : ~Mem D x |- _ ] =>
rewrite (LenRem _ _ H) ; auto with arith
end.
Qed.

Lemma RemoveLte : forall x D, length(remove x D) <= length D.
induction D ; mysimp.
Qed.

Lemma MinusLte : forall D xs, length(minus D xs) <= length D.
induction xs ; mysimp ; pose (RemoveLte a (minus D xs)) ; omega.
Qed.

Lemma MemRemv : forall x y D, x <> y -> Mem (remove x D) y = Mem D y.
induction D ; mysimp.
Qed.

Lemma MemMinus' : forall a x D, Mem (remove x D) a -> x <> a.
induction D ; mysimp.
Qed.

Lemma MemMinus x s D : Mem D x -> WfSubst (remove x D) s -> Mem (minus D (support s)) x.
induction s ; mysimp ; intros ; mysimp.
rewrite <- MinusRemove. eapply IHs. assert (x <> a) ; [ eapply MemMinus' ; auto | idtac ].
eauto. rewrite MemRemv in * ; auto. rewrite RemoveComm ; auto.
Qed. Hint Resolve MemMinus.

This is a critical little lemma that connects the occurs check with well-formedness. It tells us that when x doesn't occur in t but t is well-formed with respect to D, then t is also well-formed with respect to D - x.
Lemma OccursWf x D t : WfType D t -> ~Occurs x t -> WfType (remove x D) t.
induction t ; mysimp ; try tauto.
Qed. Hint Resolve OccursWf.

Lemmas about the height of types
Lemma HeightLeftLess : forall t1 t2, height t1 < height (Arrow_t t1 t2).
intros ; simpl ; unfold mmax ; destruct (le_lt_dec (height t1) (height t2)) ;
auto with arith.
Qed. Hint Resolve HeightLeftLess.

Lemma HeightRightLess : forall t1 t2, height t2 < height (Arrow_t t1 t2).
intros ; simpl ; unfold mmax ; destruct (le_lt_dec (height t1) (height t2)) ;
auto with arith.
Qed. Hint Resolve HeightRightLess.

This lemma which shows that for well-formed substitutions s1 and s2, if s1 equates two types, then s1++s2 does as well. This means that we can always extend a substitution with additional constraints and we will continue to equate types.
Lemma SubstExtends :
forall s1 D, WfSubst D s1 ->
forall s2, WfSubst (minus D (support s1)) s2 ->
forall t1 t2, substs s1 t1 = substs s1 t2 ->
substs (s1 ++ s2) t1 = substs (s1 ++ s2) t2.
Proof.
induction s1 ; simpl ; intros ; subst ; auto. mysimp.
eapply IHs1 ; eauto. rewrite MinusRemove ; auto.
Qed.

If the type is smaller, then the pair is smaller
Lemma TpairLtArrowLeft D t11 t12 :
tpair_lt (make_tpair D t11) (make_tpair D (Arrow_t t11 t12)).
Proof.
unfold tpair_lt ; intros. eapply right_lex ; auto.
Qed.

Lemma TpairLtArrowRight D t11 t12 :
tpair_lt (make_tpair D t12) (make_tpair D (Arrow_t t11 t12)).
Proof.
unfold tpair_lt ; intros. eapply right_lex ; auto.
Qed.

If the context is smaller, then the pair is smaller
Lemma TpairLtSub s D t1 t2 :
WfSubst D s -> tpair_lt (make_tpair (minus D (support s)) (substs s t2))
(make_tpair D (Arrow_t t1 t2)).
Proof.
intros. destruct s ; simpl ; mysimp. apply TpairLtArrowRight.
simpl in * ; mysimp. pose (LengthRemove p (minus D (support s))).
assert (Mem (minus D (support s)) p). auto.
eapply left_lex. pose (l H2). pose (MinusLte D (support s)). omega.
Qed.

This is an abbreviation that will help write unification.
Definition unify_return_type(tp:tpair) :=
forall t2, WfType (get_ctxt tp) (get_type tp) -> WfType (get_ctxt tp) t2 ->
option (wf_subst (get_ctxt tp)).

The main unification loop body: The body is parameterized by a tpair, that is (D,t1) and a function unify that can be invoked on any tpair smaller than (D,t1) and returns a function which when given t2, and proofs that t1 and t2 are well-formed with respect to D, returns an optional well-formed substitution with respect to D.
Definition unify_body (tp : tpair)
(unify : forall (tp2 : tpair),
tpair_lt tp2 tp -> unify_return_type tp2)
: unify_return_type tp.
intros tp unify t2.
destruct tp as [D t1].
refine (

match eq_type t1 t2 return WfType D t1 -> WfType D t2 -> option (wf_subst D) with
| left Heq => fun H1 H2 => Some (@existT substitution _ nil I)
| right Hneq =>

match t1 as t1', t2 as t2' return
t1 = t1' -> t2 = t2' -> WfType D t1' -> WfType D t2' -> option (wf_subst D)
with
| Arrow_t t11 t12, Arrow_t t21 t22 =>
fun H0 H1 H2 H3 =>

match unify (make_tpair D t11) _ t21 _ _ with
| None => None
| Some (existT s1 Hs1wf) =>

match unify (make_tpair (minus D (support s1))
(substs s1 t12)) _ (substs s1 t22) _ _ with
| None => None
| Some (existT s2 Hs2wf) =>

Some (@existT _ _ (s1 ++ s2) _)
end
end
| Tvar_t x, t2 => fun H0 H1 H2 H3 =>

match occurs x t2 with
| left H4 => None
| right H4 => Some (existT _ ((x,t2)::nil) _)
end
| t1, Tvar_t x => fun H0 H1 H2 H3 =>
match occurs x t1 with

| left H4 => None
| right H4 => Some (existT _ ((x,t1)::nil) _)
end
| _, _ => fun _ _ _ _ => None
end (refl_equal t1) (refl_equal t2)
end
) ; simpl in * ; mysimp ; subst ;
(apply TpairLtArrowLeft || apply TpairLtSub ; auto).
Defined.

We tie the knot for unify using Fix_F which demands an accessibility proof for our initial tp. But this is easy since we've already shown all tpairs are accessible.
Finally, we can define unify the way we really wanted to by currying the initial type pair.
Definition unify D t := unify_tp (well_founded_tpair_lt (make_tpair D t)).

# Proving Unification Correct.

We've managed to write a terminating unify function that Coq accepts. Now we need an induction principle so that we can reason about it. We'll start by building a generic induction principle for tpairs, again simply using the Wellfounded library.
Lemma tpair_ind :
forall (P: tpair -> Prop),
(forall tp, (forall tp2, tpair_lt tp2 tp -> P tp2) -> P tp) ->
forall tp, P tp.
Proof.
intros. apply (@Acc_ind tpair tpair_lt P) ; auto. apply well_founded_tpair_lt.
Qed.

This is the statement of correctness for unify -- if we get a substitution back, then if we apply it to the two types, we get back the same type.
Definition unify_equates (tp:tpair) : Prop :=
forall t2 H1 H2,
match unify_tp (well_founded_tpair_lt tp) t2 H1 H2 with
| None => True
| Some (existT s H3) => substs s (get_type tp) = substs s t2
end.

Lemma AccTpair D1 t1 D2 t2 H :
Acc_inv (well_founded_tpair_lt (existT (fun _ : tctxt => type) D1 t1))
(make_tpair D2 t2) H = well_founded_tpair_lt (make_tpair D2 t2).
Proof.
intros. apply proof_irr.
Qed.

Lemma SubOcc : forall t x u, ~Occurs x u -> sub t x u = u.
induction u ; mysimp ; intros ; try firstorder. congruence.
Qed.

Ltac usimp := repeat (idtac ;
match goal with
| [ |- context[occurs ?a ?b] ] => destruct (occurs a b) ; mysimp
| [ H : ~ Occurs ?x ?t |- context[sub _ ?x ?t] ] => rewrite SubOcc ; auto
| [ |- context[Acc_inv _ _ _] ] => rewrite AccTpair
| [ HU : unify_equates _ |-
context[Fix_F unify_return_type unify_body ?X ?Y ?Ha ?Hb]] =>
let H := fresh "H" in
let y := fresh "y" in
generalize (HU Y Ha Hb) ; unfold unify_tp ;
assert (H : exists x, Fix_F unify_return_type unify_body X Y Ha Hb = x) ;
[ eauto | destruct H as [y H]] ; rewrite H ; destruct y ; auto
| [ w : wf_subst _ |- _ ] => destruct w
end).

The main lemma -- showing that unify is correct. Notice that we factored out much of the reasoning to a helper tactic usimp. We could probably do a better job of writing this so that it's robust to change.
Lemma UnifyEquates_tp :
forall tp, unify_equates tp.
Proof.
apply (tpair_ind unify_equates) ; intros ; unfold unify_equates ;
destruct tp as [D t1] ; simpl ; intros ; unfold unify_tp ;
rewrite <- Fix_F_eq ; simpl ;
destruct (eq_type t1 t2) ; subst ; auto ; destruct t1 ; usimp ; auto ;
destruct t2 ; auto ; usimp ; mysimp.
pose (H (make_tpair D t1_1) (TpairLtArrowLeft _ _ _)) ; usimp ; intro.
assert (unify_equates (make_tpair (minus D (support x)) (substs x t1_2))).
apply H. eapply TpairLtSub ; auto. usimp ; auto. intros ; mysimp.
simpl in *. repeat rewrite SubstArrow. repeat rewrite <- SubstAppend in *.
rewrite (SubstExtends _ D w _ w0 _ _ H3). congruence.
Qed.

A slightly nicer version of the correctness result for unification.
Theorem UnifyEquates : forall D t1 t2 H1 H2 s H3,
unify D t1 t2 H1 H2 = Some (@existT _ _ s H3) -> substs s t1 = substs s t2.
Proof.
intros ; pose (UnifyEquates_tp (make_tpair D t1) t2 H1 H2) ;
unfold unify in H ; rewrite H in y ; auto.
Qed.

# Type-Inference in Two Pieces: Constraint Generation and Constraint Solving

Now we're going to build a type-checker that takes advantage of unification. Usually, the type-checker eagerly unifies types as it goes, but for our purposes, it will be a little easier (and cleaner) to break inference into two pieces: constraint generation, and constraint solving.

Constraint generation is actually a quite straightforward process over the abstract syntax of an expression. But we need to be able to (a) track a set of equality constraints, and (b) be able to generate fresh type variables. To encapsulate this sort of state, we will write our type-checker using a combination of a state monad and the option monad. The state will allow us to generate fresh type variables, and to record a set of generated equational constraints which must be satisified in order for the term to type-check. The option will be used to track when a type-error occurs.
Record state := mkState {
st_next_tvar : tvar ;
st_D : list tvar ;
st_constraints : list (type * type)
}.

Our monad definition: An M A is a function which when given a state record, returns an optional state record and A value.
Definition M(A:Type) := state -> option(state * A).

The return for the monad -- I think of this as lifting a pure value into an effectful computation.
Definition ret(A:Type)(x:A) : M A := fun s => Some(s,x).

The bind operation for the monad -- this is just sequential composition of two effectful computations. Notice that the state that comes out of the first computation gets fed into the second computation. So the user never has to worry about plumbing the state into functions.
Definition bind(A B:Type)(c1:M A)(c2: A -> M B) : M B :=
fun s1 =>
match c1 s1 with
| None => None
| Some (s2,v) => c2 v s2
end.

Some handy notation for bind lets us duplicate Haskell-style "do-notation".
Notation "x <- c1 ; c2" := (bind c1 (fun x => c2))
(right associativity, at level 84, c1 at next level).

For our specific monad, we have a failure operation which just returns None.
Definition fail(A:Type) : M A := fun s => None.

To generate a fresh variable, we increment the counter and add the variable to our list of generated type variables.
Definition fresh_tvar : M type :=
fun s =>
match s with
mkState n ts c => Some (mkState (1+n) (ts ++ n::nil) c, Tvar_t n)
end.

This just adds a pair of types that are meant to be equated to our list of constraints.
Definition add_constr(t1 t2:type) : M unit :=
fun s =>
match s with
mkState n ts c => Some (mkState n ts ((t1,t2)::c), tt)
end.

This is a monad command that tries to look up the type of a variable in a context, failing if the variable isn't found.
Fixpoint look(x:var)(G:ctxt) : M type :=
match G with
| nil => fail _
| (y,t)::rest => if string_dec x y then ret t else look x rest
end.

Finally, we can generate the constraints with this nice little definition. Note that if we tried to eagerly unify the constraints, then we'd need to use a much more complicated definition to track the fact that the context and types are well-formed with respect to the list of generated type variables. This is possible, but quite a bit more verbose. In general, I find that in Coq, it's best to avoid dependent types if you can, and rely on "after-the-fact" proving. But as with unify, sometimes there's no avoiding it.
Fixpoint gen_constraints(G:ctxt)(e:exp) : M type :=
match e with
| Var_e x => look x G
| Num_e _ => ret Nat_t
| Lam_e x e =>
t1 <- fresh_tvar ;
t2 <- gen_constraints ((x,t1)::G) e ;
ret (Arrow_t t1 t2)
| App_e e1 e2 =>
t1 <- gen_constraints G e1 ;
t2 <- gen_constraints G e2 ;
t <- fresh_tvar ;
_ <- add_constr t1 (Arrow_t t2 t) ;
ret t
end.

# Generated Constraints are Well-Formed

Now we want to show that the constraints generated by gen_constraints are well-formed with respect to the list of type variables D that we accumulate in the state. This will allow us to call unify on all of the pairs of constraints we accumulated.

Various lemmas about weakening well-formedness proofs.
Lemma RemoveDist x D1 D2 : remove x (D1 ++ D2) = (remove x D1) ++ (remove x D2).
Proof.
induction D1 ; mysimp ; intros ; rewrite IHD1 ; auto.
Qed.

Lemma WfTypeEnd t a D : WfType D t -> WfType (D ++ a::nil) t.
Proof.
induction t ; mysimp ; intros ; mysimp ; generalize H ; induction D ; simpl ;
[ tauto | mysimp].
Qed. Hint Resolve WfTypeEnd.

Lemma WfTypeWeaken t D2 D1 : WfType D1 t -> WfType (D1 ++ D2) t.
Proof.
induction D2 ; intros ; [ rewrite <- app_nil_end ; auto | rewrite AppCons ; auto ].
Qed. Hint Resolve WfTypeWeaken.

Lemma WfTypeRemoveWeaken t x D : WfType (remove x D) t -> WfType D t.
Proof.
induction t ; auto ; simpl ; [ induction D ; simpl ; mysimp | intros ; mysimp ;
firstorder ].
Qed. Hint Resolve WfTypeRemoveWeaken.

Lemma WfTypeMinusWeaken t D2 D1 : WfType (minus D1 D2) t -> WfType D1 t.
Proof.
induction D2 ; eauto.
Qed.

Lemma MemWeaken : forall a D1 D2, Mem D1 a -> Mem (D1 ++ D2) a.
induction D1 ; simpl in * ; intros ; [ tauto | mysimp ].
Qed.

Lemma WfSubstWeaken s D1 D2 : WfSubst D1 s -> WfSubst (D1 ++ D2) s.
Proof.
induction s ; auto ; simpl ; intros ; mysimp ; [ apply MemWeaken |
rewrite RemoveDist ; apply WfTypeWeaken | rewrite RemoveDist ] ; auto.
Qed.

This defines the notion of a well-formed variable context G with respect to a type variable context D.
Fixpoint WfCtxt(D:tctxt)(G:ctxt) : Prop :=
match G with
| nil => True
| (x,t)::rest => WfType D t /\ WfCtxt D rest
end.

Lemma WfCtxtWeaken(D1 D2:tctxt)(G:ctxt) : WfCtxt D1 G -> WfCtxt (D1 ++ D2) G.
Proof.
induction G ; auto ; mysimp ; intros ; mysimp.
Qed. Hint Resolve WfCtxtWeaken.

And this defines the notion of a well-formed list of constraints with respect to a type variable context D.
Fixpoint WfConstr(D:tctxt)(cs:list(type*type)) : Prop :=
match cs with
| nil => True
| (t1,t2)::rest => WfType D t1 /\ WfType D t2 /\ WfConstr D rest
end.

Lemma WfConstrWeaken D1 cs : WfConstr D1 cs -> forall D2, WfConstr (D1 ++ D2) cs.
Proof.
induction cs. auto. mysimp ; destruct a ; intros ; mysimp.
Qed.

Lemma MemEnd : forall D x, Mem (D ++ x::nil) x.
induction D ; intros ; mysimp.
Qed. Hint Resolve MemEnd.

Lemma MemMid : forall x D1 D2, Mem ((D1 ++ x::nil)++D2) x.
induction D1 ; intros ; mysimp.
Qed. Hint Resolve MemMid.

To solve the constraints, we need to call unify. But unify demands that it be given well-formed types with respect to some typing context D in order to ensure termination. So this is a big lemma that establishes the well-formedness of the constraints generated by gen_constraints. It tells us that if we start with a well-formed context G with respect to the type context st_D (which is in the initial state), and start with well-formed constraints in the state, then gen_constraints will only extend st_D and the constraints in the state, and the constraints will continue to be well-formed, and the resulting type will be well-formed. We need all of these pieces to get the induction to go through.
Lemma GenWf : forall e G s1 s2 t,
gen_constraints G e s1 = Some (s2, t) ->
WfCtxt (st_D s1) G ->
WfConstr (st_D s1) (st_constraints s1) ->
(exists D2, (st_D s2) = (st_D s1) ++ D2) /\
(exists c2, (st_constraints s2) = c2 ++ (st_constraints s1)) /\
WfConstr (st_D s2) (st_constraints s2) /\
WfType (st_D s2) t.
Proof.
Ltac gen_simp :=
repeat subst ; unfold bind, fresh_tvar, ret, fail in * ;
match goal with
| [ IH : forall _ _ _ _, gen_constraints _ ?e _ = _ -> _,
H : context[gen_constraints ?G ?e ?s] |- _ ] =>
generalize (IH G s) ; clear IH ;
destruct (gen_constraints G e s) ; intros ; try congruence
| [ p : (_ * _)%type |- _ ] => destruct p
| [ H : forall _ _, Some _ = Some _ -> _ |- _ ] =>
generalize (H _ _ (refl_equal _)) ; clear H ; intro H
| [ s : state |- _ ] => destruct s ; simpl in *
| [ H : exists _, _ |- _ ] => destruct H ; simpl in *
| [ H : Some _ = Some _ |- _] => inversion H ; clear H ; subst
| [ H1 : WfCtxt ?D1 ?G,
H2 : WfConstr (?D1 ++ ?D2) ?cs,
H3 : WfCtxt (?D1 ++ ?D2) ?G -> WfConstr (?D1 ++ ?D2) ?cs -> _ |- _] =>
generalize (H3 (WfCtxtWeaken _ _ _ H1) H2) ; clear H3 ; intros ; mysimp ; subst ;
simpl ; repeat rewrite app_ass ; eauto
| [ H1 : WfCtxt ?D ?G,
H2 : WfConstr ?D ?cs,
H3 : WfCtxt ?D ?G -> WfConstr ?D ?cs -> _ |- _ ] =>
generalize (H3 H1 H2) ; clear H3 ; intros ; mysimp ; subst ; eauto
| [ H : _ -> _ -> ?P |- _ ] =>
assert P; [ eapply H | idtac ] ; mysimp ; subst ; eauto ;
[ eapply WfConstrWeaken ; auto | rewrite app_ass ; eauto ] ; fail
| [ |- exists _, ?p :: ?x1 ++ ?x ++ ?s = _ ++ ?s ] =>
exists (p :: x1 ++ x) ; simpl ; rewrite app_ass ; eauto
| [ H : context[string_dec ?v1 ?v2] |- _ ] => destruct (string_dec v1 v2)
| [ H : None = Some _ |- _ ] => congruence
| [ H : _ /\ _ |- _ ] => destruct H
end.
induction e ; simpl ; intros ; gen_simp ;
match goal with
| [ H : look _ ?G _ = _ |- _ ] =>
induction G ; simpl in * ; unfold fail in * ; try congruence ;
repeat gen_simp ; [ repeat split ; auto ; try (exists nil ; auto ;
rewrite <- app_nil_end ; auto) ; auto
| apply IHG ; tauto]
| _ => idtac
end ; repeat gen_simp ; mysimp ;
try (exists nil ; auto ; rewrite <- app_nil_end) ; auto ; rewrite <- app_ass ; auto ;
rewrite <- app_ass ; eapply WfConstrWeaken ; eauto.
Qed.

This just runs across a list of well-formed constraints, unifies them, and returns a final substitution. But it guarantees that the substitution is well-formed with respect to D.
Definition unify_constraints D cs : WfConstr D cs -> option (wf_subst D).
refine (fun D =>
fix unify_cs (cs:list(type*type)) : WfConstr D cs -> option (wf_subst D) :=
match cs with
| nil => fun H => Some (@existT substitution _ nil I)
| (t1,t2)::rest => fun H =>
match unify_cs rest _ with
| None => None
| Some (existT s1 Hs1) =>
match unify (minus D (support s1)) (substs s1 t1) (substs s1 t2) _ _ with
| None => None
| Some (existT s2 Hs2) => Some (@existT substitution _ (s1 ++ s2) _)
end
end
end) ; simpl in * ; mysimp.
Defined.

Finally, we can write the type-checker! Our type-checker first generates constraints, follwed by unifying the constraints, and finally applies the resulting substitution to the type computed by gen_constraints, returning that type.
Definition type_check(e:exp) : option type :=
let x := gen_constraints nil e (mkState 0 nil nil) in
match x as x' return (x = x') -> option type with
| None => fun _ => None
| Some (mkState _ D cs, t) =>
fun H =>
match unify_constraints D cs (proj1 (proj2 (proj2 (GenWf _ _ _ H I I)))) with
| None => None
| Some (existT s Hs) => Some (substs s t)
end
end (refl_equal x).

# Correctness of the Type Checker

Ltac tc_simp :=
match goal with
| [ s : state |- _ ] => destruct s ; simpl in *
| [ H : (_ * _)%type |- _ ] => destruct H
| [ IH : forall _ _ _ _, gen_constraints _ ?e _ = _ -> _,
H: context[gen_constraints ?G ?e ?s] |- _] =>
generalize (IH G s) ; clear IH ; destruct (gen_constraints G e s) ; try
congruence
| [ H : forall _ _, Some _ = Some _ -> _ |- _ ] => generalize (H _ _ (refl_equal _)) ;
clear H ; intro H
end.

Lemma GenConstrExtends : forall e G s1 s2 t,
gen_constraints G e s1 = Some (s2,t) ->
exists cs2, (st_constraints s2) = cs2 ++ (st_constraints s1).
Proof.
induction e ; simpl ; intros ; unfold bind, fresh_tvar, fail, ret in * ; mysimp ; subst ;
match goal with
| [ H : look _ ?G _ = _ |- _] =>
exists nil ; simpl ; induction G ; simpl in * ; unfold fail, ret in * ;
try congruence ; mysimp
| _ => idtac
end ;
repeat tc_simp ; mysimp ; subst ; intros ; repeat tc_simp ; mysimp ; simpl in * ; subst ;
eauto ;
match goal with
| [ |- exists _ , ?x = _ ++ ?x ] => exists nil
| [ |- exists _, ?a::?b::?c++?d++ _ = _ ] => exists (a::b::c++d) ; rewrite <- app_ass
| [ |- exists _, ?a::?c++?d++ _ = _ ] => exists (a::c++d) ; rewrite <- app_ass
end ; auto.
Qed.

Additional constraints don't change whether a unifier succeeds.
Lemma TCAddConstr : forall s c1 c2,
(forall t1 t2, In (t1,t2) (c2 ++ c1) -> substs s t1 = substs s t2) ->
(forall t1 t2, In (t1,t2) c1 -> substs s t1 = substs s t2).
Proof.
induction c2 ; auto ; mysimp.
Qed.

Lifting subsitution to contexts.
Fixpoint substs_ctxt (s:substitution) (G:ctxt) : ctxt :=
match G with
| nil => nil
| (x,t)::rest => (x,substs s t)::(substs_ctxt s rest)
end.

The declarative typing relation respects substitution.
Lemma HasTypeSubst : forall s G e t,G |-- e ; t -> (substs_ctxt s G) |-- e ; substs s t.
induction 1 ; intros ; try rewrite SubstNat in * ;
try rewrite SubstArrow ; econstructor ; auto.
induction G ; simpl in * ; mysimp.
rewrite <- SubstArrow ; eauto. auto.
Qed. Hint Resolve HasTypeSubst.

Lemmas on substitution for (term) variable contexts
Lemma SubstCtxtAppend G s2 s1 : substs_ctxt (s1++s2) G = substs_ctxt s2 (substs_ctxt s1 G).
induction G ; mysimp ; intros ; rewrite SubstAppend ; rewrite IHG ; auto.
Qed.

Lemma HasTypeAppend : forall G e t s1 s2, substs_ctxt s1 G |-- e ; substs s1 t ->
substs_ctxt (s1 ++ s2) G |-- e ; substs (s1 ++ s2) t.
Proof.
intros. rewrite SubstCtxtAppend. rewrite SubstAppend. eapply HasTypeSubst ; auto.
Qed.

We can add in "later" bits of a substitution with no effect on a type.
Lemma SubstIdem : forall t D s, WfType (minus D (support s)) t -> substs s t = t.
Proof.
induction t ; simpl in * ; intros. induction s ; auto ; destruct a ; simpl in *.
destruct (eq_nat_dec t0 t). subst. assert False. clear IHs. generalize H.
generalize (minus D (support s)). induction t0. auto. mysimp. contradiction.
mysimp. apply IHs. rewrite MemRemv in *; auto. rewrite SubstArrow ; simpl in H ; mysimp ;
rewrite (IHt1 D s); auto ; rewrite (IHt2 D s) ; auto. rewrite SubstNat ; auto.
Qed.

Lemma InWeaken : forall (A:Type) (x:A) S1 S2, In x S1 -> In x (S2 ++ S1).
Proof.
induction S2; mysimp ; intros ; try contradiction ; mysimp.
Qed. Hint Resolve InWeaken.

Typing is ensured under any substitution that respects the constraints generated by gen_constraints. This is the key correctness lemma for the type-checker -- it tells us that if we generate a set of constraints, and if we find a substitution that unifies those constraints, then if we apply that substitution to the context and type, we get a valid typing for the term we started with.
Lemma TC2corr1 : forall e G s1 s2 t,
gen_constraints G e s1 = Some (s2,t) ->
forall s, (forall t1 t2, In (t1,t2) (st_constraints s2) -> substs s t1 = substs s t2) ->
substs_ctxt s G |-- e ; substs s t.
Proof.
induction e ; simpl ; intros ; unfold bind, fresh_tvar, add_constr, ret, fail in * ;
mysimp ; subst ;
repeat match goal with
| [ |- _ |-- Var_e _ ; _ ] =>
constructor ; induction G ; simpl in * ; unfold ret, fail in * ; mysimp
| [ |- _ |-- Num_e _ ; _ ] => rewrite SubstNat ; constructor
| [ |- _ |-- Lam_e _ _ ; _ ] =>
repeat (repeat tc_simp ; intros ; mysimp ; subst) ;
rewrite SubstArrow ; econstructor ; eauto
| [ H : context[gen_constraints ?G ?e1 ?s1] |- _] =>
generalize (GenConstrExtends e1 G s1) ; tc_simp ; tc_simp ; tc_simp ; tc_simp
| _ => intros ; repeat (repeat tc_simp ; intros ; mysimp ; simpl in * ; subst)
end ;
match goal with
| [ H1 : forall _, _ -> _ |-- ?e1 ; substs _ ?t0,
H2 : forall _, _ -> _ |-- ?e2 ; substs _ ?t1,
H0 : forall _ _, _ -> substs _ _ = substs _ _ |-
_ |-- App_e ?e1 ?e2 ; substs ?s ?t ] =>
econstructor ; [ idtac | eauto ] ;
rewrite <- SubstArrow ;
let H := fresh "H" in
assert (H : substs s t0 = substs s (Arrow_t t1 t)) ;
try (apply H0 ; tauto ; fail) ; rewrite <- H ; eapply H1 ; eapply TCAddConstr ;
intros ; eapply H0 ; right ; eauto
end.
Qed.

Unifying a list of constraints results in a substitution that equates every pair in the constraints.
Lemma UCSEquates : forall cs D (H: WfConstr D cs),
match @unify_constraints D cs H with
| None => True
| Some (existT s Hs) =>
forall t1 t2, In (t1,t2) cs -> substs s t1 = substs s t2
end.
Proof.
induction cs ; simpl ; try tauto. destruct a. intros. mysimp.
generalize (IHcs D w1). destruct (unify_constraints D cs w1).
destruct w2. intros.
match goal with
| [ |- context[unify ?D ?t1 ?t2 ?H1 ?H2] ] =>
generalize (@UnifyEquates D t1 t2 H1 H2) ; destruct (unify D t1 t2 H1 H2)
end. intros. destruct w3. generalize (H0 _ _ (refl_equal _)) ; clear H0 ; intro.
intros. mysimp. subst. repeat rewrite SubstAppend. auto.
repeat rewrite SubstAppend. rewrite (H _ _ H1). auto. auto. auto.
Qed.

The type-checker is correct with respect to the hasType relation
Theorem type_check_correct : forall e t, type_check e = Some t -> nil |-- e ; t.
Proof.
unfold type_check. intros e t. generalize (GenWf e nil (mkState 0 nil nil)).
generalize (TC2corr1 e nil (mkState 0 nil nil)).
destruct (gen_constraints nil e (mkState 0 nil nil)) ; intros ; mysimp.
repeat tc_simp. generalize (a _ _ (refl_equal _) I I). intros. mysimp. simpl in *. subst.
match goal with
| [ H : context[unify_constraints ?a ?b ?c] |- _ ] =>
generalize (@UCSEquates b a c) ; destruct (unify_constraints a b c)
end ; intros ; mysimp. destruct w. mysimp.
Qed.

# Potential exercises:

1. Add new constructs to the types and the expressions of the language. For example, you could add Plus_e and Minus_e operations on Nat_t's, or booleans and an If_e-expression, or unit, etc. See what proofs break and try to generalize them so that they (a) cover the new cases, and (b) will hopefully cover more future cases.

2. Some of the lemmas towards the end are not appropriately "Adam-ized". Rewrite them so that they are.

3. We only established the soundness of the type-checker. Prove its completeness. In particular, you should show that unification finds a minimal substitution that satisfies the set of constraints.