refactor: move PreDeclaration (now PreDefinition) to its own module
This commit is contained in:
parent
d0993d07a1
commit
35a81e80d6
2 changed files with 156 additions and 133 deletions
|
|
@ -7,6 +7,7 @@ import Lean.Meta.Closure
|
|||
import Lean.Meta.Check
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.PreDefinition
|
||||
|
||||
namespace Lean
|
||||
namespace Elab
|
||||
|
|
@ -204,14 +205,6 @@ letRecsToLift.forM fun toLift => do
|
|||
fnName ← getFunName fvarId letRecsToLift;
|
||||
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
|
||||
|
||||
structure PreDeclaration :=
|
||||
(kind : DefKind)
|
||||
(lparams : List Name)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
|
||||
namespace MutualClosure
|
||||
|
||||
/- A mapping from FVarId to Set of FVarIds. -/
|
||||
|
|
@ -485,14 +478,14 @@ e.replace fun e => match e with
|
|||
| _ => none
|
||||
| _ => none
|
||||
|
||||
def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDeclaration) :=
|
||||
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM
|
||||
(fun i (preDecls : Array PreDeclaration) => do
|
||||
(fun i (preDefs : Array PreDefinition) => do
|
||||
let header := mainHeaders.get! i;
|
||||
val ← mkLambdaFVars sectionVars (mainVals.get! i);
|
||||
type ← mkForallFVars sectionVars header.type;
|
||||
pure $ preDecls.push {
|
||||
pure $ preDefs.push {
|
||||
kind := header.kind,
|
||||
declName := header.declName,
|
||||
lparams := [], -- we set it later
|
||||
|
|
@ -500,14 +493,14 @@ mainHeaders.size.foldM
|
|||
type := type,
|
||||
value := val
|
||||
})
|
||||
preDecls
|
||||
preDefs
|
||||
|
||||
def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDeclaration :=
|
||||
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
|
||||
letRecClosures.foldl
|
||||
(fun (preDecls : Array PreDeclaration) (c : LetRecClosure) =>
|
||||
(fun (preDefs : Array PreDefinition) (c : LetRecClosure) =>
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type;
|
||||
let val := Closure.mkLambda c.localDecls c.toLift.val;
|
||||
preDecls.push {
|
||||
preDefs.push {
|
||||
kind := kind,
|
||||
declName := c.toLift.declName,
|
||||
lparams := [], -- we set it later
|
||||
|
|
@ -515,7 +508,7 @@ letRecClosures.foldl
|
|||
type := type,
|
||||
value := val
|
||||
})
|
||||
preDecls
|
||||
preDefs
|
||||
|
||||
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=
|
||||
if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem»
|
||||
|
|
@ -534,7 +527,7 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers :
|
|||
- `letRecsToLift`: The let-rec's definitions that need to be lifted
|
||||
-/
|
||||
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
|
||||
: TermElabM (Array PreDeclaration) := do
|
||||
: TermElabM (Array PreDefinition) := do
|
||||
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
|
||||
let mainFVarIds := mainFVars.map Expr.fvarId!;
|
||||
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
|
||||
|
|
@ -569,112 +562,6 @@ if h : 0 < headers.size then
|
|||
else
|
||||
[]
|
||||
|
||||
private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do
|
||||
type ← instantiateMVars preDecl.type;
|
||||
value ← instantiateMVars preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
|
||||
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
|
||||
nextIdx ← get;
|
||||
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
|
||||
set nextIdx;
|
||||
pure e
|
||||
|
||||
private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do
|
||||
type ← levelMVarToParamExpr preDecl.type;
|
||||
value ← levelMVarToParamExpr preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
|
||||
private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
|
||||
(levelMVarToParamPreDeclsAux preDecls).run' 1
|
||||
|
||||
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
|
||||
modify fun s => collectLevelParams s e
|
||||
|
||||
private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
|
||||
let (_, s) := StateT.run
|
||||
(preDecls.forM fun preDecl => do {
|
||||
collectLevelParamsExpr preDecl.type;
|
||||
collectLevelParamsExpr preDecl.value })
|
||||
{};
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
|
||||
| Except.error msg => throwError msg
|
||||
| Except.ok levelParams => pure levelParams
|
||||
|
||||
private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration :=
|
||||
let result : Std.ShareCommonM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do {
|
||||
type ← Std.withShareCommon preDecl.type;
|
||||
value ← Std.withShareCommon preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
};
|
||||
result.run
|
||||
|
||||
private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do
|
||||
let preDecls := shareCommon preDecls;
|
||||
lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames;
|
||||
let us := lparams.map mkLevelParam;
|
||||
let fixExpr (e : Expr) : Expr :=
|
||||
e.replace fun c => match c with
|
||||
| Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
|
||||
| _ => none;
|
||||
pure $ preDecls.map fun preDecl =>
|
||||
{ preDecl with
|
||||
type := fixExpr preDecl.type,
|
||||
value := fixExpr preDecl.value,
|
||||
lparams := lparams }
|
||||
|
||||
private def applyAttributesOf (preDecls : Array PreDeclaration) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
|
||||
preDecls.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime
|
||||
|
||||
private def addAndCompileNonRec (preDecl : PreDeclaration) : TermElabM Unit := do
|
||||
env ← getEnv;
|
||||
let decl :=
|
||||
match preDecl.kind with
|
||||
| DefKind.«example» => unreachable!
|
||||
| DefKind.«theorem» =>
|
||||
Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value }
|
||||
| DefKind.«opaque» =>
|
||||
Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
isUnsafe := preDecl.modifiers.isUnsafe }
|
||||
| DefKind.«abbrev» =>
|
||||
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe }
|
||||
| DefKind.«def» =>
|
||||
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1),
|
||||
isUnsafe := preDecl.modifiers.isUnsafe };
|
||||
ensureNoUnassignedMVars decl;
|
||||
addDecl decl;
|
||||
applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking;
|
||||
compileDecl decl;
|
||||
applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation;
|
||||
pure ()
|
||||
|
||||
private def addAndCompileAsUnsafe (preDecls : Array PreDeclaration) : TermElabM Unit := do
|
||||
let decl := Declaration.mutualDefnDecl $ preDecls.toList.map fun preDecl => {
|
||||
name := preDecl.declName,
|
||||
lparams := preDecl.lparams,
|
||||
type := preDecl.type,
|
||||
value := preDecl.value,
|
||||
isUnsafe := true,
|
||||
hints := ReducibilityHints.opaque
|
||||
};
|
||||
ensureNoUnassignedMVars decl;
|
||||
addDecl decl;
|
||||
applyAttributesOf preDecls AttributeApplicationTime.afterTypeChecking;
|
||||
compileDecl decl;
|
||||
applyAttributesOf preDecls AttributeApplicationTime.afterCompilation;
|
||||
pure ()
|
||||
|
||||
private def partitionNonRec (preDecls : Array PreDeclaration) : Array PreDeclaration × Array PreDeclaration :=
|
||||
preDecls.partition fun predDecl =>
|
||||
Option.isNone $ predDecl.value.find? fun c => match c with
|
||||
| Expr.const declName _ _ => preDecls.any fun preDecl => preDecl.declName == declName
|
||||
| _ => false
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
|
||||
scopeLevelNames ← getLevelNames;
|
||||
headers ← elabHeaders views;
|
||||
|
|
@ -690,15 +577,11 @@ withFunLocalDecls headers fun funFVars => do
|
|||
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift;
|
||||
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
|
||||
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
|
||||
preDecls ← levelMVarToParamPreDecls preDecls;
|
||||
preDecls ← instantiateMVarsAtPreDecls preDecls;
|
||||
preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames;
|
||||
preDecls.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value;
|
||||
let (preDeclsNonRec, preDecls) := partitionNonRec preDecls;
|
||||
preDeclsNonRec.forM addAndCompileNonRec;
|
||||
-- TODO
|
||||
addAndCompileAsUnsafe preDecls
|
||||
preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift;
|
||||
preDefs ← levelMVarToParamPreDecls preDefs;
|
||||
preDefs ← instantiateMVarsAtPreDecls preDefs;
|
||||
preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames;
|
||||
addPreDefinitions preDefs
|
||||
|
||||
end Term
|
||||
namespace Command
|
||||
|
|
|
|||
140
src/Lean/Elab/PreDefinition.lean
Normal file
140
src/Lean/Elab/PreDefinition.lean
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Elab.Term
|
||||
import Lean.Elab.DefView
|
||||
|
||||
namespace Lean
|
||||
namespace Elab
|
||||
open Meta
|
||||
open Term
|
||||
|
||||
/-
|
||||
A (potentially recursive) definition.
|
||||
The elaborator converts it into Kernel definitions using many different strategies.
|
||||
-/
|
||||
structure PreDefinition :=
|
||||
(kind : DefKind)
|
||||
(lparams : List Name)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
|
||||
def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDecl => do
|
||||
type ← instantiateMVars preDecl.type;
|
||||
value ← instantiateMVars preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
|
||||
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
|
||||
nextIdx ← get;
|
||||
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
|
||||
set nextIdx;
|
||||
pure e
|
||||
|
||||
private def levelMVarToParamPreDeclsAux (preDefs : Array PreDefinition) : StateRefT Nat TermElabM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDecl => do
|
||||
type ← levelMVarToParamExpr preDecl.type;
|
||||
value ← levelMVarToParamExpr preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
|
||||
def levelMVarToParamPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
|
||||
(levelMVarToParamPreDeclsAux preDefs).run' 1
|
||||
|
||||
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
|
||||
modify fun s => collectLevelParams s e
|
||||
|
||||
private def getLevelParamsPreDecls (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
|
||||
let (_, s) := StateT.run
|
||||
(preDefs.forM fun preDecl => do {
|
||||
collectLevelParamsExpr preDecl.type;
|
||||
collectLevelParamsExpr preDecl.value })
|
||||
{};
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
|
||||
| Except.error msg => throwError msg
|
||||
| Except.ok levelParams => pure levelParams
|
||||
|
||||
private def shareCommon (preDefs : Array PreDefinition) : Array PreDefinition :=
|
||||
let result : Std.ShareCommonM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDecl => do {
|
||||
type ← Std.withShareCommon preDecl.type;
|
||||
value ← Std.withShareCommon preDecl.value;
|
||||
pure { preDecl with type := type, value := value }
|
||||
};
|
||||
result.run
|
||||
|
||||
def fixLevelParams (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDefinition) := do
|
||||
let preDefs := shareCommon preDefs;
|
||||
lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames;
|
||||
let us := lparams.map mkLevelParam;
|
||||
let fixExpr (e : Expr) : Expr :=
|
||||
e.replace fun c => match c with
|
||||
| Expr.const declName _ _ => if preDefs.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
|
||||
| _ => none;
|
||||
pure $ preDefs.map fun preDecl =>
|
||||
{ preDecl with
|
||||
type := fixExpr preDecl.type,
|
||||
value := fixExpr preDecl.value,
|
||||
lparams := lparams }
|
||||
|
||||
private def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
|
||||
preDefs.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime
|
||||
|
||||
private def addAndCompileNonRec (preDecl : PreDefinition) : TermElabM Unit := do
|
||||
env ← getEnv;
|
||||
let decl :=
|
||||
match preDecl.kind with
|
||||
| DefKind.«example» => unreachable!
|
||||
| DefKind.«theorem» =>
|
||||
Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value }
|
||||
| DefKind.«opaque» =>
|
||||
Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
isUnsafe := preDecl.modifiers.isUnsafe }
|
||||
| DefKind.«abbrev» =>
|
||||
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe }
|
||||
| DefKind.«def» =>
|
||||
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1),
|
||||
isUnsafe := preDecl.modifiers.isUnsafe };
|
||||
ensureNoUnassignedMVars decl;
|
||||
addDecl decl;
|
||||
applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking;
|
||||
compileDecl decl;
|
||||
applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation;
|
||||
pure ()
|
||||
|
||||
private def addAndCompileAsUnsafe (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDecl => {
|
||||
name := preDecl.declName,
|
||||
lparams := preDecl.lparams,
|
||||
type := preDecl.type,
|
||||
value := preDecl.value,
|
||||
isUnsafe := true,
|
||||
hints := ReducibilityHints.opaque
|
||||
};
|
||||
ensureNoUnassignedMVars decl;
|
||||
addDecl decl;
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking;
|
||||
compileDecl decl;
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation;
|
||||
pure ()
|
||||
|
||||
private def partitionNonRec (preDefs : Array PreDefinition) : Array PreDefinition × Array PreDefinition :=
|
||||
preDefs.partition fun predDecl =>
|
||||
Option.isNone $ predDecl.value.find? fun c => match c with
|
||||
| Expr.const declName _ _ => preDefs.any fun preDecl => preDecl.declName == declName
|
||||
| _ => false
|
||||
|
||||
def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
preDefs.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value;
|
||||
let (preDefsNonRec, preDefs) := partitionNonRec preDefs;
|
||||
preDefsNonRec.forM addAndCompileNonRec;
|
||||
-- TODO
|
||||
addAndCompileAsUnsafe preDefs
|
||||
|
||||
end Elab
|
||||
end Lean
|
||||
Loading…
Add table
Reference in a new issue