chore: document and cleanup
This commit is contained in:
parent
da44604c1b
commit
387b6c22ee
1 changed files with 58 additions and 47 deletions
|
|
@ -20,13 +20,23 @@ open Lean.Parser.Term
|
|||
structure DefViewElabHeader where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
/-- Stores whether this is the header of a definition, theorem, ... -/
|
||||
kind : DefKind
|
||||
/--
|
||||
Short name. Recall that all declarations in Lean 4 are potentially recursive. We use `shortDeclName` to refer
|
||||
to them at `valueStx`, and other declarations in the same mutual block. -/
|
||||
shortDeclName : Name
|
||||
/-- Full name for this declaration. This is the name that will be added to the `Environment`. -/
|
||||
declName : Name
|
||||
/-- Universe level parameter names explicitly provided by the user. -/
|
||||
levelNames : List Name
|
||||
/-- Syntax objects for the binders occurring befor `:`, we use them to populate the `InfoTree` when elaborating `valueStx`. -/
|
||||
binderIds : Array Syntax
|
||||
/-- Number of parameters before `:`, it also includes auto-implicit parameters automatically added by Lean. -/
|
||||
numParams : Nat
|
||||
type : Expr -- including the parameters
|
||||
/-- Type including parameters. -/
|
||||
type : Expr
|
||||
/-- `Syntax` object the body/value of the definition. -/
|
||||
valueStx : Syntax
|
||||
deriving Inhabited
|
||||
|
||||
|
|
@ -68,7 +78,7 @@ private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewEl
|
|||
checkModifiers newHeader.modifiers firstHeader.modifiers
|
||||
checkKinds newHeader.kind firstHeader.kind
|
||||
catch
|
||||
| Exception.error ref msg => throw (Exception.error ref m!"invalid mutually recursive definitions, {msg}")
|
||||
| .error ref msg => throw (.error ref m!"invalid mutually recursive definitions, {msg}")
|
||||
| ex => throw ex
|
||||
else
|
||||
pure ()
|
||||
|
|
@ -83,7 +93,7 @@ private def registerFailedToInferDefTypeInfo (type : Expr) (ref : Syntax) : Term
|
|||
``` -/
|
||||
private def isMultiConstant? (views : Array DefView) : Option (List Name) :=
|
||||
if views.size == 1 &&
|
||||
views[0]!.kind == DefKind.opaque &&
|
||||
views[0]!.kind == .opaque &&
|
||||
views[0]!.binders.getArgs.size > 0 &&
|
||||
views[0]!.binders.getArgs.all (·.isIdent) then
|
||||
some (views[0]!.binders.getArgs.toList.map (·.getId))
|
||||
|
|
@ -99,6 +109,7 @@ private def getPendindMVarErrorMessage (views : Array DefView) : String :=
|
|||
| none =>
|
||||
"\nwhen the resulting type of a declaration is explicitly provided, all holes (e.g., `_`) in the header are resolved before the declaration body is processed"
|
||||
|
||||
/-- Elaborate only the declaration headers. We have to elaborate the headers first because we support mutually recursive declarations in Lean 4. -/
|
||||
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do
|
||||
let expandedDeclIds ← views.mapM fun view => withRef view.ref do
|
||||
Term.expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers
|
||||
|
|
@ -107,7 +118,7 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe
|
|||
for view in views, ⟨shortDeclName, declName, levelNames⟩ in expandedDeclIds do
|
||||
let newHeader ← withRef view.ref do
|
||||
addDeclarationRanges declName view.ref
|
||||
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration
|
||||
applyAttributesAt declName view.modifiers.attrs .beforeElaboration
|
||||
withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <|
|
||||
elabBindersEx view.binders.getArgs fun xs => do
|
||||
let refForElabFunType := view.value
|
||||
|
|
@ -134,21 +145,23 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe
|
|||
discard <| logUnassignedUsingErrorInfos pendingMVarIds <|
|
||||
getPendindMVarErrorMessage views
|
||||
let newHeader := {
|
||||
ref := view.ref,
|
||||
modifiers := view.modifiers,
|
||||
kind := view.kind,
|
||||
shortDeclName := shortDeclName,
|
||||
declName := declName,
|
||||
levelNames := levelNames,
|
||||
binderIds := binderIds,
|
||||
numParams := xs.size,
|
||||
type := type,
|
||||
ref := view.ref
|
||||
modifiers := view.modifiers
|
||||
kind := view.kind
|
||||
shortDeclName := shortDeclName
|
||||
declName, type, levelNames, binderIds
|
||||
numParams := xs.size
|
||||
valueStx := view.value : DefViewElabHeader }
|
||||
check headers newHeader
|
||||
return newHeader
|
||||
headers := headers.push newHeader
|
||||
return headers
|
||||
|
||||
/--
|
||||
Create auxiliary local declarations `fs` for the given hearders using their `shortDeclName` and `type`, given hearders, and execute `k fs`.
|
||||
The new free variables are tagged as `auxDecl`.
|
||||
Remark: `fs.size = headers.size`.
|
||||
-/
|
||||
private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
|
||||
let rec loop (i : Nat) (fvars : Array Expr) := do
|
||||
if h : i < headers.size then
|
||||
|
|
@ -164,10 +177,10 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (
|
|||
private def expandWhereStructInst : Macro
|
||||
| `(Parser.Command.whereStructInst|where $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
|
||||
let letIdDecls ← decls.mapM fun stx => match stx with
|
||||
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
|
||||
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
|
||||
| `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl
|
||||
| `(letDecl|$decl:letIdDecl) => pure decl
|
||||
| _ => Macro.throwUnsupported
|
||||
| _ => Macro.throwUnsupported
|
||||
let structInstFields ← letIdDecls.mapM fun
|
||||
| stx@`(letIdDecl|$id:ident $binders* $[: $ty?]? := $val) => withRef stx do
|
||||
let mut val := val
|
||||
|
|
@ -191,11 +204,11 @@ def declVal := declValSimple <|> declValEqns <|> Term.whereDecls
|
|||
```
|
||||
-/
|
||||
private def declValToTerm (declVal : Syntax) : MacroM Syntax := withRef declVal do
|
||||
if declVal.isOfKind ``Lean.Parser.Command.declValSimple then
|
||||
if declVal.isOfKind ``Parser.Command.declValSimple then
|
||||
expandWhereDeclsOpt declVal[2] declVal[1]
|
||||
else if declVal.isOfKind ``Lean.Parser.Command.declValEqns then
|
||||
else if declVal.isOfKind ``Parser.Command.declValEqns then
|
||||
expandMatchAltsWhereDecls declVal[0]
|
||||
else if declVal.isOfKind ``Lean.Parser.Command.whereStructInst then
|
||||
else if declVal.isOfKind ``Parser.Command.whereStructInst then
|
||||
expandWhereStructInst declVal
|
||||
else if declVal.isMissing then
|
||||
Macro.throwErrorAt declVal "declaration body is missing"
|
||||
|
|
@ -244,7 +257,7 @@ private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM De
|
|||
private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do
|
||||
let type ← instantiateMVars toLift.type
|
||||
let val ← instantiateMVars toLift.val
|
||||
pure { toLift with type := type, val := val }
|
||||
pure { toLift with type, val }
|
||||
|
||||
private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId :=
|
||||
let occ? := type.find? fun e => match e with
|
||||
|
|
@ -393,14 +406,14 @@ private def merge (s₁ s₂ : FVarIdSet) : M FVarIdSet :=
|
|||
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
|
||||
let usedFVarsMap ← getUsedFVarsMap
|
||||
match usedFVarsMap.find? fvarId with
|
||||
| none => pure ()
|
||||
| none => return ()
|
||||
| some fvarIds =>
|
||||
let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' =>
|
||||
let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' => do
|
||||
if fvarId == fvarId' then
|
||||
pure fvarIdsNew
|
||||
return fvarIdsNew
|
||||
else
|
||||
match usedFVarsMap.find? fvarId' with
|
||||
| none => pure fvarIdsNew
|
||||
| none => return fvarIdsNew
|
||||
/- We are being sloppy here `otherFVarIds` may contain free variables that are
|
||||
not in the context of the let-rec associated with fvarId.
|
||||
We filter these out-of-context free variables later. -/
|
||||
|
|
@ -416,7 +429,7 @@ private partial def fixpoint : Unit → M Unit
|
|||
fixpoint ()
|
||||
|
||||
def run (letRecFVarIds : Array FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
|
||||
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }
|
||||
let (_, s) := fixpoint () |>.run letRecFVarIds |>.run { usedFVarsMap := usedFVarsMap }
|
||||
s.usedFVarsMap
|
||||
|
||||
end FixPoint
|
||||
|
|
@ -426,13 +439,13 @@ abbrev FreeVarMap := FVarIdMap (Array FVarId)
|
|||
private def mkFreeVarMap [Monad m] [MonadMCtx m]
|
||||
(sectionVars : Array Expr) (mainFVarIds : Array FVarId)
|
||||
(recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : m FreeVarMap := do
|
||||
let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
|
||||
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
|
||||
let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
|
||||
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
|
||||
let mut freeVarMap := {}
|
||||
for toLift in letRecsToLift do
|
||||
let lctx := toLift.lctx
|
||||
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!
|
||||
let fvarIdsSet := usedFVarsMap.find? toLift.fvarId |>.get!
|
||||
let fvarIds := fvarIdsSet.fold (init := #[]) fun fvarIds fvarId =>
|
||||
if lctx.contains fvarId && !recFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
|
|
@ -465,7 +478,7 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName
|
|||
: StateRefT ClosureState TermElabM (Array FVarId) := do
|
||||
let type ← preprocess type
|
||||
modify fun s => { s with
|
||||
newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi,
|
||||
newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi
|
||||
exprArgs := s.exprArgs.push (mkFVar fvarId)
|
||||
}
|
||||
return pushNewVars toProcess (collectFVars {} type)
|
||||
|
|
@ -473,16 +486,16 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName
|
|||
private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT ClosureState TermElabM Unit := do
|
||||
let lctx ← getLCtx
|
||||
match pickMaxFVar? lctx toProcess with
|
||||
| none => pure ()
|
||||
| none => return ()
|
||||
| some fvarId =>
|
||||
trace[Elab.definition.mkClosure] "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}"
|
||||
let toProcess := toProcess.erase fvarId
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
match localDecl with
|
||||
| LocalDecl.cdecl _ _ userName type bi =>
|
||||
| .cdecl _ _ userName type bi =>
|
||||
let toProcess ← pushLocalDecl toProcess fvarId userName type bi
|
||||
mkClosureForAux toProcess
|
||||
| LocalDecl.ldecl _ _ userName type val _ =>
|
||||
| .ldecl _ _ userName type val _ =>
|
||||
let zetaFVarIds ← getZetaFVarIds
|
||||
if !zetaFVarIds.contains fvarId then
|
||||
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
|
||||
|
|
@ -496,23 +509,24 @@ private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT Closu
|
|||
newLetDecls := s.newLetDecls.push <| LocalDecl.ldecl default fvarId userName type val false,
|
||||
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
|
||||
at `newLocalDecls` and `localDecls` -/
|
||||
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
|
||||
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val)
|
||||
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
|
||||
}
|
||||
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
|
||||
|
||||
private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do
|
||||
let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls }
|
||||
let (_, s) ← mkClosureForAux freeVars |>.run { localDecls := localDecls }
|
||||
pure { s with
|
||||
newLocalDecls := s.newLocalDecls.reverse,
|
||||
newLetDecls := s.newLetDecls.reverse,
|
||||
newLocalDecls := s.newLocalDecls.reverse
|
||||
newLetDecls := s.newLetDecls.reverse
|
||||
exprArgs := s.exprArgs.reverse
|
||||
}
|
||||
|
||||
structure LetRecClosure where
|
||||
ref : Syntax
|
||||
localDecls : Array LocalDecl
|
||||
closed : Expr -- expression used to replace occurrences of the let-rec FVarId
|
||||
/-- Expression used to replace occurrences of the let-rec `FVarId`. -/
|
||||
closed : Expr
|
||||
toLift : LetRecToLift
|
||||
|
||||
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
|
||||
|
|
@ -530,7 +544,7 @@ private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId)
|
|||
ref := toLift.ref
|
||||
localDecls := s.newLocalDecls
|
||||
closed := c
|
||||
toLift := { toLift with val := val, type := type }
|
||||
toLift := { toLift with val, type }
|
||||
}
|
||||
|
||||
private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : TermElabM (List LetRecClosure) := do
|
||||
|
|
@ -565,7 +579,7 @@ def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecC
|
|||
|
||||
def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
|
||||
e.replace fun e => match e with
|
||||
| Expr.fvar fvarId => match r.find? fvarId with
|
||||
| .fvar fvarId => match r.find? fvarId with
|
||||
| some c => some c
|
||||
| _ => none
|
||||
| _ => none
|
||||
|
|
@ -574,7 +588,7 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
|
|||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
|
||||
let header := mainHeaders[i]!
|
||||
let val ← mkLambdaFVars sectionVars mainVals[i]!
|
||||
let value ← mkLambdaFVars sectionVars mainVals[i]!
|
||||
let type ← mkForallFVars sectionVars header.type
|
||||
return preDefs.push {
|
||||
ref := getDeclarationSelectionRef header.ref
|
||||
|
|
@ -582,22 +596,19 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
|
|||
declName := header.declName
|
||||
levelParams := [], -- we set it later
|
||||
modifiers := header.modifiers
|
||||
type := type
|
||||
value := val
|
||||
type, value
|
||||
}
|
||||
|
||||
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
|
||||
letRecClosures.foldl (init := preDefs) fun preDefs c =>
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type
|
||||
let val := Closure.mkLambda c.localDecls c.toLift.val
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type
|
||||
let value := Closure.mkLambda c.localDecls c.toLift.val
|
||||
preDefs.push {
|
||||
ref := c.ref
|
||||
kind := kind
|
||||
declName := c.toLift.declName
|
||||
levelParams := [] -- we set it later
|
||||
modifiers := { modifiers with attrs := c.toLift.attrs }
|
||||
type := type
|
||||
value := val
|
||||
kind, type, value
|
||||
}
|
||||
|
||||
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue