From 35a81e80d6af024d58cdaf74a4a04ba36b4bf549 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 6 Sep 2020 08:54:38 -0700 Subject: [PATCH] refactor: move `PreDeclaration` (now `PreDefinition`) to its own module --- src/Lean/Elab/MutualDef.lean | 149 ++++--------------------------- src/Lean/Elab/PreDefinition.lean | 140 +++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 133 deletions(-) create mode 100644 src/Lean/Elab/PreDefinition.lean diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index aa51e31688..04ce9472af 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -7,6 +7,7 @@ import Lean.Meta.Closure import Lean.Meta.Check import Lean.Elab.Command import Lean.Elab.DefView +import Lean.Elab.PreDefinition namespace Lean namespace Elab @@ -204,14 +205,6 @@ letRecsToLift.forM fun toLift => do fnName ← getFunName fvarId letRecsToLift; throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously") -structure PreDeclaration := -(kind : DefKind) -(lparams : List Name) -(modifiers : Modifiers) -(declName : Name) -(type : Expr) -(value : Expr) - namespace MutualClosure /- A mapping from FVarId to Set of FVarIds. -/ @@ -485,14 +478,14 @@ e.replace fun e => match e with | _ => none | _ => none -def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr) - : TermElabM (Array PreDeclaration) := +def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr) + : TermElabM (Array PreDefinition) := mainHeaders.size.foldM - (fun i (preDecls : Array PreDeclaration) => do + (fun i (preDefs : Array PreDefinition) => do let header := mainHeaders.get! i; val ← mkLambdaFVars sectionVars (mainVals.get! i); type ← mkForallFVars sectionVars header.type; - pure $ preDecls.push { + pure $ preDefs.push { kind := header.kind, declName := header.declName, lparams := [], -- we set it later @@ -500,14 +493,14 @@ mainHeaders.size.foldM type := type, value := val }) - preDecls + preDefs -def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDeclaration := +def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition := letRecClosures.foldl - (fun (preDecls : Array PreDeclaration) (c : LetRecClosure) => + (fun (preDefs : Array PreDefinition) (c : LetRecClosure) => let type := Closure.mkForall c.localDecls c.toLift.type; let val := Closure.mkLambda c.localDecls c.toLift.val; - preDecls.push { + preDefs.push { kind := kind, declName := c.toLift.declName, lparams := [], -- we set it later @@ -515,7 +508,7 @@ letRecClosures.foldl type := type, value := val }) - preDecls + preDefs def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind := if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem» @@ -534,7 +527,7 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers : - `letRecsToLift`: The let-rec's definitions that need to be lifted -/ def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift) - : TermElabM (Array PreDeclaration) := do + : TermElabM (Array PreDefinition) := do -- Store in recFVarIds the fvarId of every function being defined by the mutual block. let mainFVarIds := mainFVars.map Expr.fvarId!; let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds; @@ -569,112 +562,6 @@ if h : 0 < headers.size then else [] -private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) := -preDecls.mapM fun preDecl => do - type ← instantiateMVars preDecl.type; - value ← instantiateMVars preDecl.value; - pure { preDecl with type := type, value := value } - -private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do -nextIdx ← get; -(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx; -set nextIdx; -pure e - -private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) := -preDecls.mapM fun preDecl => do - type ← levelMVarToParamExpr preDecl.type; - value ← levelMVarToParamExpr preDecl.value; - pure { preDecl with type := type, value := value } - -private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) := -(levelMVarToParamPreDeclsAux preDecls).run' 1 - -private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do -modify fun s => collectLevelParams s e - -private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) := -let (_, s) := StateT.run - (preDecls.forM fun preDecl => do { - collectLevelParamsExpr preDecl.type; - collectLevelParamsExpr preDecl.value }) - {}; -match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with -| Except.error msg => throwError msg -| Except.ok levelParams => pure levelParams - -private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration := -let result : Std.ShareCommonM (Array PreDeclaration) := - preDecls.mapM fun preDecl => do { - type ← Std.withShareCommon preDecl.type; - value ← Std.withShareCommon preDecl.value; - pure { preDecl with type := type, value := value } - }; -result.run - -private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do -let preDecls := shareCommon preDecls; -lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames; -let us := lparams.map mkLevelParam; -let fixExpr (e : Expr) : Expr := - e.replace fun c => match c with - | Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none - | _ => none; -pure $ preDecls.map fun preDecl => - { preDecl with - type := fixExpr preDecl.type, - value := fixExpr preDecl.value, - lparams := lparams } - -private def applyAttributesOf (preDecls : Array PreDeclaration) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do -preDecls.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime - -private def addAndCompileNonRec (preDecl : PreDeclaration) : TermElabM Unit := do -env ← getEnv; -let decl := - match preDecl.kind with - | DefKind.«example» => unreachable! - | DefKind.«theorem» => - Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value } - | DefKind.«opaque» => - Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, - isUnsafe := preDecl.modifiers.isUnsafe } - | DefKind.«abbrev» => - Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, - hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe } - | DefKind.«def» => - Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, - hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1), - isUnsafe := preDecl.modifiers.isUnsafe }; -ensureNoUnassignedMVars decl; -addDecl decl; -applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking; -compileDecl decl; -applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation; -pure () - -private def addAndCompileAsUnsafe (preDecls : Array PreDeclaration) : TermElabM Unit := do -let decl := Declaration.mutualDefnDecl $ preDecls.toList.map fun preDecl => { - name := preDecl.declName, - lparams := preDecl.lparams, - type := preDecl.type, - value := preDecl.value, - isUnsafe := true, - hints := ReducibilityHints.opaque - }; -ensureNoUnassignedMVars decl; -addDecl decl; -applyAttributesOf preDecls AttributeApplicationTime.afterTypeChecking; -compileDecl decl; -applyAttributesOf preDecls AttributeApplicationTime.afterCompilation; -pure () - -private def partitionNonRec (preDecls : Array PreDeclaration) : Array PreDeclaration × Array PreDeclaration := -preDecls.partition fun predDecl => - Option.isNone $ predDecl.value.find? fun c => match c with - | Expr.const declName _ _ => preDecls.any fun preDecl => preDecl.declName == declName - | _ => false - def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do scopeLevelNames ← getLevelNames; headers ← elabHeaders views; @@ -690,15 +577,11 @@ withFunLocalDecls headers fun funFVars => do letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift; checkLetRecsToLiftTypes funFVars letRecsToLift; withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do - preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift; - preDecls ← levelMVarToParamPreDecls preDecls; - preDecls ← instantiateMVarsAtPreDecls preDecls; - preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames; - preDecls.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value; - let (preDeclsNonRec, preDecls) := partitionNonRec preDecls; - preDeclsNonRec.forM addAndCompileNonRec; - -- TODO - addAndCompileAsUnsafe preDecls + preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift; + preDefs ← levelMVarToParamPreDecls preDefs; + preDefs ← instantiateMVarsAtPreDecls preDefs; + preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames; + addPreDefinitions preDefs end Term namespace Command diff --git a/src/Lean/Elab/PreDefinition.lean b/src/Lean/Elab/PreDefinition.lean new file mode 100644 index 0000000000..dfa92ec88e --- /dev/null +++ b/src/Lean/Elab/PreDefinition.lean @@ -0,0 +1,140 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Elab.Term +import Lean.Elab.DefView + +namespace Lean +namespace Elab +open Meta +open Term + +/- +A (potentially recursive) definition. +The elaborator converts it into Kernel definitions using many different strategies. +-/ +structure PreDefinition := +(kind : DefKind) +(lparams : List Name) +(modifiers : Modifiers) +(declName : Name) +(type : Expr) +(value : Expr) + +def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) := +preDefs.mapM fun preDecl => do + type ← instantiateMVars preDecl.type; + value ← instantiateMVars preDecl.value; + pure { preDecl with type := type, value := value } + +private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do +nextIdx ← get; +(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx; +set nextIdx; +pure e + +private def levelMVarToParamPreDeclsAux (preDefs : Array PreDefinition) : StateRefT Nat TermElabM (Array PreDefinition) := +preDefs.mapM fun preDecl => do + type ← levelMVarToParamExpr preDecl.type; + value ← levelMVarToParamExpr preDecl.value; + pure { preDecl with type := type, value := value } + +def levelMVarToParamPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) := +(levelMVarToParamPreDeclsAux preDefs).run' 1 + +private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do +modify fun s => collectLevelParams s e + +private def getLevelParamsPreDecls (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) := +let (_, s) := StateT.run + (preDefs.forM fun preDecl => do { + collectLevelParamsExpr preDecl.type; + collectLevelParamsExpr preDecl.value }) + {}; +match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with +| Except.error msg => throwError msg +| Except.ok levelParams => pure levelParams + +private def shareCommon (preDefs : Array PreDefinition) : Array PreDefinition := +let result : Std.ShareCommonM (Array PreDefinition) := + preDefs.mapM fun preDecl => do { + type ← Std.withShareCommon preDecl.type; + value ← Std.withShareCommon preDecl.value; + pure { preDecl with type := type, value := value } + }; +result.run + +def fixLevelParams (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDefinition) := do +let preDefs := shareCommon preDefs; +lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames; +let us := lparams.map mkLevelParam; +let fixExpr (e : Expr) : Expr := + e.replace fun c => match c with + | Expr.const declName _ _ => if preDefs.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none + | _ => none; +pure $ preDefs.map fun preDecl => + { preDecl with + type := fixExpr preDecl.type, + value := fixExpr preDecl.value, + lparams := lparams } + +private def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do +preDefs.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime + +private def addAndCompileNonRec (preDecl : PreDefinition) : TermElabM Unit := do +env ← getEnv; +let decl := + match preDecl.kind with + | DefKind.«example» => unreachable! + | DefKind.«theorem» => + Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value } + | DefKind.«opaque» => + Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, + isUnsafe := preDecl.modifiers.isUnsafe } + | DefKind.«abbrev» => + Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, + hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe } + | DefKind.«def» => + Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value, + hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1), + isUnsafe := preDecl.modifiers.isUnsafe }; +ensureNoUnassignedMVars decl; +addDecl decl; +applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking; +compileDecl decl; +applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation; +pure () + +private def addAndCompileAsUnsafe (preDefs : Array PreDefinition) : TermElabM Unit := do +let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDecl => { + name := preDecl.declName, + lparams := preDecl.lparams, + type := preDecl.type, + value := preDecl.value, + isUnsafe := true, + hints := ReducibilityHints.opaque + }; +ensureNoUnassignedMVars decl; +addDecl decl; +applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking; +compileDecl decl; +applyAttributesOf preDefs AttributeApplicationTime.afterCompilation; +pure () + +private def partitionNonRec (preDefs : Array PreDefinition) : Array PreDefinition × Array PreDefinition := +preDefs.partition fun predDecl => + Option.isNone $ predDecl.value.find? fun c => match c with + | Expr.const declName _ _ => preDefs.any fun preDecl => preDecl.declName == declName + | _ => false + +def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := do +preDefs.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value; +let (preDefsNonRec, preDefs) := partitionNonRec preDefs; +preDefsNonRec.forM addAndCompileNonRec; +-- TODO +addAndCompileAsUnsafe preDefs + +end Elab +end Lean