chore: update stage0

This commit is contained in:
Leonardo de Moura 2020-10-15 14:31:37 -07:00
parent 6c6f3dca87
commit ef04995f0e
3 changed files with 4187 additions and 4596 deletions

View file

@ -1,3 +1,4 @@
#lang lean4
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
@ -9,9 +10,7 @@ import Lean.Elab.Command
import Lean.Elab.DefView
import Lean.Elab.PreDefinition
namespace Lean
namespace Elab
open MonadResolveName (getCurrNamespace getOpenDecls) -- HACK for old frontend
namespace Lean.Elab
/- DefView after elaborating the header. -/
structure DefViewElabHeader :=
@ -33,45 +32,42 @@ namespace Term
open Meta
private def checkModifiers (m₁ m₂ : Modifiers) : TermElabM Unit := do
unless (m₁.isUnsafe == m₂.isUnsafe) $
throwError "cannot mix unsafe and safe definitions";
unless (m₁.isNoncomputable == m₂.isNoncomputable) $
throwError "cannot mix computable and non-computable definitions";
unless (m₁.isPartial == m₂.isPartial) $
throwError "cannot mix partial and non-partial definitions";
pure ()
unless m₁.isUnsafe == m₂.isUnsafe do
throwError "cannot mix unsafe and safe definitions"
unless m₁.isNoncomputable == m₂.isNoncomputable do
throwError "cannot mix computable and non-computable definitions"
unless m₁.isPartial == m₂.isPartial do
throwError "cannot mix partial and non-partial definitions"
private def checkKinds (k₁ k₂ : DefKind) : TermElabM Unit := do
unless (k₁.isExample == k₂.isExample) $
throwError "cannot mix examples and definitions"; -- Reason: we should discard examples
unless (k₁.isTheorem == k₂.isTheorem) $
throwError "cannot mix theorems and definitions"; -- Reason: we will eventually elaborate theorems in `Task`s.
pure ()
unless k₁.isExample == k₂.isExample do
throwError "cannot mix examples and definitions" -- Reason: we should discard examples
unless k₁.isTheorem == k₂.isTheorem do
throwError "cannot mix theorems and definitions" -- Reason: we will eventually elaborate theorems in `Task`s.
private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewElabHeader) : TermElabM Unit := do
when (newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe) $
throwError "'unsafe' theorems are not allowed";
when (newHeader.kind.isTheorem && newHeader.modifiers.isPartial) $
throwError "'partial' theorems are not allowed, 'partial' is a code generation directive";
when (newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable) $
throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems";
when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe) $
throwError "'noncomputable unsafe' is not allowed";
when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial) $
throwError "'noncomputable partial' is not allowed";
when (newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe) $
throwError "'unsafe' subsumes 'partial'";
if newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe then
throwError "'unsafe' theorems are not allowed"
if newHeader.kind.isTheorem && newHeader.modifiers.isPartial then
throwError "'partial' theorems are not allowed, 'partial' is a code generation directive"
if newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable then
throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems"
if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe then
throwError "'noncomputable unsafe' is not allowed"
if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial then
throwError "'noncomputable partial' is not allowed"
if newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe then
throwError "'unsafe' subsumes 'partial'"
if h : 0 < prevHeaders.size then
let firstHeader := prevHeaders.get ⟨0, h⟩;
let firstHeader := prevHeaders.get ⟨0, h⟩
try
unless newHeader.levelNames == firstHeader.levelNames do
throwError "universe parameters mismatch"
checkModifiers newHeader.modifiers firstHeader.modifiers
checkKinds newHeader.kind firstHeader.kind
catch
(do
unless (newHeader.levelNames == firstHeader.levelNames) $
throwError "universe parameters mismatch";
checkModifiers newHeader.modifiers firstHeader.modifiers;
checkKinds newHeader.kind firstHeader.kind)
(fun ex => match ex with
| Exception.error ref msg => throw (Exception.error ref ("invalid mutually recursive definitions, " ++ msg))
| _ => throw ex)
| Exception.error ref msg => throw (Exception.error ref msg!"invalid mutually recursive definitions, {msg}")
| ex => throw ex
else
pure ()
@ -80,29 +76,28 @@ registerCustomErrorIfMVar type ref "failed to infer definition type"
private def elabFunType (ref : Syntax) (xs : Array Expr) (view : DefView) : TermElabM Expr := do
match view.type? with
| some typeStx => do
type ← elabType typeStx;
synthesizeSyntheticMVarsNoPostponing;
type ← instantiateMVars type;
registerFailedToInferDefTypeInfo type typeStx;
| some typeStx =>
let type ← elabType typeStx
synthesizeSyntheticMVarsNoPostponing
let type ← instantiateMVars type
registerFailedToInferDefTypeInfo type typeStx
mkForallFVars xs type
| none => do
let hole := mkHole ref;
type ← elabType hole;
registerFailedToInferDefTypeInfo type ref;
| none =>
let hole := mkHole ref
let type ← elabType hole
registerFailedToInferDefTypeInfo type ref
mkForallFVars xs type
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) :=
views.foldlM
(fun (headers : Array DefViewElabHeader) (view : DefView) => withRef view.ref do
currNamespace ← getCurrNamespace;
currLevelNames ← getLevelNames;
⟨shortDeclName, declName, levelNames⟩ ← expandDeclId currNamespace currLevelNames view.declId view.modifiers;
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration;
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do
let headers := #[]
for view in views do
let newHeader ← withRef view.ref do
let ⟨shortDeclName, declName, levelNames⟩ ← expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration
withLevelNames levelNames $ elabBinders view.binders.getArgs fun xs => do
let refForElabFunType := view.value;
type ← elabFunType refForElabFunType xs view;
let newHeader : DefViewElabHeader := {
let refForElabFunType := view.value
let type ← elabFunType refForElabFunType xs view
let newHeader := {
ref := view.ref,
modifiers := view.modifiers,
kind := view.kind,
@ -111,22 +106,20 @@ views.foldlM
levelNames := levelNames,
numParams := xs.size,
type := type,
valueStx := view.value
};
check headers newHeader;
pure $ headers.push newHeader)
#[]
valueStx := view.value : DefViewElabHeader }
check headers newHeader
pure newHeader
headers := headers.push newHeader
pure headers
private partial def withFunLocalDeclsAux {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : Nat → Array Expr → TermElabM α
| i, fvars =>
if h : i < headers.size then do
let header := headers.get ⟨i, h⟩;
withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => withFunLocalDeclsAux (i+1) (fvars.push fvar)
private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (fvars : Array Expr) := do
if h : i < headers.size then
let header := headers.get ⟨i, h⟩
withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => loop (i+1) (fvars.push fvar)
else
k fvars
private def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
withFunLocalDeclsAux headers k 0 #[]
loop 0 #[]
/-
Recall that
@ -142,71 +135,69 @@ def declVal := declValSimple <|> declValEqns
-/
private def declValToTerm (declVal : Syntax) : MacroM Syntax :=
if declVal.isOfKind `Lean.Parser.Command.declValSimple then
pure $ declVal.getArg 1
pure declVal[1]
else if declVal.isOfKind `Lean.Parser.Command.declValEqns then
expandMatchAltsIntoMatch declVal (declVal.getArg 0)
expandMatchAltsIntoMatch declVal declVal[0]
else
Macro.throwError declVal "unexpected definition value"
private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
headers.mapM fun header => withDeclName header.declName $ withLevelNames header.levelNames do
valStx ← liftMacroM $ declValToTerm header.valueStx;
let valStx ← liftMacroM $ declValToTerm header.valueStx
forallBoundedTelescope header.type header.numParams fun xs type => do
val ← elabTermEnsuringType valStx type;
let val ← elabTermEnsuringType valStx type
mkLambdaFVars xs val
private def collectUsed (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
: StateRefT CollectFVars.State TermElabM Unit := do
headers.forM fun header => collectUsedFVars header.type;
values.forM collectUsedFVars;
toLift.forM fun letRecToLift => do {
collectUsedFVars letRecToLift.type;
headers.forM fun header => collectUsedFVars header.type
values.forM collectUsedFVars
toLift.forM fun letRecToLift => do
collectUsedFVars letRecToLift.type
collectUsedFVars letRecToLift.val
}
private def removeUnusedVars (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
(_, used) ← (collectUsed headers values toLift).run {};
let (_, used) ← (collectUsed headers values toLift).run {}
removeUnused vars used
private def withUsedWhen {α} (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
(cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α :=
if cond then do
(lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift;
(cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α := do
if cond then
let (lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift
withLCtx lctx localInsts $ k vars
else
k vars
private def isExample (views : Array DefView) : Bool :=
views.any fun view => view.kind.isExample
views.any (·.kind.isExample)
private def isTheorem (views : Array DefView) : Bool :=
views.any fun view => view.kind.isTheorem
views.any (·.kind.isTheorem)
private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM DefViewElabHeader := do
type ← instantiateMVars header.type;
let type ← instantiateMVars header.type
pure { header with type := type }
private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do
type ← instantiateMVars toLift.type;
val ← instantiateMVars toLift.val;
let type ← instantiateMVars toLift.type
let val ← instantiateMVars toLift.val
pure { toLift with type := type, val := val }
private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId :=
let occ? := type.find? fun e => match e with
| Expr.fvar fvarId _ => funFVars.contains e || letRecsToLift.any fun toLift => toLift.fvarId == fvarId
| _ => false;
| _ => false
match occ? with
| some (Expr.fvar fvarId _) => some fvarId
| _ => none
private def getFunName (fvarId : FVarId) (letRecsToLift : List LetRecToLift) : TermElabM Name := do
decl? ← findLocalDecl? fvarId;
match decl? with
match (← findLocalDecl? fvarId) with
| some decl => pure decl.userName
| none =>
/- Recall that the FVarId of nested let-recs are not in the current local context. -/
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with
match letRecsToLift.findSome? fun toLift => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with
| none => throwError "unknown function"
| some n => pure n
@ -216,12 +207,12 @@ In principle, this test can be improved. We could perform it after we separate t
However, this extra complication doesn't seem worth it.
-/
private def checkLetRecsToLiftTypes (funVars : Array Expr) (letRecsToLift : List LetRecToLift) : TermElabM Unit :=
letRecsToLift.forM fun toLift => do
letRecsToLift.forM fun toLift =>
match typeHasRecFun toLift.type funVars letRecsToLift with
| none => pure ()
| some fvarId => do
fnName ← getFunName fvarId letRecsToLift;
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
let fnName ← getFunName fvarId letRecsToLift
throwErrorAt! toLift.ref "invalid type in 'let rec', it uses '{fnName}' which is being defined simultaneously"
namespace MutualClosure
@ -274,28 +265,26 @@ Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover
`f` depends on `g` because it contains `m₂`
-/
private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift)
: UsedFVarsMap :=
let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {};
let usedFVarMap := mainFVarIds.foldl
(fun (usedFVarMap : UsedFVarsMap) mainFVarId =>
usedFVarMap.insert mainFVarId sectionVarSet)
{};
letRecsToLift.foldl
(fun (usedFVarMap : UsedFVarsMap) toLift =>
let state := Lean.collectFVars {} toLift.val;
let state := Lean.collectFVars state toLift.type;
let set := state.fvarSet;
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
for the associated let-rec because we need this information to compute the fixpoint later. -/
let mvarIds := (toLift.val.collectMVars {}).result;
let set := mvarIds.foldl
(fun (set : NameSet) (mvarId : MVarId) =>
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
| some fvarId => set.insert fvarId
| none => set)
set;
usedFVarMap.insert toLift.fvarId set)
usedFVarMap
: UsedFVarsMap := do
let sectionVarSet := {}
for var in sectionVars do
sectionVarSet := sectionVarSet.insert var.fvarId!
let usedFVarMap := {}
for mainFVarId in mainFVarIds do
usedFVarMap := usedFVarMap.insert mainFVarId sectionVarSet
for toLift in letRecsToLift do
let state := Lean.collectFVars {} toLift.val
let state := Lean.collectFVars state toLift.type
let set := state.fvarSet
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
for the associated let-rec because we need this information to compute the fixpoint later. -/
let mvarIds := (toLift.val.collectMVars {}).result
for mvarId in mvarIds do
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
| some fvarId => set := set.insert fvarId
| none => pure ()
usedFVarMap := usedFVarMap.insert toLift.fvarId set
pure usedFVarMap
/-
The let-recs may invoke each other. Example:
@ -324,10 +313,10 @@ structure State :=
abbrev M := ReaderT (List FVarId) $ StateM State
private def isModified : M Bool := do s ← get; pure s.modified
private def isModified : M Bool := do pure (← get).modified
private def resetModified : M Unit := modify fun s => { s with modified := false }
private def markModified : M Unit := modify fun s => { s with modified := true }
private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap
private def getUsedFVarsMap : M UsedFVarsMap := do pure (← get).usedFVarsMap
private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap }
-- merge s₂ into s₁
@ -337,16 +326,16 @@ s₂.foldM
if s₁.contains k then
pure s₁
else do
markModified;
markModified
pure $ s₁.insert k)
s₁
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
usedFVarsMap ← getUsedFVarsMap;
let usedFVarsMap ← getUsedFVarsMap
match usedFVarsMap.find? fvarId with
| none => pure ()
| some fvarIds => do
fvarIdsNew ← fvarIds.foldM
| some fvarIds =>
let fvarIdsNew ← fvarIds.foldM
(fun (fvarIdsNew : NameSet) (fvarId' : FVarId) =>
if fvarId == fvarId' then
pure fvarIdsNew
@ -357,18 +346,19 @@ match usedFVarsMap.find? fvarId with
not in the context of the let-rec associated with fvarId.
We filter these out-of-context free variables later. -/
| some otherFVarIds => merge fvarIdsNew otherFVarIds)
fvarIds;
fvarIds
modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew
private partial def fixpoint : Unit → M Unit
| _ => do
resetModified;
letRecFVarIds ← read;
letRecFVarIds.forM updateUsedVarsOf;
whenM isModified $ fixpoint ()
resetModified
let letRecFVarIds ← read
letRecFVarIds.forM updateUsedVarsOf
if (← isModified) then
fixpoint ()
def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap };
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }
s.usedFVarsMap
end FixPoint
@ -377,23 +367,23 @@ abbrev FreeVarMap := NameMap (Array FVarId)
private def mkFreeVarMap
(mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId)
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap :=
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift;
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap;
letRecsToLift.foldl
(fun (freeVarMap : FreeVarMap) toLift =>
let lctx := toLift.lctx;
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!;
let fvarIds := fvarIdsSet.fold
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
if lctx.contains fvarId && !recFVarIds.contains fvarId then
fvarIds.push fvarId
else
fvarIds)
#[];
freeVarMap.insert toLift.fvarId fvarIds)
{}
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := do
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
let freeVarMap := {}
for toLift in letRecsToLift do
let lctx := toLift.lctx
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!
let fvarIds := fvarIdsSet.fold
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
if lctx.contains fvarId && !recFVarIds.contains fvarId then
fvarIds.push fvarId
else
fvarIds)
#[]
freeVarMap := freeVarMap.insert toLift.fvarId fvarIds
pure freeVarMap
structure ClosureState :=
(newLocalDecls : Array LocalDecl := #[])
@ -405,9 +395,9 @@ private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option
fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
private def preprocess (e : Expr) : TermElabM Expr := do
e ← instantiateMVars e;
let e ← instantiateMVars e
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
liftM $ check e;
Meta.check e
pure e
/- Push free variables in `s` to `toProcess` if they are not already there. -/
@ -418,47 +408,47 @@ s.fvarSet.fold
private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default)
: StateRefT ClosureState TermElabM (Array FVarId) := do
type ← liftM $ preprocess type;
let type ← preprocess type
modify fun s => { s with
newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi,
exprArgs := s.exprArgs.push (mkFVar fvarId)
};
}
pure $ pushNewVars toProcess (collectFVars {} type)
private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit
| toProcess => do
lctx ← getLCtx;
let lctx ← getLCtx
match pickMaxFVar? lctx toProcess with
| none => pure ()
| some fvarId => do
trace `Elab.definition.mkClosure fun _ => "toProcess: " ++ (toProcess.map mkFVar) ++ ", maxVar: " ++ mkFVar fvarId;
let toProcess := toProcess.erase fvarId;
localDecl ← getLocalDecl fvarId;
| some fvarId =>
trace[Elab.definition.mkClosure]! "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}"
let toProcess := toProcess.erase fvarId
let localDecl ← getLocalDecl fvarId
match localDecl with
| LocalDecl.cdecl _ _ userName type bi => do
toProcess ← pushLocalDecl toProcess fvarId userName type bi;
| LocalDecl.cdecl _ _ userName type bi =>
let toProcess ← pushLocalDecl toProcess fvarId userName type bi
mkClosureForAux toProcess
| LocalDecl.ldecl _ _ userName type val _ => do
zetaFVarIds ← getZetaFVarIds;
if !zetaFVarIds.contains fvarId then do
| LocalDecl.ldecl _ _ userName type val _ =>
let zetaFVarIds ← getZetaFVarIds
if !zetaFVarIds.contains fvarId then
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
toProcess ← pushLocalDecl toProcess fvarId userName type;
let toProcess ← pushLocalDecl toProcess fvarId userName type
mkClosureForAux toProcess
else do
else
/- Dependent let-decl. -/
type ← liftM $ preprocess type;
val ← liftM $ preprocess val;
let type ← preprocess type
let val ← preprocess val
modify fun s => { s with
newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false,
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
at `newLocalDecls` and `localDecls` -/
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
};
}
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do
(_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls };
let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls }
pure { s with
newLocalDecls := s.newLocalDecls.reverse,
newLetDecls := s.newLetDecls.reverse,
@ -470,16 +460,16 @@ structure LetRecClosure :=
(toLift : LetRecToLift)
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
let lctx := toLift.lctx;
let lctx := toLift.lctx
withLCtx lctx toLift.localInstances do
lambdaTelescope toLift.val fun xs val => do
type ← instantiateForall toLift.type xs;
lctx ← getLCtx;
s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId!;
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type;
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val;
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs;
assignExprMVar toLift.mvarId c;
let type ← instantiateForall toLift.type xs
let lctx ← getLCtx
let s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId!
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs
assignExprMVar toLift.mvarId c
pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩
private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) :=
@ -508,9 +498,9 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
: TermElabM (Array PreDefinition) :=
mainHeaders.size.foldM
(fun i (preDefs : Array PreDefinition) => do
let header := mainHeaders.get! i;
val ← mkLambdaFVars sectionVars (mainVals.get! i);
type ← mkForallFVars sectionVars header.type;
let header := mainHeaders[i]
let val ← mkLambdaFVars sectionVars mainVals[i]
let type ← mkForallFVars sectionVars header.type
pure $ preDefs.push {
kind := header.kind,
declName := header.declName,
@ -524,8 +514,8 @@ mainHeaders.size.foldM
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
letRecClosures.foldl
(fun (preDefs : Array PreDefinition) (c : LetRecClosure) =>
let type := Closure.mkForall c.localDecls c.toLift.type;
let val := Closure.mkLambda c.localDecls c.toLift.val;
let type := Closure.mkForall c.localDecls c.toLift.type
let val := Closure.mkLambda c.localDecls c.toLift.val
preDefs.push {
kind := kind,
declName := c.toLift.declName,
@ -555,29 +545,29 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers :
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
: TermElabM (Array PreDefinition) := do
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
let mainFVarIds := mainFVars.map Expr.fvarId!;
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
let mainFVarIds := mainFVars.map Expr.fvarId!
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds
-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec.
mctx ← getMCtx;
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift;
resetZetaFVarIds;
let mctx ← getMCtx
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift
resetZetaFVarIds
withTrackingZeta do
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val };
letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap;
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
mainVals ← mainVals.mapM instantiateMVars;
mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader;
letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } };
-- Replace fvarIds for functions being defined with closed terms
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars;
let r := insertReplacementForLetRecs r letRecClosures;
let mainVals := mainVals.map r.apply;
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type };
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } };
let letRecKind := getKindForLetRecs mainHeaders;
let letRecMods := getModifiersForLetRecs mainHeaders;
pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do Meta.check toLift.type; Meta.check toLift.val
let letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
let mainVals ← mainVals.mapM instantiateMVars
let mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader
let letRecClosures ← letRecClosures.mapM fun closure => do pure { closure with toLift := (← instantiateMVarsAtLetRecToLift closure.toLift) }
-- Replace fvarIds for functions being defined with closed terms
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars
let r := insertReplacementForLetRecs r letRecClosures
let mainVals := mainVals.map r.apply
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type }
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } }
let letRecKind := getKindForLetRecs mainHeaders
let letRecMods := getModifiersForLetRecs mainHeaders
pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals
end MutualClosure
@ -589,34 +579,34 @@ else
[]
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
scopeLevelNames ← getLevelNames;
headers ← elabHeaders views;
let allUserLevelNames := getAllUserLevelNames headers;
let scopeLevelNames ← getLevelNames
let headers ← elabHeaders views
let allUserLevelNames := getAllUserLevelNames headers
withFunLocalDecls headers fun funFVars => do
values ← elabFunValues headers;
Term.synthesizeSyntheticMVarsNoPostponing;
if isExample views then pure ()
else do
values ← values.mapM instantiateMVars;
headers ← headers.mapM instantiateMVarsAtHeader;
letRecsToLift ← getLetRecsToLift;
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
checkLetRecsToLiftTypes funFVars letRecsToLift;
let values ← elabFunValues headers
Term.synthesizeSyntheticMVarsNoPostponing
if isExample views then
pure ()
else
let values ← values.mapM instantiateMVars
let headers ← headers.mapM instantiateMVarsAtHeader
let letRecsToLift ← getLetRecsToLift
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
checkLetRecsToLiftTypes funFVars letRecsToLift
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift;
preDefs ← levelMVarToParamPreDecls preDefs;
preDefs ← instantiateMVarsAtPreDecls preDefs;
preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames;
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
let preDefs ← levelMVarToParamPreDecls preDefs
let preDefs ← instantiateMVarsAtPreDecls preDefs
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
addPreDefinitions preDefs
end Term
namespace Command
def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
views ← ds.mapM fun d => do {
modifiers ← elabModifiers (d.getArg 0);
mkDefView modifiers (d.getArg 1)
};
let views ← ds.mapM fun d => do
let modifiers ← elabModifiers d[0]
mkDefView modifiers d[1]
runTermElabM none fun vars => Term.elabMutualDef vars views
end Command

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff