diff --git a/src/Lean/Compiler/Check.lean b/src/Lean/Compiler/Check.lean index 0b48a6ca53..26164ec623 100644 --- a/src/Lean/Compiler/Check.lean +++ b/src/Lean/Compiler/Check.lean @@ -6,14 +6,15 @@ Authors: Leonardo de Moura import Lean.Compiler.InferType import Lean.Compiler.Util -namespace Lean.Compiler.InferType +namespace Lean.Compiler +open InferType /-! Type checker for LCNF expressions -/ -structure Config where - terminalCasesOnly : Bool := false +structure Check.Config where + terminalCasesOnly : Bool := true def lambdaBoundedTelescope (e : Expr) (n : Nat) (k : Array Expr → Expr → InferTypeM α) : InferTypeM α := go e n #[] @@ -26,7 +27,7 @@ where withLocalDecl n (d.instantiateRev xs) bi fun x => go b i (xs.push x) | _ => throwError "lambda expected" -partial def check (e : Expr) (cfg : Config := {}) : InferTypeM Unit := do +partial def check (e : Expr) (cfg : Check.Config := {}) : InferTypeM Unit := do checkBlock e #[] where checkBlock (e : Expr) (xs : Array Expr) : InferTypeM Unit := do @@ -95,4 +96,4 @@ where unless compatibleTypes expectedType altBodyType do throwError "type mismatch at LCNF `cases` alternative{indentExpr altBody}\nhas type{indentExpr altBodyType}\nbut is expected to have type{indentExpr expectedType}" -end Lean.Compiler.InferType \ No newline at end of file +end Lean.Compiler \ No newline at end of file diff --git a/src/Lean/Compiler/Decl.lean b/src/Lean/Compiler/Decl.lean index 44454408b2..06ec342d04 100644 --- a/src/Lean/Compiler/Decl.lean +++ b/src/Lean/Compiler/Decl.lean @@ -61,8 +61,8 @@ def toDecl (declName : Name) : CoreM Decl := do let value ← toLCNF value -- TODO: uncomment return { name := declName, type, value } -def Decl.check (decl : Decl) : CoreM Unit := do - InferType.check decl.value {} { lctx := {} } +def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do + Compiler.check decl.value cfg { lctx := {} } let valueType ← InferType.inferType decl.value { lctx := {} } unless compatibleTypes decl.type valueType do throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}" diff --git a/src/Lean/Compiler/InferType.lean b/src/Lean/Compiler/InferType.lean index 76c43cd8b7..13ccdff61f 100644 --- a/src/Lean/Compiler/InferType.lean +++ b/src/Lean/Compiler/InferType.lean @@ -155,6 +155,9 @@ instance : MonadInferType InferType.InferTypeM where export MonadInferType (inferType) +instance [MonadLift m n] [MonadInferType m] : MonadInferType n where + inferType e := liftM (inferType e : m _) + def getLevel [Monad m] [MonadInferType m] [MonadError m] (type : Expr) : m Level := do match (← inferType type) with | .sort u => return u diff --git a/src/Lean/Compiler/Main.lean b/src/Lean/Compiler/Main.lean index ad786bc16e..5095517b43 100644 --- a/src/Lean/Compiler/Main.lean +++ b/src/Lean/Compiler/Main.lean @@ -28,19 +28,18 @@ where let info ← getConstInfo declName Meta.isProp info.type <||> Meta.isTypeFormerType info.type -def checkpoint (step : Name) (decls : Array Decl) : CoreM Unit := do +def checkpoint (step : Name) (decls : Array Decl) (cfg : Check.Config := {}): CoreM Unit := do trace[Meta.debug] "After {step}" for decl in decls do withOptions (fun opts => opts.setBool `pp.motives.pi false) do trace[Meta.debug] "{decl.name} := {decl.value}" - decl.check + decl.check cfg def compile (declNames : Array Name) : CoreM Unit := do let declNames ← declNames.filterM shouldGenerateCode let decls ← declNames.mapM toDecl - checkpoint `init decls - -- TODO: uncomment - -- let decls ← decls.mapM (·.terminalCases) - -- checkpoint `terminalCases decls + checkpoint `init decls { terminalCasesOnly := false } + let decls ← decls.mapM (·.terminalCases) + checkpoint `terminalCases decls end Lean.Compiler diff --git a/src/Lean/Compiler/TerminalCases.lean b/src/Lean/Compiler/TerminalCases.lean index b1067004d5..31e2aed4ae 100644 --- a/src/Lean/Compiler/TerminalCases.lean +++ b/src/Lean/Compiler/TerminalCases.lean @@ -6,8 +6,7 @@ Authors: Leonardo de Moura import Lean.Meta.Check import Lean.Compiler.Util import Lean.Compiler.Decl - -#exit -- TODO: port file to new LCNF format +import Lean.Compiler.CompilerM namespace Lean.Compiler @@ -29,8 +28,8 @@ partial def visitAlt (e : Expr) (numParams : Nat) : M Expr := do partial def visitCases (casesInfo : CasesInfo) (cases : Expr) : M Expr := do let mut args := cases.getAppArgs if let some jp := (← read).jp? then - let .forallE _ _ b _ ← inferType' jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet` - args ← liftMetaM <| updateMotive casesInfo args b + let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet` + args := casesInfo.updateResultingType args b for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do args ← args.modifyM i (visitAlt · numParams) return mkAppN cases.getAppFn args @@ -47,34 +46,15 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do let type := type.instantiateRev fvars let mut value := value.instantiateRev fvars if let some casesInfo ← isCasesApp? value then - let (bodyAbst, safeJp) ← withNewScope do + let bodyAbst ← withNewScope do let x ← mkLocalDecl binderName type let body ← visitLet body (fvars.push x) let body ← mkLetUsingScope body - let bodyType ← inferType body let bodyAbst := body.abstract #[x] - if (bodyType.abstract #[x]).hasLooseBVars then - /- - We cannot eliminate this nonterminal `cases` because the resulting type of the joinpoint - depends on `x`. We have to wait until we perform erasure to do it. - -/ - return (bodyAbst, false) - else if !(← liftMetaM <| Meta.isTypeCorrect body) then - /- - We cannot eliminate this nonterminal `cases` because the joinpoint is not type correct. - This can happen because we abstracted `x`. - We have to wait until we perform erasure to do it. - Remark: we can skip this test if we set `nonDep` properly. - -/ - return (bodyAbst, false) - else - return (bodyAbst, true) - if !safeJp then - return .letE binderName type value bodyAbst nonDep - else - let jp ← mkJpDecl (.lam binderName type bodyAbst .default) - withReader (fun _ => { jp? := some jp }) do - visitCases casesInfo value + return bodyAbst + let jp ← mkJpDecl (.lam binderName type bodyAbst .default) + withReader (fun _ => { jp? := some jp }) do + visitCases casesInfo value else if value.isLambda then value ← visitLambda value @@ -91,10 +71,10 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do else return e | some jp => - let .forallE _ d _ _ ← inferType' jp | unreachable! + let .forallE _ d _ _ ← inferType jp | unreachable! if isLcUnreachable e then mkLcUnreachable d - else if (← isDefEq (← inferType e) d) then + else if compatibleTypes (← inferType e) d then let x ← mkAuxLetDecl e return mkApp jp x else if let some x := isLcCast? e then @@ -110,7 +90,7 @@ end end TerminalCases /-- -(Try to) ensure all `casesOn` and `matcher` applications are terminal. +Ensure all `casesOn` and `matcher` applications are terminal. -/ def Decl.terminalCases (decl : Decl) : CoreM Decl := do return { decl with value := (← TerminalCases.visitLambda decl.value |>.run {} |>.run' { nextIdx := (← getMaxLetVarIdx decl.value) + 1 }) } diff --git a/src/Lean/Compiler/Util.lean b/src/Lean/Compiler/Util.lean index e450bd8ddd..048ca754d3 100644 --- a/src/Lean/Compiler/Util.lean +++ b/src/Lean/Compiler/Util.lean @@ -89,6 +89,14 @@ def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do def CasesInfo.geNumDiscrs (casesInfo : CasesInfo) : Nat := casesInfo.discrsRange.stop - casesInfo.discrsRange.start +def CasesInfo.updateResultingType (casesInfo : CasesInfo) (casesArgs : Array Expr) (typeNew : Expr) : Array Expr := + casesArgs.modify casesInfo.motivePos fun motive => go motive +where + go (e : Expr) : Expr := + match e with + | .lam n b d bi => .lam n b (go d) bi + | _ => typeNew + def isCasesApp? (e : Expr) : CoreM (Option CasesInfo) := do let .const declName _ := e.getAppFn | return none if let some info ← getCasesInfo? declName then