chore: update stage0
This commit is contained in:
parent
6c6f3dca87
commit
ef04995f0e
3 changed files with 4187 additions and 4596 deletions
426
stage0/src/Lean/Elab/MutualDef.lean
generated
426
stage0/src/Lean/Elab/MutualDef.lean
generated
|
|
@ -1,3 +1,4 @@
|
|||
#lang lean4
|
||||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
|
@ -9,9 +10,7 @@ import Lean.Elab.Command
|
|||
import Lean.Elab.DefView
|
||||
import Lean.Elab.PreDefinition
|
||||
|
||||
namespace Lean
|
||||
namespace Elab
|
||||
open MonadResolveName (getCurrNamespace getOpenDecls) -- HACK for old frontend
|
||||
namespace Lean.Elab
|
||||
|
||||
/- DefView after elaborating the header. -/
|
||||
structure DefViewElabHeader :=
|
||||
|
|
@ -33,45 +32,42 @@ namespace Term
|
|||
open Meta
|
||||
|
||||
private def checkModifiers (m₁ m₂ : Modifiers) : TermElabM Unit := do
|
||||
unless (m₁.isUnsafe == m₂.isUnsafe) $
|
||||
throwError "cannot mix unsafe and safe definitions";
|
||||
unless (m₁.isNoncomputable == m₂.isNoncomputable) $
|
||||
throwError "cannot mix computable and non-computable definitions";
|
||||
unless (m₁.isPartial == m₂.isPartial) $
|
||||
throwError "cannot mix partial and non-partial definitions";
|
||||
pure ()
|
||||
unless m₁.isUnsafe == m₂.isUnsafe do
|
||||
throwError "cannot mix unsafe and safe definitions"
|
||||
unless m₁.isNoncomputable == m₂.isNoncomputable do
|
||||
throwError "cannot mix computable and non-computable definitions"
|
||||
unless m₁.isPartial == m₂.isPartial do
|
||||
throwError "cannot mix partial and non-partial definitions"
|
||||
|
||||
private def checkKinds (k₁ k₂ : DefKind) : TermElabM Unit := do
|
||||
unless (k₁.isExample == k₂.isExample) $
|
||||
throwError "cannot mix examples and definitions"; -- Reason: we should discard examples
|
||||
unless (k₁.isTheorem == k₂.isTheorem) $
|
||||
throwError "cannot mix theorems and definitions"; -- Reason: we will eventually elaborate theorems in `Task`s.
|
||||
pure ()
|
||||
unless k₁.isExample == k₂.isExample do
|
||||
throwError "cannot mix examples and definitions" -- Reason: we should discard examples
|
||||
unless k₁.isTheorem == k₂.isTheorem do
|
||||
throwError "cannot mix theorems and definitions" -- Reason: we will eventually elaborate theorems in `Task`s.
|
||||
|
||||
private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewElabHeader) : TermElabM Unit := do
|
||||
when (newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe) $
|
||||
throwError "'unsafe' theorems are not allowed";
|
||||
when (newHeader.kind.isTheorem && newHeader.modifiers.isPartial) $
|
||||
throwError "'partial' theorems are not allowed, 'partial' is a code generation directive";
|
||||
when (newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable) $
|
||||
throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems";
|
||||
when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe) $
|
||||
throwError "'noncomputable unsafe' is not allowed";
|
||||
when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial) $
|
||||
throwError "'noncomputable partial' is not allowed";
|
||||
when (newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe) $
|
||||
throwError "'unsafe' subsumes 'partial'";
|
||||
if newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe then
|
||||
throwError "'unsafe' theorems are not allowed"
|
||||
if newHeader.kind.isTheorem && newHeader.modifiers.isPartial then
|
||||
throwError "'partial' theorems are not allowed, 'partial' is a code generation directive"
|
||||
if newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable then
|
||||
throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems"
|
||||
if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe then
|
||||
throwError "'noncomputable unsafe' is not allowed"
|
||||
if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial then
|
||||
throwError "'noncomputable partial' is not allowed"
|
||||
if newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe then
|
||||
throwError "'unsafe' subsumes 'partial'"
|
||||
if h : 0 < prevHeaders.size then
|
||||
let firstHeader := prevHeaders.get ⟨0, h⟩;
|
||||
let firstHeader := prevHeaders.get ⟨0, h⟩
|
||||
try
|
||||
unless newHeader.levelNames == firstHeader.levelNames do
|
||||
throwError "universe parameters mismatch"
|
||||
checkModifiers newHeader.modifiers firstHeader.modifiers
|
||||
checkKinds newHeader.kind firstHeader.kind
|
||||
catch
|
||||
(do
|
||||
unless (newHeader.levelNames == firstHeader.levelNames) $
|
||||
throwError "universe parameters mismatch";
|
||||
checkModifiers newHeader.modifiers firstHeader.modifiers;
|
||||
checkKinds newHeader.kind firstHeader.kind)
|
||||
(fun ex => match ex with
|
||||
| Exception.error ref msg => throw (Exception.error ref ("invalid mutually recursive definitions, " ++ msg))
|
||||
| _ => throw ex)
|
||||
| Exception.error ref msg => throw (Exception.error ref msg!"invalid mutually recursive definitions, {msg}")
|
||||
| ex => throw ex
|
||||
else
|
||||
pure ()
|
||||
|
||||
|
|
@ -80,29 +76,28 @@ registerCustomErrorIfMVar type ref "failed to infer definition type"
|
|||
|
||||
private def elabFunType (ref : Syntax) (xs : Array Expr) (view : DefView) : TermElabM Expr := do
|
||||
match view.type? with
|
||||
| some typeStx => do
|
||||
type ← elabType typeStx;
|
||||
synthesizeSyntheticMVarsNoPostponing;
|
||||
type ← instantiateMVars type;
|
||||
registerFailedToInferDefTypeInfo type typeStx;
|
||||
| some typeStx =>
|
||||
let type ← elabType typeStx
|
||||
synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← instantiateMVars type
|
||||
registerFailedToInferDefTypeInfo type typeStx
|
||||
mkForallFVars xs type
|
||||
| none => do
|
||||
let hole := mkHole ref;
|
||||
type ← elabType hole;
|
||||
registerFailedToInferDefTypeInfo type ref;
|
||||
| none =>
|
||||
let hole := mkHole ref
|
||||
let type ← elabType hole
|
||||
registerFailedToInferDefTypeInfo type ref
|
||||
mkForallFVars xs type
|
||||
|
||||
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) :=
|
||||
views.foldlM
|
||||
(fun (headers : Array DefViewElabHeader) (view : DefView) => withRef view.ref do
|
||||
currNamespace ← getCurrNamespace;
|
||||
currLevelNames ← getLevelNames;
|
||||
⟨shortDeclName, declName, levelNames⟩ ← expandDeclId currNamespace currLevelNames view.declId view.modifiers;
|
||||
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration;
|
||||
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do
|
||||
let headers := #[]
|
||||
for view in views do
|
||||
let newHeader ← withRef view.ref do
|
||||
let ⟨shortDeclName, declName, levelNames⟩ ← expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers
|
||||
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration
|
||||
withLevelNames levelNames $ elabBinders view.binders.getArgs fun xs => do
|
||||
let refForElabFunType := view.value;
|
||||
type ← elabFunType refForElabFunType xs view;
|
||||
let newHeader : DefViewElabHeader := {
|
||||
let refForElabFunType := view.value
|
||||
let type ← elabFunType refForElabFunType xs view
|
||||
let newHeader := {
|
||||
ref := view.ref,
|
||||
modifiers := view.modifiers,
|
||||
kind := view.kind,
|
||||
|
|
@ -111,22 +106,20 @@ views.foldlM
|
|||
levelNames := levelNames,
|
||||
numParams := xs.size,
|
||||
type := type,
|
||||
valueStx := view.value
|
||||
};
|
||||
check headers newHeader;
|
||||
pure $ headers.push newHeader)
|
||||
#[]
|
||||
valueStx := view.value : DefViewElabHeader }
|
||||
check headers newHeader
|
||||
pure newHeader
|
||||
headers := headers.push newHeader
|
||||
pure headers
|
||||
|
||||
private partial def withFunLocalDeclsAux {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : Nat → Array Expr → TermElabM α
|
||||
| i, fvars =>
|
||||
if h : i < headers.size then do
|
||||
let header := headers.get ⟨i, h⟩;
|
||||
withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => withFunLocalDeclsAux (i+1) (fvars.push fvar)
|
||||
private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
|
||||
let rec loop (i : Nat) (fvars : Array Expr) := do
|
||||
if h : i < headers.size then
|
||||
let header := headers.get ⟨i, h⟩
|
||||
withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => loop (i+1) (fvars.push fvar)
|
||||
else
|
||||
k fvars
|
||||
|
||||
private def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
|
||||
withFunLocalDeclsAux headers k 0 #[]
|
||||
loop 0 #[]
|
||||
|
||||
/-
|
||||
Recall that
|
||||
|
|
@ -142,71 +135,69 @@ def declVal := declValSimple <|> declValEqns
|
|||
-/
|
||||
private def declValToTerm (declVal : Syntax) : MacroM Syntax :=
|
||||
if declVal.isOfKind `Lean.Parser.Command.declValSimple then
|
||||
pure $ declVal.getArg 1
|
||||
pure declVal[1]
|
||||
else if declVal.isOfKind `Lean.Parser.Command.declValEqns then
|
||||
expandMatchAltsIntoMatch declVal (declVal.getArg 0)
|
||||
expandMatchAltsIntoMatch declVal declVal[0]
|
||||
else
|
||||
Macro.throwError declVal "unexpected definition value"
|
||||
|
||||
private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
|
||||
headers.mapM fun header => withDeclName header.declName $ withLevelNames header.levelNames do
|
||||
valStx ← liftMacroM $ declValToTerm header.valueStx;
|
||||
let valStx ← liftMacroM $ declValToTerm header.valueStx
|
||||
forallBoundedTelescope header.type header.numParams fun xs type => do
|
||||
val ← elabTermEnsuringType valStx type;
|
||||
let val ← elabTermEnsuringType valStx type
|
||||
mkLambdaFVars xs val
|
||||
|
||||
private def collectUsed (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
|
||||
: StateRefT CollectFVars.State TermElabM Unit := do
|
||||
headers.forM fun header => collectUsedFVars header.type;
|
||||
values.forM collectUsedFVars;
|
||||
toLift.forM fun letRecToLift => do {
|
||||
collectUsedFVars letRecToLift.type;
|
||||
headers.forM fun header => collectUsedFVars header.type
|
||||
values.forM collectUsedFVars
|
||||
toLift.forM fun letRecToLift => do
|
||||
collectUsedFVars letRecToLift.type
|
||||
collectUsedFVars letRecToLift.val
|
||||
}
|
||||
|
||||
private def removeUnusedVars (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
|
||||
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
(_, used) ← (collectUsed headers values toLift).run {};
|
||||
let (_, used) ← (collectUsed headers values toLift).run {}
|
||||
removeUnused vars used
|
||||
|
||||
private def withUsedWhen {α} (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
|
||||
(cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α :=
|
||||
if cond then do
|
||||
(lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift;
|
||||
(cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α := do
|
||||
if cond then
|
||||
let (lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift
|
||||
withLCtx lctx localInsts $ k vars
|
||||
else
|
||||
k vars
|
||||
|
||||
private def isExample (views : Array DefView) : Bool :=
|
||||
views.any fun view => view.kind.isExample
|
||||
views.any (·.kind.isExample)
|
||||
|
||||
private def isTheorem (views : Array DefView) : Bool :=
|
||||
views.any fun view => view.kind.isTheorem
|
||||
views.any (·.kind.isTheorem)
|
||||
|
||||
private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM DefViewElabHeader := do
|
||||
type ← instantiateMVars header.type;
|
||||
let type ← instantiateMVars header.type
|
||||
pure { header with type := type }
|
||||
|
||||
private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do
|
||||
type ← instantiateMVars toLift.type;
|
||||
val ← instantiateMVars toLift.val;
|
||||
let type ← instantiateMVars toLift.type
|
||||
let val ← instantiateMVars toLift.val
|
||||
pure { toLift with type := type, val := val }
|
||||
|
||||
private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId :=
|
||||
let occ? := type.find? fun e => match e with
|
||||
| Expr.fvar fvarId _ => funFVars.contains e || letRecsToLift.any fun toLift => toLift.fvarId == fvarId
|
||||
| _ => false;
|
||||
| _ => false
|
||||
match occ? with
|
||||
| some (Expr.fvar fvarId _) => some fvarId
|
||||
| _ => none
|
||||
|
||||
private def getFunName (fvarId : FVarId) (letRecsToLift : List LetRecToLift) : TermElabM Name := do
|
||||
decl? ← findLocalDecl? fvarId;
|
||||
match decl? with
|
||||
match (← findLocalDecl? fvarId) with
|
||||
| some decl => pure decl.userName
|
||||
| none =>
|
||||
/- Recall that the FVarId of nested let-recs are not in the current local context. -/
|
||||
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with
|
||||
match letRecsToLift.findSome? fun toLift => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with
|
||||
| none => throwError "unknown function"
|
||||
| some n => pure n
|
||||
|
||||
|
|
@ -216,12 +207,12 @@ In principle, this test can be improved. We could perform it after we separate t
|
|||
However, this extra complication doesn't seem worth it.
|
||||
-/
|
||||
private def checkLetRecsToLiftTypes (funVars : Array Expr) (letRecsToLift : List LetRecToLift) : TermElabM Unit :=
|
||||
letRecsToLift.forM fun toLift => do
|
||||
letRecsToLift.forM fun toLift =>
|
||||
match typeHasRecFun toLift.type funVars letRecsToLift with
|
||||
| none => pure ()
|
||||
| some fvarId => do
|
||||
fnName ← getFunName fvarId letRecsToLift;
|
||||
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
|
||||
let fnName ← getFunName fvarId letRecsToLift
|
||||
throwErrorAt! toLift.ref "invalid type in 'let rec', it uses '{fnName}' which is being defined simultaneously"
|
||||
|
||||
namespace MutualClosure
|
||||
|
||||
|
|
@ -274,28 +265,26 @@ Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover
|
|||
`f` depends on `g` because it contains `m₂`
|
||||
-/
|
||||
private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift)
|
||||
: UsedFVarsMap :=
|
||||
let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {};
|
||||
let usedFVarMap := mainFVarIds.foldl
|
||||
(fun (usedFVarMap : UsedFVarsMap) mainFVarId =>
|
||||
usedFVarMap.insert mainFVarId sectionVarSet)
|
||||
{};
|
||||
letRecsToLift.foldl
|
||||
(fun (usedFVarMap : UsedFVarsMap) toLift =>
|
||||
let state := Lean.collectFVars {} toLift.val;
|
||||
let state := Lean.collectFVars state toLift.type;
|
||||
let set := state.fvarSet;
|
||||
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
|
||||
for the associated let-rec because we need this information to compute the fixpoint later. -/
|
||||
let mvarIds := (toLift.val.collectMVars {}).result;
|
||||
let set := mvarIds.foldl
|
||||
(fun (set : NameSet) (mvarId : MVarId) =>
|
||||
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
|
||||
| some fvarId => set.insert fvarId
|
||||
| none => set)
|
||||
set;
|
||||
usedFVarMap.insert toLift.fvarId set)
|
||||
usedFVarMap
|
||||
: UsedFVarsMap := do
|
||||
let sectionVarSet := {}
|
||||
for var in sectionVars do
|
||||
sectionVarSet := sectionVarSet.insert var.fvarId!
|
||||
let usedFVarMap := {}
|
||||
for mainFVarId in mainFVarIds do
|
||||
usedFVarMap := usedFVarMap.insert mainFVarId sectionVarSet
|
||||
for toLift in letRecsToLift do
|
||||
let state := Lean.collectFVars {} toLift.val
|
||||
let state := Lean.collectFVars state toLift.type
|
||||
let set := state.fvarSet
|
||||
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
|
||||
for the associated let-rec because we need this information to compute the fixpoint later. -/
|
||||
let mvarIds := (toLift.val.collectMVars {}).result
|
||||
for mvarId in mvarIds do
|
||||
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
|
||||
| some fvarId => set := set.insert fvarId
|
||||
| none => pure ()
|
||||
usedFVarMap := usedFVarMap.insert toLift.fvarId set
|
||||
pure usedFVarMap
|
||||
|
||||
/-
|
||||
The let-recs may invoke each other. Example:
|
||||
|
|
@ -324,10 +313,10 @@ structure State :=
|
|||
|
||||
abbrev M := ReaderT (List FVarId) $ StateM State
|
||||
|
||||
private def isModified : M Bool := do s ← get; pure s.modified
|
||||
private def isModified : M Bool := do pure (← get).modified
|
||||
private def resetModified : M Unit := modify fun s => { s with modified := false }
|
||||
private def markModified : M Unit := modify fun s => { s with modified := true }
|
||||
private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap
|
||||
private def getUsedFVarsMap : M UsedFVarsMap := do pure (← get).usedFVarsMap
|
||||
private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap }
|
||||
|
||||
-- merge s₂ into s₁
|
||||
|
|
@ -337,16 +326,16 @@ s₂.foldM
|
|||
if s₁.contains k then
|
||||
pure s₁
|
||||
else do
|
||||
markModified;
|
||||
markModified
|
||||
pure $ s₁.insert k)
|
||||
s₁
|
||||
|
||||
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
|
||||
usedFVarsMap ← getUsedFVarsMap;
|
||||
let usedFVarsMap ← getUsedFVarsMap
|
||||
match usedFVarsMap.find? fvarId with
|
||||
| none => pure ()
|
||||
| some fvarIds => do
|
||||
fvarIdsNew ← fvarIds.foldM
|
||||
| some fvarIds =>
|
||||
let fvarIdsNew ← fvarIds.foldM
|
||||
(fun (fvarIdsNew : NameSet) (fvarId' : FVarId) =>
|
||||
if fvarId == fvarId' then
|
||||
pure fvarIdsNew
|
||||
|
|
@ -357,18 +346,19 @@ match usedFVarsMap.find? fvarId with
|
|||
not in the context of the let-rec associated with fvarId.
|
||||
We filter these out-of-context free variables later. -/
|
||||
| some otherFVarIds => merge fvarIdsNew otherFVarIds)
|
||||
fvarIds;
|
||||
fvarIds
|
||||
modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew
|
||||
|
||||
private partial def fixpoint : Unit → M Unit
|
||||
| _ => do
|
||||
resetModified;
|
||||
letRecFVarIds ← read;
|
||||
letRecFVarIds.forM updateUsedVarsOf;
|
||||
whenM isModified $ fixpoint ()
|
||||
resetModified
|
||||
let letRecFVarIds ← read
|
||||
letRecFVarIds.forM updateUsedVarsOf
|
||||
if (← isModified) then
|
||||
fixpoint ()
|
||||
|
||||
def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
|
||||
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap };
|
||||
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }
|
||||
s.usedFVarsMap
|
||||
|
||||
end FixPoint
|
||||
|
|
@ -377,23 +367,23 @@ abbrev FreeVarMap := NameMap (Array FVarId)
|
|||
|
||||
private def mkFreeVarMap
|
||||
(mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId)
|
||||
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap :=
|
||||
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift;
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap;
|
||||
letRecsToLift.foldl
|
||||
(fun (freeVarMap : FreeVarMap) toLift =>
|
||||
let lctx := toLift.lctx;
|
||||
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!;
|
||||
let fvarIds := fvarIdsSet.fold
|
||||
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
|
||||
if lctx.contains fvarId && !recFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
else
|
||||
fvarIds)
|
||||
#[];
|
||||
freeVarMap.insert toLift.fvarId fvarIds)
|
||||
{}
|
||||
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := do
|
||||
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
|
||||
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
|
||||
let freeVarMap := {}
|
||||
for toLift in letRecsToLift do
|
||||
let lctx := toLift.lctx
|
||||
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!
|
||||
let fvarIds := fvarIdsSet.fold
|
||||
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
|
||||
if lctx.contains fvarId && !recFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
else
|
||||
fvarIds)
|
||||
#[]
|
||||
freeVarMap := freeVarMap.insert toLift.fvarId fvarIds
|
||||
pure freeVarMap
|
||||
|
||||
structure ClosureState :=
|
||||
(newLocalDecls : Array LocalDecl := #[])
|
||||
|
|
@ -405,9 +395,9 @@ private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option
|
|||
fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||||
|
||||
private def preprocess (e : Expr) : TermElabM Expr := do
|
||||
e ← instantiateMVars e;
|
||||
let e ← instantiateMVars e
|
||||
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
|
||||
liftM $ check e;
|
||||
Meta.check e
|
||||
pure e
|
||||
|
||||
/- Push free variables in `s` to `toProcess` if they are not already there. -/
|
||||
|
|
@ -418,47 +408,47 @@ s.fvarSet.fold
|
|||
|
||||
private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default)
|
||||
: StateRefT ClosureState TermElabM (Array FVarId) := do
|
||||
type ← liftM $ preprocess type;
|
||||
let type ← preprocess type
|
||||
modify fun s => { s with
|
||||
newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi,
|
||||
exprArgs := s.exprArgs.push (mkFVar fvarId)
|
||||
};
|
||||
}
|
||||
pure $ pushNewVars toProcess (collectFVars {} type)
|
||||
|
||||
private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit
|
||||
| toProcess => do
|
||||
lctx ← getLCtx;
|
||||
let lctx ← getLCtx
|
||||
match pickMaxFVar? lctx toProcess with
|
||||
| none => pure ()
|
||||
| some fvarId => do
|
||||
trace `Elab.definition.mkClosure fun _ => "toProcess: " ++ (toProcess.map mkFVar) ++ ", maxVar: " ++ mkFVar fvarId;
|
||||
let toProcess := toProcess.erase fvarId;
|
||||
localDecl ← getLocalDecl fvarId;
|
||||
| some fvarId =>
|
||||
trace[Elab.definition.mkClosure]! "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}"
|
||||
let toProcess := toProcess.erase fvarId
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
match localDecl with
|
||||
| LocalDecl.cdecl _ _ userName type bi => do
|
||||
toProcess ← pushLocalDecl toProcess fvarId userName type bi;
|
||||
| LocalDecl.cdecl _ _ userName type bi =>
|
||||
let toProcess ← pushLocalDecl toProcess fvarId userName type bi
|
||||
mkClosureForAux toProcess
|
||||
| LocalDecl.ldecl _ _ userName type val _ => do
|
||||
zetaFVarIds ← getZetaFVarIds;
|
||||
if !zetaFVarIds.contains fvarId then do
|
||||
| LocalDecl.ldecl _ _ userName type val _ =>
|
||||
let zetaFVarIds ← getZetaFVarIds
|
||||
if !zetaFVarIds.contains fvarId then
|
||||
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
|
||||
toProcess ← pushLocalDecl toProcess fvarId userName type;
|
||||
let toProcess ← pushLocalDecl toProcess fvarId userName type
|
||||
mkClosureForAux toProcess
|
||||
else do
|
||||
else
|
||||
/- Dependent let-decl. -/
|
||||
type ← liftM $ preprocess type;
|
||||
val ← liftM $ preprocess val;
|
||||
let type ← preprocess type
|
||||
let val ← preprocess val
|
||||
modify fun s => { s with
|
||||
newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false,
|
||||
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
|
||||
at `newLocalDecls` and `localDecls` -/
|
||||
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
|
||||
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
|
||||
};
|
||||
}
|
||||
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
|
||||
|
||||
private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do
|
||||
(_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls };
|
||||
let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls }
|
||||
pure { s with
|
||||
newLocalDecls := s.newLocalDecls.reverse,
|
||||
newLetDecls := s.newLetDecls.reverse,
|
||||
|
|
@ -470,16 +460,16 @@ structure LetRecClosure :=
|
|||
(toLift : LetRecToLift)
|
||||
|
||||
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
|
||||
let lctx := toLift.lctx;
|
||||
let lctx := toLift.lctx
|
||||
withLCtx lctx toLift.localInstances do
|
||||
lambdaTelescope toLift.val fun xs val => do
|
||||
type ← instantiateForall toLift.type xs;
|
||||
lctx ← getLCtx;
|
||||
s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId!;
|
||||
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type;
|
||||
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val;
|
||||
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs;
|
||||
assignExprMVar toLift.mvarId c;
|
||||
let type ← instantiateForall toLift.type xs
|
||||
let lctx ← getLCtx
|
||||
let s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId!
|
||||
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type
|
||||
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val
|
||||
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs
|
||||
assignExprMVar toLift.mvarId c
|
||||
pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩
|
||||
|
||||
private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) :=
|
||||
|
|
@ -508,9 +498,9 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
|
|||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM
|
||||
(fun i (preDefs : Array PreDefinition) => do
|
||||
let header := mainHeaders.get! i;
|
||||
val ← mkLambdaFVars sectionVars (mainVals.get! i);
|
||||
type ← mkForallFVars sectionVars header.type;
|
||||
let header := mainHeaders[i]
|
||||
let val ← mkLambdaFVars sectionVars mainVals[i]
|
||||
let type ← mkForallFVars sectionVars header.type
|
||||
pure $ preDefs.push {
|
||||
kind := header.kind,
|
||||
declName := header.declName,
|
||||
|
|
@ -524,8 +514,8 @@ mainHeaders.size.foldM
|
|||
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
|
||||
letRecClosures.foldl
|
||||
(fun (preDefs : Array PreDefinition) (c : LetRecClosure) =>
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type;
|
||||
let val := Closure.mkLambda c.localDecls c.toLift.val;
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type
|
||||
let val := Closure.mkLambda c.localDecls c.toLift.val
|
||||
preDefs.push {
|
||||
kind := kind,
|
||||
declName := c.toLift.declName,
|
||||
|
|
@ -555,29 +545,29 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers :
|
|||
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
|
||||
: TermElabM (Array PreDefinition) := do
|
||||
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
|
||||
let mainFVarIds := mainFVars.map Expr.fvarId!;
|
||||
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
|
||||
let mainFVarIds := mainFVars.map Expr.fvarId!
|
||||
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds
|
||||
-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec.
|
||||
mctx ← getMCtx;
|
||||
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift;
|
||||
resetZetaFVarIds;
|
||||
let mctx ← getMCtx
|
||||
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift
|
||||
resetZetaFVarIds
|
||||
withTrackingZeta do
|
||||
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
|
||||
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val };
|
||||
letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap;
|
||||
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
|
||||
mainVals ← mainVals.mapM instantiateMVars;
|
||||
mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader;
|
||||
letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } };
|
||||
-- Replace fvarIds for functions being defined with closed terms
|
||||
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars;
|
||||
let r := insertReplacementForLetRecs r letRecClosures;
|
||||
let mainVals := mainVals.map r.apply;
|
||||
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type };
|
||||
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } };
|
||||
let letRecKind := getKindForLetRecs mainHeaders;
|
||||
let letRecMods := getModifiersForLetRecs mainHeaders;
|
||||
pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals
|
||||
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
|
||||
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do Meta.check toLift.type; Meta.check toLift.val
|
||||
let letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap
|
||||
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
|
||||
let mainVals ← mainVals.mapM instantiateMVars
|
||||
let mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader
|
||||
let letRecClosures ← letRecClosures.mapM fun closure => do pure { closure with toLift := (← instantiateMVarsAtLetRecToLift closure.toLift) }
|
||||
-- Replace fvarIds for functions being defined with closed terms
|
||||
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars
|
||||
let r := insertReplacementForLetRecs r letRecClosures
|
||||
let mainVals := mainVals.map r.apply
|
||||
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type }
|
||||
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } }
|
||||
let letRecKind := getKindForLetRecs mainHeaders
|
||||
let letRecMods := getModifiersForLetRecs mainHeaders
|
||||
pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals
|
||||
|
||||
end MutualClosure
|
||||
|
||||
|
|
@ -589,34 +579,34 @@ else
|
|||
[]
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
|
||||
scopeLevelNames ← getLevelNames;
|
||||
headers ← elabHeaders views;
|
||||
let allUserLevelNames := getAllUserLevelNames headers;
|
||||
let scopeLevelNames ← getLevelNames
|
||||
let headers ← elabHeaders views
|
||||
let allUserLevelNames := getAllUserLevelNames headers
|
||||
withFunLocalDecls headers fun funFVars => do
|
||||
values ← elabFunValues headers;
|
||||
Term.synthesizeSyntheticMVarsNoPostponing;
|
||||
if isExample views then pure ()
|
||||
else do
|
||||
values ← values.mapM instantiateMVars;
|
||||
headers ← headers.mapM instantiateMVarsAtHeader;
|
||||
letRecsToLift ← getLetRecsToLift;
|
||||
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift;
|
||||
let values ← elabFunValues headers
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
if isExample views then
|
||||
pure ()
|
||||
else
|
||||
let values ← values.mapM instantiateMVars
|
||||
let headers ← headers.mapM instantiateMVarsAtHeader
|
||||
let letRecsToLift ← getLetRecsToLift
|
||||
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift
|
||||
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
|
||||
preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift;
|
||||
preDefs ← levelMVarToParamPreDecls preDefs;
|
||||
preDefs ← instantiateMVarsAtPreDecls preDefs;
|
||||
preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames;
|
||||
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
|
||||
let preDefs ← levelMVarToParamPreDecls preDefs
|
||||
let preDefs ← instantiateMVarsAtPreDecls preDefs
|
||||
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
|
||||
addPreDefinitions preDefs
|
||||
|
||||
end Term
|
||||
namespace Command
|
||||
|
||||
def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
|
||||
views ← ds.mapM fun d => do {
|
||||
modifiers ← elabModifiers (d.getArg 0);
|
||||
mkDefView modifiers (d.getArg 1)
|
||||
};
|
||||
let views ← ds.mapM fun d => do
|
||||
let modifiers ← elabModifiers d[0]
|
||||
mkDefView modifiers d[1]
|
||||
runTermElabM none fun vars => Term.elabMutualDef vars views
|
||||
|
||||
end Command
|
||||
|
|
|
|||
1080
stage0/stdlib/Lean/Elab/Declaration.c
generated
1080
stage0/stdlib/Lean/Elab/Declaration.c
generated
File diff suppressed because it is too large
Load diff
7277
stage0/stdlib/Lean/Elab/MutualDef.c
generated
7277
stage0/stdlib/Lean/Elab/MutualDef.c
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue