feat: add LCNF/Check.lean
This commit is contained in:
parent
766afdd0bc
commit
bce7eadfbc
4 changed files with 179 additions and 4 deletions
|
|
@ -73,7 +73,11 @@ structure Decl where
|
|||
-/
|
||||
type : Expr
|
||||
/--
|
||||
The value of the declaration, usually changes as it progresses
|
||||
Parameters.
|
||||
-/
|
||||
params : Array Param
|
||||
/--
|
||||
The body of the declaration, usually changes as it progresses
|
||||
through compiler passes.
|
||||
-/
|
||||
value : Code
|
||||
|
|
|
|||
161
src/Lean/Compiler/LCNF/Check.lean
Normal file
161
src/Lean/Compiler/LCNF/Check.lean
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.InferType
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
namespace Check
|
||||
open InferType
|
||||
|
||||
structure Context where
|
||||
/-- Join points that are in scope. -/
|
||||
jps : FVarIdSet := {}
|
||||
/-- Variables and local functions in scope -/
|
||||
vars : FVarIdSet := {}
|
||||
|
||||
structure State where
|
||||
/-- All free variables found -/
|
||||
all : FVarIdHashSet
|
||||
|
||||
abbrev CheckM := ReaderT Context $ StateRefT State InferTypeM
|
||||
|
||||
def checkFVar (fvarId : FVarId) : CheckM Unit :=
|
||||
unless (← read).vars.contains fvarId do
|
||||
throwError "invalid out of scope free variable {fvarId.name}"
|
||||
|
||||
def checkApp (f : Expr) (args : Array Expr) : CheckM Unit := do
|
||||
unless f.isConst || f.isFVar do
|
||||
throwError "unexpected function application, function must be a constant or free variable{indentExpr (mkAppN f args)}"
|
||||
if f.isFVar then
|
||||
checkFVar f.fvarId!
|
||||
let mut fType ← inferType f
|
||||
let mut j := 0
|
||||
for i in [:args.size] do
|
||||
let arg := args[i]!
|
||||
if fType.isAnyType then
|
||||
return ()
|
||||
fType := fType.headBeta
|
||||
let (d, b) ←
|
||||
match fType with
|
||||
| .forallE _ d b _ => pure (d, b)
|
||||
| _ =>
|
||||
fType := fType.instantiateRevRange j i args |>.headBeta
|
||||
match fType with
|
||||
| .forallE _ d b _ => j := i; pure (d, b)
|
||||
| _ =>
|
||||
if fType.isAnyType then return ()
|
||||
throwError "function expected at{indentExpr (mkAppN f args)}\narrow type expected{indentExpr fType}"
|
||||
let argType ← inferType arg
|
||||
let expectedType := d.instantiateRevRange j i args
|
||||
unless compatibleTypes argType expectedType do
|
||||
throwError "type mismatch at LCNF application{indentExpr (mkAppN f args)}\nargument {arg} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}"
|
||||
unless isTypeFormerType expectedType || expectedType.erased do
|
||||
unless arg.isFVar do
|
||||
throwError "invalid LCNF application{indentExpr (mkAppN f args)}\nargument{indentExpr arg}\nmust be a free variable"
|
||||
checkFVar arg.fvarId!
|
||||
fType := b
|
||||
|
||||
def checkExpr (e : Expr) : CheckM Unit :=
|
||||
match e with
|
||||
| .lit _ => pure ()
|
||||
| .app .. => checkApp e.getAppFn e.getAppArgs
|
||||
| .proj _ _ (.fvar fvarId) => checkFVar fvarId
|
||||
| .mdata _ (.fvar fvarId) => checkFVar fvarId
|
||||
| .const _ _ => pure () -- TODO: check number of universe level parameters
|
||||
| .fvar fvarId => checkFVar fvarId
|
||||
| _ => throwError "unexpected expression at LCNF{indentExpr e}"
|
||||
|
||||
def checkJpInScope (jp : FVarId) : CheckM Unit := do
|
||||
unless (← read).jps.contains jp do
|
||||
/-
|
||||
We cannot jump to join points defined out of the scope of a local function declaration.
|
||||
For example, the following is an invalid LCNF.
|
||||
```
|
||||
jp_1 := fun x => ... -- Some join point
|
||||
let f := fun y => -- Local function declaration.
|
||||
...
|
||||
jp_1 _x.n -- jump to a join point that is not in the scope of `f`.
|
||||
```
|
||||
-/
|
||||
throwError "invalid jump to out of scope join point"
|
||||
|
||||
def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do
|
||||
checkExpr letDecl.value
|
||||
let valueType ← inferType letDecl.value
|
||||
unless compatibleTypes letDecl.type valueType do
|
||||
throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}"
|
||||
|
||||
def addFVarId (fvarId : FVarId) : CheckM Unit := do
|
||||
if (← get).all.contains fvarId then
|
||||
throwError "invalid LCNF, free variables are not unique `{fvarId.name}`"
|
||||
modify fun s => { s with all := s.all.insert fvarId }
|
||||
|
||||
@[inline] def withFVarId (fvarId : FVarId) (x : CheckM α) : CheckM α := do
|
||||
addFVarId fvarId
|
||||
withReader (fun ctx => { ctx with vars := ctx.vars.insert fvarId }) x
|
||||
|
||||
@[inline] def withJp (fvarId : FVarId) (x : CheckM α) : CheckM α := do
|
||||
addFVarId fvarId
|
||||
withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x
|
||||
|
||||
@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do
|
||||
params.forM (addFVarId ·.fvarId)
|
||||
withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId })
|
||||
x
|
||||
|
||||
mutual
|
||||
|
||||
partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
|
||||
let type ← withParams funDecl.params do
|
||||
mkForallParams funDecl.params (← check funDecl.value)
|
||||
unless compatibleTypes funDecl.type type do
|
||||
throwError "type mismatch at `{funDecl.binderName}`, value has type{indentExpr type}\nbut is expected to have type{indentExpr funDecl.type}"
|
||||
|
||||
partial def checkCases (c : Cases) : CheckM Expr := do
|
||||
let mut ctorNames : NameSet := {}
|
||||
let mut hasDefault := false
|
||||
checkFVar c.discr
|
||||
let discrType ← inferFVarType c.discr
|
||||
let .const declName _ := discrType.headBeta.getAppFn | throwError "unexpected LCNF discriminant type {discrType}"
|
||||
unless c.typeName == declName do
|
||||
throwError "invalid LCNF `{c.typeName}.casesOn`, discriminant has type{indentExpr discrType}"
|
||||
for alt in c.alts do
|
||||
let type ←
|
||||
match alt with
|
||||
| .default k => hasDefault := true; check k
|
||||
| .alt ctorName params k =>
|
||||
if ctorNames.contains ctorName then
|
||||
throwError "invalid LCNF `cases`, alternative `{ctorName}` occurs more than once"
|
||||
ctorNames := ctorNames.insert ctorName
|
||||
let .ctorInfo val ← getConstInfo ctorName | throwError "invalid LCNF `cases`, `{ctorName}` is not a constructor name"
|
||||
unless val.induct == c.typeName do
|
||||
throwError "invalid LCNF `cases`, `{ctorName}` is not a constructor of `{c.typeName}`"
|
||||
unless params.size == val.numFields do
|
||||
throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives"
|
||||
-- TODO: check whether the ctor field types as parameter types match.
|
||||
withParams params do check k
|
||||
unless compatibleTypes type c.resultType do
|
||||
throwError "type mismatch at LCNF `cases` alternative\nhas type{indentExpr type}\nbut is expected to have type{indentExpr c.resultType}"
|
||||
return c.resultType
|
||||
|
||||
partial def check (code : Code) : CheckM Expr := do
|
||||
match code with
|
||||
| .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k
|
||||
| .fun decl k =>
|
||||
-- Remark: local function declarations should not jump to out of scope join points
|
||||
withReader (fun ctx => { ctx with jps := {} }) do checkFunDecl decl
|
||||
withFVarId decl.fvarId do check k
|
||||
| .jp decl k => checkFunDecl decl; withJp decl.fvarId do check k
|
||||
| .cases c => checkCases c
|
||||
| .jmp fvarId args => checkJpInScope fvarId; checkApp (.fvar fvarId) args; code.inferType
|
||||
| .return fvarId => checkFVar fvarId; code.inferType
|
||||
| .unreach .. => code.inferType
|
||||
|
||||
end
|
||||
|
||||
end Check
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -21,14 +21,18 @@ def getLocalDecl (fvarId : FVarId) : InferTypeM LocalDecl := do
|
|||
| some localDecl => return localDecl
|
||||
| none => LCNF.getLocalDecl fvarId
|
||||
|
||||
def mkForallFVars (xs : Array Expr) (b : Expr) : InferTypeM Expr :=
|
||||
let b := b.abstract xs
|
||||
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
||||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i b => do
|
||||
let x := xs[i]!
|
||||
let .cdecl _ _ n ty _ ← getLocalDecl x.fvarId! | unreachable!
|
||||
let ty := ty.abstractRange i xs;
|
||||
return .forallE n ty b .default
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
|
||||
let xs := params.map fun p => .fvar p.fvarId
|
||||
mkForallFVars xs type |>.run {}
|
||||
|
||||
@[inline] def withLocalDecl (binderName : Name) (type : Expr) (binderInfo : BinderInfo) (k : Expr → InferTypeM α) : InferTypeM α := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
withReader (fun lctx => lctx.mkLocalDecl fvarId binderName type binderInfo) do
|
||||
|
|
|
|||
|
|
@ -63,6 +63,9 @@ where
|
|||
else
|
||||
k
|
||||
|
||||
def run (x : M α) (offset : Nat := 0) (levelMap : LevelMap := {}) : α :=
|
||||
x |>.run offset |>.run' levelMap
|
||||
|
||||
end ToExpr
|
||||
|
||||
open ToExpr
|
||||
|
|
@ -89,6 +92,9 @@ partial def Code.toExprM (code : Code) : M Expr := do
|
|||
return mkAppN (mkConst `cases) (#[← c.discr.toExprM] ++ alts)
|
||||
|
||||
def Code.toExpr (code : Code) : Expr :=
|
||||
code.toExprM |>.run 0 |>.run' {}
|
||||
run code.toExprM
|
||||
|
||||
def Decl.toExpr (decl : Decl) : Expr :=
|
||||
run do withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
Loading…
Add table
Reference in a new issue