refactor: move PreDeclaration (now PreDefinition) to its own module

This commit is contained in:
Leonardo de Moura 2020-09-06 08:54:38 -07:00
parent d0993d07a1
commit 35a81e80d6
2 changed files with 156 additions and 133 deletions

View file

@ -7,6 +7,7 @@ import Lean.Meta.Closure
import Lean.Meta.Check
import Lean.Elab.Command
import Lean.Elab.DefView
import Lean.Elab.PreDefinition
namespace Lean
namespace Elab
@ -204,14 +205,6 @@ letRecsToLift.forM fun toLift => do
fnName ← getFunName fvarId letRecsToLift;
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
structure PreDeclaration :=
(kind : DefKind)
(lparams : List Name)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
(value : Expr)
namespace MutualClosure
/- A mapping from FVarId to Set of FVarIds. -/
@ -485,14 +478,14 @@ e.replace fun e => match e with
| _ => none
| _ => none
def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
: TermElabM (Array PreDeclaration) :=
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
: TermElabM (Array PreDefinition) :=
mainHeaders.size.foldM
(fun i (preDecls : Array PreDeclaration) => do
(fun i (preDefs : Array PreDefinition) => do
let header := mainHeaders.get! i;
val ← mkLambdaFVars sectionVars (mainVals.get! i);
type ← mkForallFVars sectionVars header.type;
pure $ preDecls.push {
pure $ preDefs.push {
kind := header.kind,
declName := header.declName,
lparams := [], -- we set it later
@ -500,14 +493,14 @@ mainHeaders.size.foldM
type := type,
value := val
})
preDecls
preDefs
def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDeclaration :=
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
letRecClosures.foldl
(fun (preDecls : Array PreDeclaration) (c : LetRecClosure) =>
(fun (preDefs : Array PreDefinition) (c : LetRecClosure) =>
let type := Closure.mkForall c.localDecls c.toLift.type;
let val := Closure.mkLambda c.localDecls c.toLift.val;
preDecls.push {
preDefs.push {
kind := kind,
declName := c.toLift.declName,
lparams := [], -- we set it later
@ -515,7 +508,7 @@ letRecClosures.foldl
type := type,
value := val
})
preDecls
preDefs
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=
if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem»
@ -534,7 +527,7 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers :
- `letRecsToLift`: The let-rec's definitions that need to be lifted
-/
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
: TermElabM (Array PreDeclaration) := do
: TermElabM (Array PreDefinition) := do
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
let mainFVarIds := mainFVars.map Expr.fvarId!;
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
@ -569,112 +562,6 @@ if h : 0 < headers.size then
else
[]
private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do
type ← instantiateMVars preDecl.type;
value ← instantiateMVars preDecl.value;
pure { preDecl with type := type, value := value }
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
nextIdx ← get;
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
set nextIdx;
pure e
private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do
type ← levelMVarToParamExpr preDecl.type;
value ← levelMVarToParamExpr preDecl.value;
pure { preDecl with type := type, value := value }
private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
(levelMVarToParamPreDeclsAux preDecls).run' 1
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
modify fun s => collectLevelParams s e
private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
let (_, s) := StateT.run
(preDecls.forM fun preDecl => do {
collectLevelParamsExpr preDecl.type;
collectLevelParamsExpr preDecl.value })
{};
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
| Except.error msg => throwError msg
| Except.ok levelParams => pure levelParams
private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration :=
let result : Std.ShareCommonM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do {
type ← Std.withShareCommon preDecl.type;
value ← Std.withShareCommon preDecl.value;
pure { preDecl with type := type, value := value }
};
result.run
private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do
let preDecls := shareCommon preDecls;
lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames;
let us := lparams.map mkLevelParam;
let fixExpr (e : Expr) : Expr :=
e.replace fun c => match c with
| Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
| _ => none;
pure $ preDecls.map fun preDecl =>
{ preDecl with
type := fixExpr preDecl.type,
value := fixExpr preDecl.value,
lparams := lparams }
private def applyAttributesOf (preDecls : Array PreDeclaration) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
preDecls.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime
private def addAndCompileNonRec (preDecl : PreDeclaration) : TermElabM Unit := do
env ← getEnv;
let decl :=
match preDecl.kind with
| DefKind.«example» => unreachable!
| DefKind.«theorem» =>
Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value }
| DefKind.«opaque» =>
Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
isUnsafe := preDecl.modifiers.isUnsafe }
| DefKind.«abbrev» =>
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe }
| DefKind.«def» =>
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1),
isUnsafe := preDecl.modifiers.isUnsafe };
ensureNoUnassignedMVars decl;
addDecl decl;
applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation;
pure ()
private def addAndCompileAsUnsafe (preDecls : Array PreDeclaration) : TermElabM Unit := do
let decl := Declaration.mutualDefnDecl $ preDecls.toList.map fun preDecl => {
name := preDecl.declName,
lparams := preDecl.lparams,
type := preDecl.type,
value := preDecl.value,
isUnsafe := true,
hints := ReducibilityHints.opaque
};
ensureNoUnassignedMVars decl;
addDecl decl;
applyAttributesOf preDecls AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributesOf preDecls AttributeApplicationTime.afterCompilation;
pure ()
private def partitionNonRec (preDecls : Array PreDeclaration) : Array PreDeclaration × Array PreDeclaration :=
preDecls.partition fun predDecl =>
Option.isNone $ predDecl.value.find? fun c => match c with
| Expr.const declName _ _ => preDecls.any fun preDecl => preDecl.declName == declName
| _ => false
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
scopeLevelNames ← getLevelNames;
headers ← elabHeaders views;
@ -690,15 +577,11 @@ withFunLocalDecls headers fun funFVars => do
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
checkLetRecsToLiftTypes funFVars letRecsToLift;
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
preDecls ← levelMVarToParamPreDecls preDecls;
preDecls ← instantiateMVarsAtPreDecls preDecls;
preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames;
preDecls.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value;
let (preDeclsNonRec, preDecls) := partitionNonRec preDecls;
preDeclsNonRec.forM addAndCompileNonRec;
-- TODO
addAndCompileAsUnsafe preDecls
preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift;
preDefs ← levelMVarToParamPreDecls preDefs;
preDefs ← instantiateMVarsAtPreDecls preDefs;
preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames;
addPreDefinitions preDefs
end Term
namespace Command

