From 387b6c22eed2d2460fb2770a47cdbfd58ebcd505 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 25 Jul 2022 14:25:44 -0700 Subject: [PATCH] chore: document and cleanup --- src/Lean/Elab/MutualDef.lean | 105 +++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index e0db7b875c..d304462ffa 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -20,13 +20,23 @@ open Lean.Parser.Term structure DefViewElabHeader where ref : Syntax modifiers : Modifiers + /-- Stores whether this is the header of a definition, theorem, ... -/ kind : DefKind + /-- + Short name. Recall that all declarations in Lean 4 are potentially recursive. We use `shortDeclName` to refer + to them at `valueStx`, and other declarations in the same mutual block. -/ shortDeclName : Name + /-- Full name for this declaration. This is the name that will be added to the `Environment`. -/ declName : Name + /-- Universe level parameter names explicitly provided by the user. -/ levelNames : List Name + /-- Syntax objects for the binders occurring befor `:`, we use them to populate the `InfoTree` when elaborating `valueStx`. -/ binderIds : Array Syntax + /-- Number of parameters before `:`, it also includes auto-implicit parameters automatically added by Lean. -/ numParams : Nat - type : Expr -- including the parameters + /-- Type including parameters. -/ + type : Expr + /-- `Syntax` object the body/value of the definition. -/ valueStx : Syntax deriving Inhabited @@ -68,7 +78,7 @@ private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewEl checkModifiers newHeader.modifiers firstHeader.modifiers checkKinds newHeader.kind firstHeader.kind catch - | Exception.error ref msg => throw (Exception.error ref m!"invalid mutually recursive definitions, {msg}") + | .error ref msg => throw (.error ref m!"invalid mutually recursive definitions, {msg}") | ex => throw ex else pure () @@ -83,7 +93,7 @@ private def registerFailedToInferDefTypeInfo (type : Expr) (ref : Syntax) : Term ``` -/ private def isMultiConstant? (views : Array DefView) : Option (List Name) := if views.size == 1 && - views[0]!.kind == DefKind.opaque && + views[0]!.kind == .opaque && views[0]!.binders.getArgs.size > 0 && views[0]!.binders.getArgs.all (·.isIdent) then some (views[0]!.binders.getArgs.toList.map (·.getId)) @@ -99,6 +109,7 @@ private def getPendindMVarErrorMessage (views : Array DefView) : String := | none => "\nwhen the resulting type of a declaration is explicitly provided, all holes (e.g., `_`) in the header are resolved before the declaration body is processed" +/-- Elaborate only the declaration headers. We have to elaborate the headers first because we support mutually recursive declarations in Lean 4. -/ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do let expandedDeclIds ← views.mapM fun view => withRef view.ref do Term.expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers @@ -107,7 +118,7 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe for view in views, ⟨shortDeclName, declName, levelNames⟩ in expandedDeclIds do let newHeader ← withRef view.ref do addDeclarationRanges declName view.ref - applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration + applyAttributesAt declName view.modifiers.attrs .beforeElaboration withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <| elabBindersEx view.binders.getArgs fun xs => do let refForElabFunType := view.value @@ -134,21 +145,23 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe discard <| logUnassignedUsingErrorInfos pendingMVarIds <| getPendindMVarErrorMessage views let newHeader := { - ref := view.ref, - modifiers := view.modifiers, - kind := view.kind, - shortDeclName := shortDeclName, - declName := declName, - levelNames := levelNames, - binderIds := binderIds, - numParams := xs.size, - type := type, + ref := view.ref + modifiers := view.modifiers + kind := view.kind + shortDeclName := shortDeclName + declName, type, levelNames, binderIds + numParams := xs.size valueStx := view.value : DefViewElabHeader } check headers newHeader return newHeader headers := headers.push newHeader return headers +/-- + Create auxiliary local declarations `fs` for the given hearders using their `shortDeclName` and `type`, given hearders, and execute `k fs`. + The new free variables are tagged as `auxDecl`. + Remark: `fs.size = headers.size`. +-/ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α := let rec loop (i : Nat) (fvars : Array Expr) := do if h : i < headers.size then @@ -164,10 +177,10 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) ( private def expandWhereStructInst : Macro | `(Parser.Command.whereStructInst|where $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do let letIdDecls ← decls.mapM fun stx => match stx with - | `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here" + | `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here" | `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl | `(letDecl|$decl:letIdDecl) => pure decl - | _ => Macro.throwUnsupported + | _ => Macro.throwUnsupported let structInstFields ← letIdDecls.mapM fun | stx@`(letIdDecl|$id:ident $binders* $[: $ty?]? := $val) => withRef stx do let mut val := val @@ -191,11 +204,11 @@ def declVal := declValSimple <|> declValEqns <|> Term.whereDecls ``` -/ private def declValToTerm (declVal : Syntax) : MacroM Syntax := withRef declVal do - if declVal.isOfKind ``Lean.Parser.Command.declValSimple then + if declVal.isOfKind ``Parser.Command.declValSimple then expandWhereDeclsOpt declVal[2] declVal[1] - else if declVal.isOfKind ``Lean.Parser.Command.declValEqns then + else if declVal.isOfKind ``Parser.Command.declValEqns then expandMatchAltsWhereDecls declVal[0] - else if declVal.isOfKind ``Lean.Parser.Command.whereStructInst then + else if declVal.isOfKind ``Parser.Command.whereStructInst then expandWhereStructInst declVal else if declVal.isMissing then Macro.throwErrorAt declVal "declaration body is missing" @@ -244,7 +257,7 @@ private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM De private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do let type ← instantiateMVars toLift.type let val ← instantiateMVars toLift.val - pure { toLift with type := type, val := val } + pure { toLift with type, val } private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId := let occ? := type.find? fun e => match e with @@ -393,14 +406,14 @@ private def merge (s₁ s₂ : FVarIdSet) : M FVarIdSet := private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do let usedFVarsMap ← getUsedFVarsMap match usedFVarsMap.find? fvarId with - | none => pure () + | none => return () | some fvarIds => - let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' => + let fvarIdsNew ← fvarIds.foldM (init := fvarIds) fun fvarIdsNew fvarId' => do if fvarId == fvarId' then - pure fvarIdsNew + return fvarIdsNew else match usedFVarsMap.find? fvarId' with - | none => pure fvarIdsNew + | none => return fvarIdsNew /- We are being sloppy here `otherFVarIds` may contain free variables that are not in the context of the let-rec associated with fvarId. We filter these out-of-context free variables later. -/ @@ -416,7 +429,7 @@ private partial def fixpoint : Unit → M Unit fixpoint () def run (letRecFVarIds : Array FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap := - let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap } + let (_, s) := fixpoint () |>.run letRecFVarIds |>.run { usedFVarsMap := usedFVarsMap } s.usedFVarsMap end FixPoint @@ -426,13 +439,13 @@ abbrev FreeVarMap := FVarIdMap (Array FVarId) private def mkFreeVarMap [Monad m] [MonadMCtx m] (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : m FreeVarMap := do - let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift - let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId - let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap + let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift + let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId + let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap let mut freeVarMap := {} for toLift in letRecsToLift do let lctx := toLift.lctx - let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get! + let fvarIdsSet := usedFVarsMap.find? toLift.fvarId |>.get! let fvarIds := fvarIdsSet.fold (init := #[]) fun fvarIds fvarId => if lctx.contains fvarId && !recFVarIds.contains fvarId then fvarIds.push fvarId @@ -465,7 +478,7 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : StateRefT ClosureState TermElabM (Array FVarId) := do let type ← preprocess type modify fun s => { s with - newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi, + newLocalDecls := s.newLocalDecls.push <| LocalDecl.cdecl default fvarId userName type bi exprArgs := s.exprArgs.push (mkFVar fvarId) } return pushNewVars toProcess (collectFVars {} type) @@ -473,16 +486,16 @@ private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT ClosureState TermElabM Unit := do let lctx ← getLCtx match pickMaxFVar? lctx toProcess with - | none => pure () + | none => return () | some fvarId => trace[Elab.definition.mkClosure] "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}" let toProcess := toProcess.erase fvarId let localDecl ← getLocalDecl fvarId match localDecl with - | LocalDecl.cdecl _ _ userName type bi => + | .cdecl _ _ userName type bi => let toProcess ← pushLocalDecl toProcess fvarId userName type bi mkClosureForAux toProcess - | LocalDecl.ldecl _ _ userName type val _ => + | .ldecl _ _ userName type val _ => let zetaFVarIds ← getZetaFVarIds if !zetaFVarIds.contains fvarId then /- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/ @@ -496,23 +509,24 @@ private partial def mkClosureForAux (toProcess : Array FVarId) : StateRefT Closu newLetDecls := s.newLetDecls.push <| LocalDecl.ldecl default fvarId userName type val false, /- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId at `newLocalDecls` and `localDecls` -/ - newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val), + newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val) localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val) } mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val)) private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do - let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls } + let (_, s) ← mkClosureForAux freeVars |>.run { localDecls := localDecls } pure { s with - newLocalDecls := s.newLocalDecls.reverse, - newLetDecls := s.newLetDecls.reverse, + newLocalDecls := s.newLocalDecls.reverse + newLetDecls := s.newLetDecls.reverse exprArgs := s.exprArgs.reverse } structure LetRecClosure where ref : Syntax localDecls : Array LocalDecl - closed : Expr -- expression used to replace occurrences of the let-rec FVarId + /-- Expression used to replace occurrences of the let-rec `FVarId`. -/ + closed : Expr toLift : LetRecToLift private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do @@ -530,7 +544,7 @@ private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) ref := toLift.ref localDecls := s.newLocalDecls closed := c - toLift := { toLift with val := val, type := type } + toLift := { toLift with val, type } } private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : TermElabM (List LetRecClosure) := do @@ -565,7 +579,7 @@ def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecC def Replacement.apply (r : Replacement) (e : Expr) : Expr := e.replace fun e => match e with - | Expr.fvar fvarId => match r.find? fvarId with + | .fvar fvarId => match r.find? fvarId with | some c => some c | _ => none | _ => none @@ -574,7 +588,7 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea : TermElabM (Array PreDefinition) := mainHeaders.size.foldM (init := preDefs) fun i preDefs => do let header := mainHeaders[i]! - let val ← mkLambdaFVars sectionVars mainVals[i]! + let value ← mkLambdaFVars sectionVars mainVals[i]! let type ← mkForallFVars sectionVars header.type return preDefs.push { ref := getDeclarationSelectionRef header.ref @@ -582,22 +596,19 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea declName := header.declName levelParams := [], -- we set it later modifiers := header.modifiers - type := type - value := val + type, value } def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition := letRecClosures.foldl (init := preDefs) fun preDefs c => - let type := Closure.mkForall c.localDecls c.toLift.type - let val := Closure.mkLambda c.localDecls c.toLift.val + let type := Closure.mkForall c.localDecls c.toLift.type + let value := Closure.mkLambda c.localDecls c.toLift.val preDefs.push { ref := c.ref - kind := kind declName := c.toLift.declName levelParams := [] -- we set it later modifiers := { modifiers with attrs := c.toLift.attrs } - type := type - value := val + kind, type, value } def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=