From 8753a45452d0982b1dcabdbc3b5eff9484a58ace Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 15 Oct 2020 14:19:06 -0700 Subject: [PATCH] chore: move to new frontend --- src/Lean/Elab/MutualDef.lean | 366 +++++++++++++++++------------------ 1 file changed, 179 insertions(+), 187 deletions(-) diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index a7cd859e5f..cadb1ce9ae 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -1,3 +1,4 @@ +#lang lean4 /- Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. @@ -9,9 +10,7 @@ import Lean.Elab.Command import Lean.Elab.DefView import Lean.Elab.PreDefinition -namespace Lean -namespace Elab -open MonadResolveName (getCurrNamespace getOpenDecls) -- HACK for old frontend +namespace Lean.Elab /- DefView after elaborating the header. -/ structure DefViewElabHeader := @@ -33,45 +32,42 @@ namespace Term open Meta private def checkModifiers (m₁ m₂ : Modifiers) : TermElabM Unit := do -unless (m₁.isUnsafe == m₂.isUnsafe) $ - throwError "cannot mix unsafe and safe definitions"; -unless (m₁.isNoncomputable == m₂.isNoncomputable) $ - throwError "cannot mix computable and non-computable definitions"; -unless (m₁.isPartial == m₂.isPartial) $ - throwError "cannot mix partial and non-partial definitions"; -pure () +unless m₁.isUnsafe == m₂.isUnsafe do + throwError "cannot mix unsafe and safe definitions" +unless m₁.isNoncomputable == m₂.isNoncomputable do + throwError "cannot mix computable and non-computable definitions" +unless m₁.isPartial == m₂.isPartial do + throwError "cannot mix partial and non-partial definitions" private def checkKinds (k₁ k₂ : DefKind) : TermElabM Unit := do -unless (k₁.isExample == k₂.isExample) $ - throwError "cannot mix examples and definitions"; -- Reason: we should discard examples -unless (k₁.isTheorem == k₂.isTheorem) $ - throwError "cannot mix theorems and definitions"; -- Reason: we will eventually elaborate theorems in `Task`s. -pure () +unless k₁.isExample == k₂.isExample do + throwError "cannot mix examples and definitions" -- Reason: we should discard examples +unless k₁.isTheorem == k₂.isTheorem do + throwError "cannot mix theorems and definitions" -- Reason: we will eventually elaborate theorems in `Task`s. private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewElabHeader) : TermElabM Unit := do -when (newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe) $ - throwError "'unsafe' theorems are not allowed"; -when (newHeader.kind.isTheorem && newHeader.modifiers.isPartial) $ - throwError "'partial' theorems are not allowed, 'partial' is a code generation directive"; -when (newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable) $ +if newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe then + throwError "'unsafe' theorems are not allowed" +if newHeader.kind.isTheorem && newHeader.modifiers.isPartial then + throwError "'partial' theorems are not allowed, 'partial' is a code generation directive" +if newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable then throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems"; -when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe) $ - throwError "'noncomputable unsafe' is not allowed"; -when (newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial) $ - throwError "'noncomputable partial' is not allowed"; -when (newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe) $ - throwError "'unsafe' subsumes 'partial'"; +if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe then + throwError "'noncomputable unsafe' is not allowed" +if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial then + throwError "'noncomputable partial' is not allowed" +if newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe then + throwError "'unsafe' subsumes 'partial'" if h : 0 < prevHeaders.size then let firstHeader := prevHeaders.get ⟨0, h⟩; + try + unless newHeader.levelNames == firstHeader.levelNames do + throwError "universe parameters mismatch" + checkModifiers newHeader.modifiers firstHeader.modifiers + checkKinds newHeader.kind firstHeader.kind catch - (do - unless (newHeader.levelNames == firstHeader.levelNames) $ - throwError "universe parameters mismatch"; - checkModifiers newHeader.modifiers firstHeader.modifiers; - checkKinds newHeader.kind firstHeader.kind) - (fun ex => match ex with - | Exception.error ref msg => throw (Exception.error ref ("invalid mutually recursive definitions, " ++ msg)) - | _ => throw ex) + | Exception.error ref msg => throw (Exception.error ref msg!"invalid mutually recursive definitions, {msg}") + | ex => throw ex else pure () @@ -80,29 +76,28 @@ registerCustomErrorIfMVar type ref "failed to infer definition type" private def elabFunType (ref : Syntax) (xs : Array Expr) (view : DefView) : TermElabM Expr := do match view.type? with -| some typeStx => do - type ← elabType typeStx; - synthesizeSyntheticMVarsNoPostponing; - type ← instantiateMVars type; - registerFailedToInferDefTypeInfo type typeStx; +| some typeStx => + let type ← elabType typeStx + synthesizeSyntheticMVarsNoPostponing + let type ← instantiateMVars type + registerFailedToInferDefTypeInfo type typeStx mkForallFVars xs type -| none => do - let hole := mkHole ref; - type ← elabType hole; - registerFailedToInferDefTypeInfo type ref; +| none => + let hole := mkHole ref + let type ← elabType hole + registerFailedToInferDefTypeInfo type ref mkForallFVars xs type -private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := -views.foldlM - (fun (headers : Array DefViewElabHeader) (view : DefView) => withRef view.ref do - currNamespace ← getCurrNamespace; - currLevelNames ← getLevelNames; - ⟨shortDeclName, declName, levelNames⟩ ← expandDeclId currNamespace currLevelNames view.declId view.modifiers; - applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration; +private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do +let headers := #[] +for view in views do + let newHeader ← withRef view.ref do + let ⟨shortDeclName, declName, levelNames⟩ ← expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers + applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration withLevelNames levelNames $ elabBinders view.binders.getArgs fun xs => do - let refForElabFunType := view.value; - type ← elabFunType refForElabFunType xs view; - let newHeader : DefViewElabHeader := { + let refForElabFunType := view.value + let type ← elabFunType refForElabFunType xs view + let newHeader := { ref := view.ref, modifiers := view.modifiers, kind := view.kind, @@ -111,22 +106,20 @@ views.foldlM levelNames := levelNames, numParams := xs.size, type := type, - valueStx := view.value - }; - check headers newHeader; - pure $ headers.push newHeader) - #[] + valueStx := view.value : DefViewElabHeader } + check headers newHeader + pure newHeader + headers := headers.push newHeader +pure headers -private partial def withFunLocalDeclsAux {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : Nat → Array Expr → TermElabM α -| i, fvars => - if h : i < headers.size then do - let header := headers.get ⟨i, h⟩; - withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => withFunLocalDeclsAux (i+1) (fvars.push fvar) +private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α := +let rec loop (i : Nat) (fvars : Array Expr) := do + if h : i < headers.size then + let header := headers.get ⟨i, h⟩ + withLocalDecl header.shortDeclName BinderInfo.auxDecl header.type fun fvar => loop (i+1) (fvars.push fvar) else k fvars - -private def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (k : Array Expr → TermElabM α) : TermElabM α := -withFunLocalDeclsAux headers k 0 #[] +loop 0 #[] /- Recall that @@ -142,54 +135,53 @@ def declVal := declValSimple <|> declValEqns -/ private def declValToTerm (declVal : Syntax) : MacroM Syntax := if declVal.isOfKind `Lean.Parser.Command.declValSimple then - pure $ declVal.getArg 1 + pure declVal[1] else if declVal.isOfKind `Lean.Parser.Command.declValEqns then - expandMatchAltsIntoMatch declVal (declVal.getArg 0) + expandMatchAltsIntoMatch declVal declVal[0] else Macro.throwError declVal "unexpected definition value" private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) := headers.mapM fun header => withDeclName header.declName $ withLevelNames header.levelNames do - valStx ← liftMacroM $ declValToTerm header.valueStx; + let valStx ← liftMacroM $ declValToTerm header.valueStx forallBoundedTelescope header.type header.numParams fun xs type => do - val ← elabTermEnsuringType valStx type; + let val ← elabTermEnsuringType valStx type mkLambdaFVars xs val private def collectUsed (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift) : StateRefT CollectFVars.State TermElabM Unit := do -headers.forM fun header => collectUsedFVars header.type; -values.forM collectUsedFVars; -toLift.forM fun letRecToLift => do { - collectUsedFVars letRecToLift.type; +headers.forM fun header => collectUsedFVars header.type +values.forM collectUsedFVars +toLift.forM fun letRecToLift => do + collectUsedFVars letRecToLift.type collectUsedFVars letRecToLift.val -} private def removeUnusedVars (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift) : TermElabM (LocalContext × LocalInstances × Array Expr) := do -(_, used) ← (collectUsed headers values toLift).run {}; +let (_, used) ← (collectUsed headers values toLift).run {} removeUnused vars used private def withUsedWhen {α} (vars : Array Expr) (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift) - (cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α := -if cond then do - (lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift; + (cond : Bool) (k : Array Expr → TermElabM α) : TermElabM α := do +if cond then + let (lctx, localInsts, vars) ← removeUnusedVars vars headers values toLift withLCtx lctx localInsts $ k vars else k vars private def isExample (views : Array DefView) : Bool := -views.any fun view => view.kind.isExample +views.any (·.kind.isExample) private def isTheorem (views : Array DefView) : Bool := -views.any fun view => view.kind.isTheorem +views.any (·.kind.isTheorem) private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM DefViewElabHeader := do -type ← instantiateMVars header.type; +let type ← instantiateMVars header.type pure { header with type := type } private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do -type ← instantiateMVars toLift.type; -val ← instantiateMVars toLift.val; +let type ← instantiateMVars toLift.type +let val ← instantiateMVars toLift.val pure { toLift with type := type, val := val } private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId := @@ -201,12 +193,11 @@ match occ? with | _ => none private def getFunName (fvarId : FVarId) (letRecsToLift : List LetRecToLift) : TermElabM Name := do -decl? ← findLocalDecl? fvarId; -match decl? with +match (← findLocalDecl? fvarId) with | some decl => pure decl.userName | none => /- Recall that the FVarId of nested let-recs are not in the current local context. -/ - match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with + match letRecsToLift.findSome? fun toLift => if toLift.fvarId == fvarId then some toLift.shortDeclName else none with | none => throwError "unknown function" | some n => pure n @@ -216,12 +207,12 @@ In principle, this test can be improved. We could perform it after we separate t However, this extra complication doesn't seem worth it. -/ private def checkLetRecsToLiftTypes (funVars : Array Expr) (letRecsToLift : List LetRecToLift) : TermElabM Unit := -letRecsToLift.forM fun toLift => do +letRecsToLift.forM fun toLift => match typeHasRecFun toLift.type funVars letRecsToLift with | none => pure () | some fvarId => do - fnName ← getFunName fvarId letRecsToLift; - throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously") + let fnName ← getFunName fvarId letRecsToLift + throwErrorAt! toLift.ref "invalid type in 'let rec', it uses '{fnName}' which is being defined simultaneously" namespace MutualClosure @@ -275,25 +266,25 @@ Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover -/ private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : UsedFVarsMap := -let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {}; +let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {} let usedFVarMap := mainFVarIds.foldl (fun (usedFVarMap : UsedFVarsMap) mainFVarId => usedFVarMap.insert mainFVarId sectionVarSet) - {}; + {} letRecsToLift.foldl (fun (usedFVarMap : UsedFVarsMap) toLift => - let state := Lean.collectFVars {} toLift.val; - let state := Lean.collectFVars state toLift.type; - let set := state.fvarSet; + let state := Lean.collectFVars {} toLift.val + let state := Lean.collectFVars state toLift.type + let set := state.fvarSet /- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId for the associated let-rec because we need this information to compute the fixpoint later. -/ - let mvarIds := (toLift.val.collectMVars {}).result; + let mvarIds := (toLift.val.collectMVars {}).result let set := mvarIds.foldl (fun (set : NameSet) (mvarId : MVarId) => match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with | some fvarId => set.insert fvarId | none => set) - set; + set usedFVarMap.insert toLift.fvarId set) usedFVarMap @@ -324,10 +315,10 @@ structure State := abbrev M := ReaderT (List FVarId) $ StateM State -private def isModified : M Bool := do s ← get; pure s.modified +private def isModified : M Bool := do pure (← get).modified private def resetModified : M Unit := modify fun s => { s with modified := false } private def markModified : M Unit := modify fun s => { s with modified := true } -private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap +private def getUsedFVarsMap : M UsedFVarsMap := do pure (← get).usedFVarsMap private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap } -- merge s₂ into s₁ @@ -342,11 +333,11 @@ s₂.foldM s₁ private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do -usedFVarsMap ← getUsedFVarsMap; +let usedFVarsMap ← getUsedFVarsMap match usedFVarsMap.find? fvarId with | none => pure () -| some fvarIds => do - fvarIdsNew ← fvarIds.foldM +| some fvarIds => + let fvarIdsNew ← fvarIds.foldM (fun (fvarIdsNew : NameSet) (fvarId' : FVarId) => if fvarId == fvarId' then pure fvarIdsNew @@ -362,13 +353,14 @@ match usedFVarsMap.find? fvarId with private partial def fixpoint : Unit → M Unit | _ => do - resetModified; - letRecFVarIds ← read; - letRecFVarIds.forM updateUsedVarsOf; - whenM isModified $ fixpoint () + resetModified + let letRecFVarIds ← read + letRecFVarIds.forM updateUsedVarsOf + if (← isModified) then + fixpoint () def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap := -let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }; +let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap } s.usedFVarsMap end FixPoint @@ -378,20 +370,20 @@ abbrev FreeVarMap := NameMap (Array FVarId) private def mkFreeVarMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := -let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift; -let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; -let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap; +let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift +let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId +let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap letRecsToLift.foldl (fun (freeVarMap : FreeVarMap) toLift => - let lctx := toLift.lctx; - let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!; + let lctx := toLift.lctx + let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get! let fvarIds := fvarIdsSet.fold (fun (fvarIds : Array FVarId) (fvarId : FVarId) => if lctx.contains fvarId && !recFVarIds.contains fvarId then fvarIds.push fvarId else fvarIds) - #[]; + #[] freeVarMap.insert toLift.fvarId fvarIds) {} @@ -405,9 +397,9 @@ private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index private def preprocess (e : Expr) : TermElabM Expr := do -e ← instantiateMVars e; +let e ← instantiateMVars e -- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect. -liftM $ check e; +Meta.check e pure e /- Push free variables in `s` to `toProcess` if they are not already there. -/ @@ -418,47 +410,47 @@ s.fvarSet.fold private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default) : StateRefT ClosureState TermElabM (Array FVarId) := do -type ← liftM $ preprocess type; +let type ← preprocess type modify fun s => { s with newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi, exprArgs := s.exprArgs.push (mkFVar fvarId) -}; +} pure $ pushNewVars toProcess (collectFVars {} type) private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit | toProcess => do - lctx ← getLCtx; + let lctx ← getLCtx match pickMaxFVar? lctx toProcess with | none => pure () - | some fvarId => do - trace `Elab.definition.mkClosure fun _ => "toProcess: " ++ (toProcess.map mkFVar) ++ ", maxVar: " ++ mkFVar fvarId; - let toProcess := toProcess.erase fvarId; - localDecl ← getLocalDecl fvarId; + | some fvarId => + trace[Elab.definition.mkClosure]! "toProcess: {toProcess.map mkFVar}, maxVar: {mkFVar fvarId}" + let toProcess := toProcess.erase fvarId + let localDecl ← getLocalDecl fvarId match localDecl with - | LocalDecl.cdecl _ _ userName type bi => do - toProcess ← pushLocalDecl toProcess fvarId userName type bi; + | LocalDecl.cdecl _ _ userName type bi => + let toProcess ← pushLocalDecl toProcess fvarId userName type bi mkClosureForAux toProcess - | LocalDecl.ldecl _ _ userName type val _ => do - zetaFVarIds ← getZetaFVarIds; - if !zetaFVarIds.contains fvarId then do + | LocalDecl.ldecl _ _ userName type val _ => + let zetaFVarIds ← getZetaFVarIds; + if !zetaFVarIds.contains fvarId then /- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/ - toProcess ← pushLocalDecl toProcess fvarId userName type; + let toProcess ← pushLocalDecl toProcess fvarId userName type mkClosureForAux toProcess - else do + else /- Dependent let-decl. -/ - type ← liftM $ preprocess type; - val ← liftM $ preprocess val; + let type ← preprocess type + let val ← preprocess val modify fun s => { s with newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false, /- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId at `newLocalDecls` and `localDecls` -/ newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val), localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val) - }; + } mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val)) private partial def mkClosureFor (freeVars : Array FVarId) (localDecls : Array LocalDecl) : TermElabM ClosureState := do -(_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls }; +let (_, s) ← (mkClosureForAux freeVars).run { localDecls := localDecls } pure { s with newLocalDecls := s.newLocalDecls.reverse, newLetDecls := s.newLetDecls.reverse, @@ -470,16 +462,16 @@ structure LetRecClosure := (toLift : LetRecToLift) private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do -let lctx := toLift.lctx; +let lctx := toLift.lctx withLCtx lctx toLift.localInstances do lambdaTelescope toLift.val fun xs val => do - type ← instantiateForall toLift.type xs; - lctx ← getLCtx; - s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId!; - let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type; - let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val; - let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs; - assignExprMVar toLift.mvarId c; + let type ← instantiateForall toLift.type xs + let lctx ← getLCtx + let s ← mkClosureFor freeVars $ xs.map fun x => lctx.get! x.fvarId! + let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type + let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val + let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs + assignExprMVar toLift.mvarId c pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩ private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) := @@ -508,9 +500,9 @@ def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHea : TermElabM (Array PreDefinition) := mainHeaders.size.foldM (fun i (preDefs : Array PreDefinition) => do - let header := mainHeaders.get! i; - val ← mkLambdaFVars sectionVars (mainVals.get! i); - type ← mkForallFVars sectionVars header.type; + let header := mainHeaders[i] + let val ← mkLambdaFVars sectionVars mainVals[i] + let type ← mkForallFVars sectionVars header.type pure $ preDefs.push { kind := header.kind, declName := header.declName, @@ -524,8 +516,8 @@ mainHeaders.size.foldM def pushLetRecs (preDefs : Array PreDefinition) (letRecClosures : List LetRecClosure) (kind : DefKind) (modifiers : Modifiers) : Array PreDefinition := letRecClosures.foldl (fun (preDefs : Array PreDefinition) (c : LetRecClosure) => - let type := Closure.mkForall c.localDecls c.toLift.type; - let val := Closure.mkLambda c.localDecls c.toLift.val; + let type := Closure.mkForall c.localDecls c.toLift.type + let val := Closure.mkLambda c.localDecls c.toLift.val preDefs.push { kind := kind, declName := c.toLift.declName, @@ -555,29 +547,29 @@ def getModifiersForLetRecs (mainHeaders : Array DefViewElabHeader) : Modifiers : def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift) : TermElabM (Array PreDefinition) := do -- Store in recFVarIds the fvarId of every function being defined by the mutual block. -let mainFVarIds := mainFVars.map Expr.fvarId!; -let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds; +let mainFVarIds := mainFVars.map Expr.fvarId! +let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds -- Compute the set of free variables (excluding `recFVarIds`) for each let-rec. -mctx ← getMCtx; -let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift; -resetZetaFVarIds; +let mctx ← getMCtx +let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift +resetZetaFVarIds withTrackingZeta do --- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`. -letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val }; -letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap; --- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations. -mainVals ← mainVals.mapM instantiateMVars; -mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader; -letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } }; --- Replace fvarIds for functions being defined with closed terms -let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars; -let r := insertReplacementForLetRecs r letRecClosures; -let mainVals := mainVals.map r.apply; -let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type }; -let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } }; -let letRecKind := getKindForLetRecs mainHeaders; -let letRecMods := getModifiersForLetRecs mainHeaders; -pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals + -- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`. + letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do Meta.check toLift.type; Meta.check toLift.val + let letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap + -- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations. + let mainVals ← mainVals.mapM instantiateMVars + let mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader + let letRecClosures ← letRecClosures.mapM fun closure => do pure { closure with toLift := (← instantiateMVarsAtLetRecToLift closure.toLift) } + -- Replace fvarIds for functions being defined with closed terms + let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars + let r := insertReplacementForLetRecs r letRecClosures + let mainVals := mainVals.map r.apply + let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type } + let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } } + let letRecKind := getKindForLetRecs mainHeaders + let letRecMods := getModifiersForLetRecs mainHeaders + pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars mainHeaders mainVals end MutualClosure @@ -589,34 +581,34 @@ else [] def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do -scopeLevelNames ← getLevelNames; -headers ← elabHeaders views; -let allUserLevelNames := getAllUserLevelNames headers; +let scopeLevelNames ← getLevelNames +let headers ← elabHeaders views +let allUserLevelNames := getAllUserLevelNames headers withFunLocalDecls headers fun funFVars => do - values ← elabFunValues headers; - Term.synthesizeSyntheticMVarsNoPostponing; - if isExample views then pure () - else do - values ← values.mapM instantiateMVars; - headers ← headers.mapM instantiateMVarsAtHeader; - letRecsToLift ← getLetRecsToLift; - letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift; - checkLetRecsToLiftTypes funFVars letRecsToLift; + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + if isExample views then + pure () + else + let values ← values.mapM instantiateMVars + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do - preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift; - preDefs ← levelMVarToParamPreDecls preDefs; - preDefs ← instantiateMVarsAtPreDecls preDefs; - preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames; + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + let preDefs ← levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames addPreDefinitions preDefs end Term namespace Command def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do -views ← ds.mapM fun d => do { - modifiers ← elabModifiers (d.getArg 0); - mkDefView modifiers (d.getArg 1) -}; +let views ← ds.mapM fun d => do + let modifiers ← elabModifiers d[0] + mkDefView modifiers d[1] runTermElabM none fun vars => Term.elabMutualDef vars views end Command