View file

@ -0,0 +1,140 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Elab.Term
import Lean.Elab.DefView
namespace Lean
namespace Elab
open Meta
open Term
/-
A (potentially recursive) definition.
The elaborator converts it into Kernel definitions using many different strategies.
-/
structure PreDefinition :=
(kind : DefKind)
(lparams : List Name)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
(value : Expr)
def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
preDefs.mapM fun preDecl => do
type ← instantiateMVars preDecl.type;
value ← instantiateMVars preDecl.value;
pure { preDecl with type := type, value := value }
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
nextIdx ← get;
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
set nextIdx;
pure e
private def levelMVarToParamPreDeclsAux (preDefs : Array PreDefinition) : StateRefT Nat TermElabM (Array PreDefinition) :=
preDefs.mapM fun preDecl => do
type ← levelMVarToParamExpr preDecl.type;
value ← levelMVarToParamExpr preDecl.value;
pure { preDecl with type := type, value := value }
def levelMVarToParamPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
(levelMVarToParamPreDeclsAux preDefs).run' 1
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
modify fun s => collectLevelParams s e
private def getLevelParamsPreDecls (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
let (_, s) := StateT.run
(preDefs.forM fun preDecl => do {
collectLevelParamsExpr preDecl.type;
collectLevelParamsExpr preDecl.value })
{};
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
| Except.error msg => throwError msg
| Except.ok levelParams => pure levelParams
private def shareCommon (preDefs : Array PreDefinition) : Array PreDefinition :=
let result : Std.ShareCommonM (Array PreDefinition) :=
preDefs.mapM fun preDecl => do {
type ← Std.withShareCommon preDecl.type;
value ← Std.withShareCommon preDecl.value;
pure { preDecl with type := type, value := value }
};
result.run
def fixLevelParams (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDefinition) := do
let preDefs := shareCommon preDefs;
lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames;
let us := lparams.map mkLevelParam;
let fixExpr (e : Expr) : Expr :=
e.replace fun c => match c with
| Expr.const declName _ _ => if preDefs.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
| _ => none;
pure $ preDefs.map fun preDecl =>
{ preDecl with
type := fixExpr preDecl.type,
value := fixExpr preDecl.value,
lparams := lparams }
private def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
preDefs.forM fun preDecl => applyAttributes preDecl.declName preDecl.modifiers.attrs applicationTime
private def addAndCompileNonRec (preDecl : PreDefinition) : TermElabM Unit := do
env ← getEnv;
let decl :=
match preDecl.kind with
| DefKind.«example» => unreachable!
| DefKind.«theorem» =>
Declaration.thmDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value }
| DefKind.«opaque» =>
Declaration.opaqueDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
isUnsafe := preDecl.modifiers.isUnsafe }
| DefKind.«abbrev» =>
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
hints := ReducibilityHints.«abbrev», isUnsafe := preDecl.modifiers.isUnsafe }
| DefKind.«def» =>
Declaration.defnDecl { name := preDecl.declName, lparams := preDecl.lparams, type := preDecl.type, value := preDecl.value,
hints := ReducibilityHints.regular (getMaxHeight env preDecl.value + 1),
isUnsafe := preDecl.modifiers.isUnsafe };
ensureNoUnassignedMVars decl;
addDecl decl;
applyAttributesOf #[preDecl] AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributesOf #[preDecl] AttributeApplicationTime.afterCompilation;
pure ()
private def addAndCompileAsUnsafe (preDefs : Array PreDefinition) : TermElabM Unit := do
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDecl => {
name := preDecl.declName,
lparams := preDecl.lparams,
type := preDecl.type,
value := preDecl.value,
isUnsafe := true,
hints := ReducibilityHints.opaque
};
ensureNoUnassignedMVars decl;
addDecl decl;
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation;
pure ()
private def partitionNonRec (preDefs : Array PreDefinition) : Array PreDefinition × Array PreDefinition :=
preDefs.partition fun predDecl =>
Option.isNone $ predDecl.value.find? fun c => match c with
| Expr.const declName _ _ => preDefs.any fun preDecl => preDecl.declName == declName
| _ => false
def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := do
preDefs.forM fun preDecl => trace `Elab.definition.body fun _ => preDecl.declName ++ " : " ++ preDecl.type ++ " :=" ++ Format.line ++ preDecl.value;
let (preDefsNonRec, preDefs) := partitionNonRec preDefs;
preDefsNonRec.forM addAndCompileNonRec;
-- TODO
addAndCompileAsUnsafe preDefs
end Elab
end Lean