From bce7eadfbc399ef9510bfdf1b19e369b737dbbfa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 23 Aug 2022 08:50:59 -0700 Subject: [PATCH] feat: add `LCNF/Check.lean` --- src/Lean/Compiler/LCNF/Basic.lean | 6 +- src/Lean/Compiler/LCNF/Check.lean | 161 ++++++++++++++++++++++++++ src/Lean/Compiler/LCNF/InferType.lean | 8 +- src/Lean/Compiler/LCNF/ToExpr.lean | 8 +- 4 files changed, 179 insertions(+), 4 deletions(-) create mode 100644 src/Lean/Compiler/LCNF/Check.lean diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 780c317b12..b135a2e487 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -73,7 +73,11 @@ structure Decl where -/ type : Expr /-- - The value of the declaration, usually changes as it progresses + Parameters. + -/ + params : Array Param + /-- + The body of the declaration, usually changes as it progresses through compiler passes. -/ value : Code diff --git a/src/Lean/Compiler/LCNF/Check.lean b/src/Lean/Compiler/LCNF/Check.lean new file mode 100644 index 0000000000..6d73785526 --- /dev/null +++ b/src/Lean/Compiler/LCNF/Check.lean @@ -0,0 +1,161 @@ +/- +Copyright (c) 2022 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Compiler.LCNF.InferType + +namespace Lean.Compiler.LCNF + +namespace Check +open InferType + +structure Context where + /-- Join points that are in scope. -/ + jps : FVarIdSet := {} + /-- Variables and local functions in scope -/ + vars : FVarIdSet := {} + +structure State where + /-- All free variables found -/ + all : FVarIdHashSet + +abbrev CheckM := ReaderT Context $ StateRefT State InferTypeM + +def checkFVar (fvarId : FVarId) : CheckM Unit := + unless (← read).vars.contains fvarId do + throwError "invalid out of scope free variable {fvarId.name}" + +def checkApp (f : Expr) (args : Array Expr) : CheckM Unit := do + unless f.isConst || f.isFVar do + throwError "unexpected function application, function must be a constant or free variable{indentExpr (mkAppN f args)}" + if f.isFVar then + checkFVar f.fvarId! + let mut fType ← inferType f + let mut j := 0 + for i in [:args.size] do + let arg := args[i]! + if fType.isAnyType then + return () + fType := fType.headBeta + let (d, b) ← + match fType with + | .forallE _ d b _ => pure (d, b) + | _ => + fType := fType.instantiateRevRange j i args |>.headBeta + match fType with + | .forallE _ d b _ => j := i; pure (d, b) + | _ => + if fType.isAnyType then return () + throwError "function expected at{indentExpr (mkAppN f args)}\narrow type expected{indentExpr fType}" + let argType ← inferType arg + let expectedType := d.instantiateRevRange j i args + unless compatibleTypes argType expectedType do + throwError "type mismatch at LCNF application{indentExpr (mkAppN f args)}\nargument {arg} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}" + unless isTypeFormerType expectedType || expectedType.erased do + unless arg.isFVar do + throwError "invalid LCNF application{indentExpr (mkAppN f args)}\nargument{indentExpr arg}\nmust be a free variable" + checkFVar arg.fvarId! + fType := b + +def checkExpr (e : Expr) : CheckM Unit := + match e with + | .lit _ => pure () + | .app .. => checkApp e.getAppFn e.getAppArgs + | .proj _ _ (.fvar fvarId) => checkFVar fvarId + | .mdata _ (.fvar fvarId) => checkFVar fvarId + | .const _ _ => pure () -- TODO: check number of universe level parameters + | .fvar fvarId => checkFVar fvarId + | _ => throwError "unexpected expression at LCNF{indentExpr e}" + +def checkJpInScope (jp : FVarId) : CheckM Unit := do + unless (← read).jps.contains jp do + /- + We cannot jump to join points defined out of the scope of a local function declaration. + For example, the following is an invalid LCNF. + ``` + jp_1 := fun x => ... -- Some join point + let f := fun y => -- Local function declaration. + ... + jp_1 _x.n -- jump to a join point that is not in the scope of `f`. + ``` + -/ + throwError "invalid jump to out of scope join point" + +def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do + checkExpr letDecl.value + let valueType ← inferType letDecl.value + unless compatibleTypes letDecl.type valueType do + throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}" + +def addFVarId (fvarId : FVarId) : CheckM Unit := do + if (← get).all.contains fvarId then + throwError "invalid LCNF, free variables are not unique `{fvarId.name}`" + modify fun s => { s with all := s.all.insert fvarId } + +@[inline] def withFVarId (fvarId : FVarId) (x : CheckM α) : CheckM α := do + addFVarId fvarId + withReader (fun ctx => { ctx with vars := ctx.vars.insert fvarId }) x + +@[inline] def withJp (fvarId : FVarId) (x : CheckM α) : CheckM α := do + addFVarId fvarId + withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x + +@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do + params.forM (addFVarId ·.fvarId) + withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId }) + x + +mutual + +partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do + let type ← withParams funDecl.params do + mkForallParams funDecl.params (← check funDecl.value) + unless compatibleTypes funDecl.type type do + throwError "type mismatch at `{funDecl.binderName}`, value has type{indentExpr type}\nbut is expected to have type{indentExpr funDecl.type}" + +partial def checkCases (c : Cases) : CheckM Expr := do + let mut ctorNames : NameSet := {} + let mut hasDefault := false + checkFVar c.discr + let discrType ← inferFVarType c.discr + let .const declName _ := discrType.headBeta.getAppFn | throwError "unexpected LCNF discriminant type {discrType}" + unless c.typeName == declName do + throwError "invalid LCNF `{c.typeName}.casesOn`, discriminant has type{indentExpr discrType}" + for alt in c.alts do + let type ← + match alt with + | .default k => hasDefault := true; check k + | .alt ctorName params k => + if ctorNames.contains ctorName then + throwError "invalid LCNF `cases`, alternative `{ctorName}` occurs more than once" + ctorNames := ctorNames.insert ctorName + let .ctorInfo val ← getConstInfo ctorName | throwError "invalid LCNF `cases`, `{ctorName}` is not a constructor name" + unless val.induct == c.typeName do + throwError "invalid LCNF `cases`, `{ctorName}` is not a constructor of `{c.typeName}`" + unless params.size == val.numFields do + throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives" + -- TODO: check whether the ctor field types as parameter types match. + withParams params do check k + unless compatibleTypes type c.resultType do + throwError "type mismatch at LCNF `cases` alternative\nhas type{indentExpr type}\nbut is expected to have type{indentExpr c.resultType}" + return c.resultType + +partial def check (code : Code) : CheckM Expr := do + match code with + | .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k + | .fun decl k => + -- Remark: local function declarations should not jump to out of scope join points + withReader (fun ctx => { ctx with jps := {} }) do checkFunDecl decl + withFVarId decl.fvarId do check k + | .jp decl k => checkFunDecl decl; withJp decl.fvarId do check k + | .cases c => checkCases c + | .jmp fvarId args => checkJpInScope fvarId; checkApp (.fvar fvarId) args; code.inferType + | .return fvarId => checkFVar fvarId; code.inferType + | .unreach .. => code.inferType + +end + +end Check + +end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 6c8ad93e59..784a069c7a 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -21,14 +21,18 @@ def getLocalDecl (fvarId : FVarId) : InferTypeM LocalDecl := do | some localDecl => return localDecl | none => LCNF.getLocalDecl fvarId -def mkForallFVars (xs : Array Expr) (b : Expr) : InferTypeM Expr := - let b := b.abstract xs +def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr := + let b := type.abstract xs xs.size.foldRevM (init := b) fun i b => do let x := xs[i]! let .cdecl _ _ n ty _ ← getLocalDecl x.fvarId! | unreachable! let ty := ty.abstractRange i xs; return .forallE n ty b .default +def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr := + let xs := params.map fun p => .fvar p.fvarId + mkForallFVars xs type |>.run {} + @[inline] def withLocalDecl (binderName : Name) (type : Expr) (binderInfo : BinderInfo) (k : Expr → InferTypeM α) : InferTypeM α := do let fvarId ← mkFreshFVarId withReader (fun lctx => lctx.mkLocalDecl fvarId binderName type binderInfo) do diff --git a/src/Lean/Compiler/LCNF/ToExpr.lean b/src/Lean/Compiler/LCNF/ToExpr.lean index 0583e23dea..515e6bb2d7 100644 --- a/src/Lean/Compiler/LCNF/ToExpr.lean +++ b/src/Lean/Compiler/LCNF/ToExpr.lean @@ -63,6 +63,9 @@ where else k +def run (x : M α) (offset : Nat := 0) (levelMap : LevelMap := {}) : α := + x |>.run offset |>.run' levelMap + end ToExpr open ToExpr @@ -89,6 +92,9 @@ partial def Code.toExprM (code : Code) : M Expr := do return mkAppN (mkConst `cases) (#[← c.discr.toExprM] ++ alts) def Code.toExpr (code : Code) : Expr := - code.toExprM |>.run 0 |>.run' {} + run code.toExprM + +def Decl.toExpr (decl : Decl) : Expr := + run do withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM) end Lean.Compiler.LCNF \ No newline at end of file