chore: document and cleanup

This commit is contained in:
Leonardo de Moura 2022-07-25 14:25:44 -07:00
parent da44604c1b
commit 387b6c22ee

View file

@ -20,13 +20,23 @@ open Lean.Parser.Term
structure DefViewElabHeader where
ref : Syntax
modifiers : Modifiers
/-- Stores whether this is the header of a definition, theorem, ... -/
kind : DefKind
/--
Short name. Recall that all declarations in Lean 4 are potentially recursive. We use `shortDeclName` to refer
to them at `valueStx`, and other declarations in the same mutual block. -/
shortDeclName : Name
/-- Full name for this declaration. This is the name that will be added to the `Environment`. -/
declName : Name
/-- Universe level parameter names explicitly provided by the user. -/
levelNames : List Name
/-- Syntax objects for the binders occurring befor `:`, we use them to populate the `InfoTree` when elaborating `valueStx`. -/
binderIds : Array Syntax
/-- Number of parameters before `:`, it also includes auto-implicit parameters automatically added by Lean. -/
numParams : Nat
type : Expr -- including the parameters
/-- Type including parameters. -/
type : Expr
/-- `Syntax` object the body/value of the definition. -/
valueStx : Syntax
deriving Inhabited
@ -68,7 +78,7 @@ private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewEl
checkModifiers newHeader.modifiers firstHeader.modifiers
checkKinds newHeader.kind firstHeader.kind
catch
| Exception.error ref msg => throw (Exception.error ref m!"invalid mutually recursive definitions, {msg}")
| .error ref msg => throw (.error ref m!"invalid mutually recursive definitions, {msg}")
| ex => throw ex
else
pure ()
@ -83,7 +93,7 @@ private def registerFailedToInferDefTypeInfo (type : Expr) (ref : Syntax) : Term
``` -/
private def isMultiConstant? (views : Array DefView) : Option (List Name) :=
if views.size == 1 &&
views[0]!.kind == DefKind.opaque &&
views[0]!.kind == .opaque &&
views[0]!.binders.getArgs.size > 0 &&
views[0]!.binders.getArgs.all (·.isIdent) then
some (views[0]!.binders.getArgs.toList.map (·.getId))
@ -99,6 +109,7 @@ private def getPendindMVarErrorMessage (views : Array DefView) : String :=
| none =>
"\nwhen the resulting type of a declaration is explicitly provided, all holes (e.g., `_`) in the header are resolved before the declaration body is processed"
/-- Elaborate only the declaration headers. We have to elaborate the headers first because we support mutually recursive declarations in Lean 4. -/
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do
let expandedDeclIds ← views.mapM fun view => withRef view.ref do
Term.expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers
@ -107,7 +118,7 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe
for view in views, ⟨shortDeclName, declName, levelNames⟩ in expandedDeclIds do
let newHeader ← withRef view.ref do
addDeclarationRanges declName view.ref
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration
applyAttributesAt declName view.modifiers.attrs .beforeElaboration
withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <|
elabBindersEx view.binders.getArgs fun xs => do
let refForElabFunType := view.value
@ -134,21 +145,23 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe
discard <| logUnassignedUsingErrorInfos pendingMVarIds <|
getPendindMVarErrorMessage views
let newHeader := {
ref := view.ref,
modifiers := view.modifiers,
kind := view.kind,
shortDeclName := shortDeclName,
declName := declName,
levelNames := levelNames,
binderIds := binderIds,
numParams := xs.size,
type := type,
ref := view.ref
modifiers := view.modifiers
kind := view.kind
shortDeclName := shortDeclName
declName, type, levelNames, binderIds
numParams := xs.size
valueStx := view.value : DefViewElabHeader }
check headers newHeader
return newHeader
headers := headers.push newHeader
return headers
/--
Create auxiliary local declarations `fs` for the given hearders using their `shortDeclName` and `type`, given hearders, and execute `k fs`.
The new free variables are tagged as `auxDecl`.
Remark: `fs.size = headers.size`.
-/
private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (fvars : Array Expr) := do
if h : i < headers.size then
@ -164,10 +177,10 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (
private def expandWhereStructInst : Macro
| `(Parser.Command.whereStructInst|where $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
let letIdDecls ← decls.mapM fun stx => match stx with
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
| `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl
| `(letDecl|$decl:letIdDecl) => pure decl
| _ => Macro.throwUnsupported
| _ => Macro.throwUnsupported
let structInstFields ← letIdDecls.mapM fun
| stx@`(letIdDecl|$id:ident $binders* $[: $ty?]? := $val) => withRef stx do
let mut val := val
@ -191,11 +204,11 @@ def declVal := declValSimple <|> declValEqns <|> Term.whereDecls
```
-/
private def declValToTerm (declVal : Syntax) : MacroM Syntax := withRef declVal do
if declVal.isOfKind ``Lean.Parser.Command.declValSimple then
if declVal.isOfKind ``Parser.Command.declValSimple then
expandWhereDeclsOpt declVal[2] declVal[1]
else if declVal.isOfKind ``Lean.Parser.Command.declValEqns then
else if declVal.isOfKind ``Parser.Command.declValEqns then
expandMatchAltsWhereDecls declVal[0]
else if declVal.isOfKind ``Lean.Parser.Command.whereStructInst then
else if declVal.isOfKind ``Parser.Command.whereStructInst then
expandWhereStructInst declVal
else if declVal.isMissing then
Macro.throwErrorAt declVal "declaration body is missing"
@ -244,7 +257,7 @@ private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM De
private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do
let type ← instantiateMVars toLift.type
let val ← instantiateMVars toLift.val
pure { toLift with type := type, val := val }
pure { toLift with type, val }
private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId :=
let occ? := type.find? fun e => match e with
@ -393,14 +406,14 @@ private def merge (s₁ s₂ : FVarIdSet) : M FVarIdSet :=
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
let usedFVarsMap ← getUsedFVarsMap
match usedFVarsMap.find? fvarId with
| none => pure ()
| none => return ()
| some fvarIds =>
let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' =>
let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' => do
if fvarId == fvarId' then
pure fvarIdsNew
return fvarIdsNew
else
match usedFVarsMap.find? fvarId' with
| none => pure fvarIdsNew
| none => return fvarIdsNew
/- We are being sloppy here `otherFVarIds` may contain free variables that are
not in the context of the let-rec associated with fvarId.
We filter these out-of-context free variables later. -/
@ -416,7 +429,7 @@ private partial def fixpoint : Unit → M Unit
fixpoint ()
def run (letRecFVarIds : Array FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }
let (_, s) := fixpoint () |>.run letRecFVarIds |>.run { usedFVarsMap := usedFVarsMap }
s.usedFVarsMap
end FixPoint
@ -426,13 +439,13 @@ abbrev FreeVarMap := FVarIdMap (Array FVarId)
private def mkFreeVarMap [Monad m] [MonadMCtx m]
(sectionVars : Array Expr) (mainFVarIds : Array FVarId)
(recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : m FreeVarMap := do
let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
let mut freeVarMap := {}
for toLift in letRecsToLift do
let lctx := toLift.lctx
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!
let fvarIdsSet := usedFVarsMap.find? toLift.fvarId |>.get!
let fvarIds := fvarIdsSet.fold (init := #[]) fun fvarIds fvarId =>
if lctx.contains fvarId && !recFVarIds.contains fvarId then
fvarIds.push fvarId
@ -465,7 +478,7 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName
: StateRefT ClosureState TermElabM (Array FVarId) := do
let type ← preprocess type
modify fun s => { s with
newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi,
newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi
exprArgs := s.exprArgs.push (mkFVar fvarId)
}
return pushNewVars toProcess (collectFVars {} type)
@ -473,16 +486,16 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName
private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT ClosureState TermElabM Unit := do
let lctx ← getLCtx
match pickMaxFVar? lctx toProcess with
| none => pure ()
| none => return ()
| some fvarId =>
trace[Elab.definition.mkClosure] "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}"
let toProcess := toProcess.erase fvarId
let localDecl ← getLocalDecl fvarId
match localDecl with
| LocalDecl.cdecl _ _ userName type bi =>
| .cdecl _ _ userName type bi =>
let toProcess ← pushLocalDecl toProcess fvarId userName type bi
mkClosureForAux toProcess
| LocalDecl.ldecl _ _ userName type val _ =>
| .ldecl _ _ userName type val _ =>
let zetaFVarIds ← getZetaFVarIds
if !zetaFVarIds.contains fvarId then
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
@ -496,23 +509,24 @@ private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT Closu
newLetDecls := s.newLetDecls.push <| LocalDecl.ldecl default fvarId userName type val false,
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
at `newLocalDecls` and `localDecls` -/
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val)
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
}
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do
let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls }
let (_, s) ← mkClosureForAux freeVars |>.run { localDecls := localDecls }
pure { s with
newLocalDecls := s.newLocalDecls.reverse,
newLetDecls := s.newLetDecls.reverse,
newLocalDecls := s.newLocalDecls.reverse
newLetDecls := s.newLetDecls.reverse
exprArgs := s.exprArgs.reverse
}
structure LetRecClosure where
ref : Syntax
localDecls : Array LocalDecl
closed : Expr -- expression used to replace occurrences of the let-rec FVarId
/-- Expression used to replace occurrences of the let-rec `FVarId`. -/
closed : Expr
toLift : LetRecToLift
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
@ -530,7 +544,7 @@ private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId)
ref := toLift.ref
localDecls := s.newLocalDecls
closed := c
toLift := { toLift with val := val, type := type }
toLift := { toLift with val, type }
}
private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : TermElabM (List LetRecClosure) := do
@ -565,7 +579,7 @@ def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecC
def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
e.replace fun e => match e with
| Expr.fvar fvarId => match r.find? fvarId with
| .fvar fvarId => match r.find? fvarId with
| some c => some c
| _ => none
| _ => none
@ -574,7 +588,7 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
: TermElabM (Array PreDefinition) :=
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
let header := mainHeaders[i]!
let val ← mkLambdaFVars sectionVars mainVals[i]!
let value ← mkLambdaFVars sectionVars mainVals[i]!
let type ← mkForallFVars sectionVars header.type
return preDefs.push {
ref := getDeclarationSelectionRef header.ref
@ -582,22 +596,19 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea
declName := header.declName
levelParams := [], -- we set it later
modifiers := header.modifiers
type := type
value := val
type, value
}
def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition :=
letRecClosures.foldl (init := preDefs) fun preDefs c =>
let type := Closure.mkForall c.localDecls c.toLift.type
let val := Closure.mkLambda c.localDecls c.toLift.val
let type := Closure.mkForall c.localDecls c.toLift.type
let value := Closure.mkLambda c.localDecls c.toLift.val
preDefs.push {
ref := c.ref
kind := kind
declName := c.toLift.declName
levelParams := [] -- we set it later
modifiers := { modifiers with attrs := c.toLift.attrs }
type := type
value := val
kind, type, value
}
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=