From a1413b8fa130bb06e75e3ff346880ce6a1b7f268 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 3 Jul 2022 21:49:45 -0700 Subject: [PATCH] feat: cache failures at `isDefEq` We can compile Lean with these changes, but 3 tests are still broken. This cache is used to address a performance issue reported at https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/deterministic.20timeout.20with.20structures/near/288180087 --- src/Lean/Elab/App.lean | 2 +- src/Lean/Elab/BuiltinTerm.lean | 9 +- src/Lean/Elab/Match.lean | 3 +- src/Lean/Elab/MutualDef.lean | 20 +- src/Lean/Elab/Term.lean | 6 +- src/Lean/Meta/Basic.lean | 56 +++--- src/Lean/Meta/CollectMVars.lean | 2 +- src/Lean/Meta/DecLevel.lean | 3 +- src/Lean/Meta/ExprDefEq.lean | 79 ++++---- src/Lean/Meta/GeneralizeVars.lean | 2 +- src/Lean/Meta/PPGoal.lean | 4 +- src/Lean/Meta/Tactic/Cases.lean | 16 +- src/Lean/Meta/Tactic/Clear.lean | 4 +- src/Lean/Meta/Tactic/Induction.lean | 9 +- src/Lean/Meta/Tactic/Revert.lean | 46 +++-- src/Lean/Meta/Tactic/Subst.lean | 24 +-- src/Lean/Meta/WHNF.lean | 2 +- src/Lean/MetavarContext.lean | 283 ++++++++++++++++------------ src/Lean/Server/Rpc/Deriving.lean | 2 +- src/Lean/Util/OccursCheck.lean | 18 +- 20 files changed, 320 insertions(+), 270 deletions(-) diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index ae222e6350..65535684b3 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -343,7 +343,7 @@ private def anyNamedArgDependsOnCurrent : M Bool := do for i in [1:xs.size] do let xDecl ← getLocalDecl xs[i]!.fvarId! if s.namedArgs.any fun arg => arg.name == xDecl.userName then - if (← getMCtx).localDeclDependsOn xDecl curr.fvarId! then + if (← MetavarContext.localDeclDependsOn xDecl curr.fvarId!) then return true return false diff --git a/src/Lean/Elab/BuiltinTerm.lean b/src/Lean/Elab/BuiltinTerm.lean index 70afe85867..77901fa469 100644 --- a/src/Lean/Elab/BuiltinTerm.lean +++ b/src/Lean/Elab/BuiltinTerm.lean @@ -72,8 +72,7 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level := if userName.isAnonymous || (← read).inPattern then mkNewHole () else - let mctx ← getMCtx - match mctx.findUserName? userName with + match (← getMCtx).findUserName? userName with | none => mkNewHole () | some mvarId => let mvar := mkMVar mvarId @@ -81,16 +80,16 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level := let lctx ← getLCtx if mvarDecl.lctx.isSubPrefixOf lctx then return mvar - else match mctx.getExprAssignment? mvarId with + else match (← getExprMVarAssignment? mvarId) with | some val => let val ← instantiateMVars val - if mctx.isWellFormed lctx val then + if (← MetavarContext.isWellFormed lctx val) then return val else withLCtx mvarDecl.lctx mvarDecl.localInstances do throwError "synthetic hole has already been defined and assigned to value incompatible with the current context{indentExpr val}" | none => - if mctx.isDelayedAssigned mvarId then + if (← getMCtx).isDelayedAssigned mvarId then -- We can try to improve this case if needed. throwError "synthetic hole has already beend defined and delayed assigned with an incompatible local context" else if lctx.isSubPrefixOf mvarDecl.lctx then diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 06d71e02f5..549906a159 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -982,9 +982,8 @@ where for fvarId in s.fvarSet.toList do unless containsFVar discrs fvarId || containsFVar indices fvarId do let localDecl ← getLocalDecl fvarId - let mctx ← getMCtx for indexFVarId in indicesFVar do - if mctx.localDeclDependsOn localDecl indexFVarId then + if (← MetavarContext.localDeclDependsOn localDecl indexFVarId) then toAdd := toAdd.push fvarId let indicesFVar ← sortFVarIds (indicesFVar ++ toAdd) return indicesFVar.map mkFVar ++ indicesNonFVar diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index eee16a5841..7f07b62197 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -326,8 +326,8 @@ we would have a `LetRecToLift` containing: Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover the fact that `f` depends on `g` because it contains `m₂` -/ -private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) - : UsedFVarsMap := Id.run do +private def mkInitialUsedFVarsMap [Monad m] [MonadMCtx m] (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) + : m UsedFVarsMap := do let mut sectionVarSet := {} for var in sectionVars do sectionVarSet := sectionVarSet.insert var.fvarId! @@ -342,11 +342,11 @@ private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array E for the associated let-rec because we need this information to compute the fixpoint later. -/ let mvarIds := (toLift.val.collectMVars {}).result for mvarId in mvarIds do - match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with + match (← letRecsToLift.findSomeM? fun (toLift : LetRecToLift) => return if toLift.mvarId == (← getDelayedMVarRoot mvarId) then some toLift.fvarId else none) with | some fvarId => set := set.insert fvarId | none => pure () usedFVarMap := usedFVarMap.insert toLift.fvarId set - pure usedFVarMap + return usedFVarMap /- The let-recs may invoke each other. Example: @@ -423,10 +423,10 @@ end FixPoint abbrev FreeVarMap := FVarIdMap (Array FVarId) -private def mkFreeVarMap - (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) - (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : FreeVarMap := Id.run do - let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift +private def mkFreeVarMap [Monad m] [MonadMCtx m] + (sectionVars : Array Expr) (mainFVarIds : Array FVarId) + (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : m FreeVarMap := do + let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap let mut freeVarMap := {} @@ -536,7 +536,7 @@ private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : TermElabM (List LetRecClosure) := do -- Compute the set of free variables (excluding `recFVarIds`) for each let-rec. let mut letRecsToLift := letRecsToLift - let mut freeVarMap := mkFreeVarMap (← getMCtx) sectionVars mainFVarIds recFVarIds letRecsToLift + let mut freeVarMap ← mkFreeVarMap sectionVars mainFVarIds recFVarIds letRecsToLift let mut result := #[] for i in [:letRecsToLift.size] do if letRecsToLift[i]!.val.hasExprMVar then @@ -546,7 +546,7 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa let valNew ← instantiateMVars letRecsToLift[i]!.val letRecsToLift := letRecsToLift.modify i fun t => { t with val := valNew } -- We have to recompute the `freeVarMap` in this case. This overhead should not be an issue in practice. - freeVarMap := mkFreeVarMap (← getMCtx) sectionVars mainFVarIds recFVarIds letRecsToLift + freeVarMap ← mkFreeVarMap sectionVars mainFVarIds recFVarIds letRecsToLift let toLift := letRecsToLift[i]! result := result.push (← mkLetRecClosureFor toLift (freeVarMap.find? toLift.fvarId).get!) return result.toList diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index 4e0855fda8..47d83db169 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -694,7 +694,7 @@ partial def visit (e : Expr) : M Unit := do if e' != e then visit e' else - match (← getDelayedAssignment? mvarId) with + match (← getDelayedMVarAssignment? mvarId) with | some d => visit d.val | none => failure | _ => return () @@ -1531,7 +1531,7 @@ where if auto.isFVar then let localDecl ← getLocalDecl auto.fvarId! for x in xs do - if (← getMCtx).localDeclDependsOn localDecl x.fvarId! then + if (← MetavarContext.localDeclDependsOn localDecl x.fvarId!) then throwError "invalid auto implicit argument '{auto}', it depends on explicitly provided argument '{x}'" return autos ++ xs | auto :: todo => @@ -1562,7 +1562,7 @@ builtin_initialize registerTraceClass `Elab.letrec is delayed assigned to one. -/ def isLetRecAuxMVar (mvarId : MVarId) : TermElabM Bool := do trace[Elab.letrec] "mvarId: {mkMVar mvarId} letrecMVars: {(← get).letRecsToLift.map (mkMVar $ ·.mvarId)}" - let mvarId := (← getMCtx).getDelayedRoot mvarId + let mvarId ← getDelayedMVarRoot mvarId trace[Elab.letrec] "mvarId root: {mkMVar mvarId}" return (← get).letRecsToLift.any (·.mvarId == mvarId) diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 24cb189be1..2dd67ccc4c 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -192,9 +192,11 @@ abbrev InferTypeCache := PersistentExprStructMap Expr abbrev FunInfoCache := PersistentHashMap InfoCacheKey FunInfo abbrev WhnfCache := PersistentExprStructMap Expr -/- A set of pairs. TODO: consider more efficient representations (e.g., a proper set) and caching policies (e.g., imperfect cache). - We should also investigate the impact on memory consumption. -/ -abbrev DefEqCache := PersistentHashMap (Expr × Expr) Unit +/-- + A mapping `(s, t) ↦ isDefEq s t`. + TODO: consider more efficient representations (e.g., a proper set) and caching policies (e.g., imperfect cache). + We should also investigate the impact on memory consumption. -/ +abbrev DefEqCache := PersistentHashMap (Expr × Expr) Bool /-- Cache datastructures for type inference, type class resolution, whnf, and definitional equality. @@ -237,12 +239,12 @@ structure PostponedEntry where `MetaM` monad state. -/ structure State where - mctx : MetavarContext := {} - cache : Cache := {} + mctx : MetavarContext := {} + cache : Cache := {} /-- When `trackZeta == true`, then any let-decl free variable that is zeta expansion performed by `MetaM` is stored in `zetaFVarIds`. -/ - zetaFVarIds : FVarIdSet := {} + zetaFVarIds : FVarIdSet := {} /-- Array of postponed universe level constraints -/ - postponed : PersistentArray PostponedEntry := {} + postponed : PersistentArray PostponedEntry := {} deriving Inhabited /-- @@ -305,7 +307,8 @@ protected def saveState : MetaM SavedState := /-- Restore backtrackable parts of the state. -/ def SavedState.restore (b : SavedState) : MetaM Unit := do Core.restore b.core - modify fun s => { s with mctx := b.meta.mctx, zetaFVarIds := b.meta.zetaFVarIds, postponed := b.meta.postponed } + let usedAssignment := (← getMCtx).usedAssignment + modify fun s => { s with mctx := { b.meta.mctx with usedAssignment }, zetaFVarIds := b.meta.zetaFVarIds, postponed := b.meta.postponed } instance : MonadBacktrack SavedState MetaM where saveState := Meta.saveState @@ -488,8 +491,11 @@ def shouldReduceAll : MetaM Bool := def shouldReduceReducibleOnly : MetaM Bool := return (← getTransparency) == TransparencyMode.reducible +def findMVarDecl? (mvarId : MVarId) : MetaM (Option MetavarDecl) := + return (← getMCtx).findDecl? mvarId + def getMVarDecl (mvarId : MVarId) : MetaM MetavarDecl := do - match (← getMCtx).findDecl? mvarId with + match (← findMVarDecl? mvarId) with | some d => pure d | none => throwError "unknown metavariable '?{mvarId.name}'" @@ -527,22 +533,12 @@ def setMVarUserName (mvarId : MVarId) (newUserName : Name) : MetaM Unit := def isExprMVarAssigned (mvarId : MVarId) : MetaM Bool := return (← getMCtx).isExprAssigned mvarId -def getExprMVarAssignment? (mvarId : MVarId) : MetaM (Option Expr) := - return (← getMCtx).getExprAssignment? mvarId - -/-- Return true if `e` contains `mvarId` directly or indirectly -/ -def occursCheck (mvarId : MVarId) (e : Expr) : MetaM Bool := - return (← getMCtx).occursCheck mvarId e - def assignExprMVar (mvarId : MVarId) (val : Expr) : MetaM Unit := modifyMCtx fun mctx => mctx.assignExpr mvarId val def isDelayedAssigned (mvarId : MVarId) : MetaM Bool := return (← getMCtx).isDelayedAssigned mvarId -def getDelayedAssignment? (mvarId : MVarId) : MetaM (Option DelayedMetavarAssignment) := - return (← getMCtx).getDelayedAssignment? mvarId - def hasAssignableMVar (e : Expr) : MetaM Bool := return (← getMCtx).hasAssignableMVar e @@ -595,6 +591,10 @@ def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MetaM Expr := def abstract (e : Expr) (xs : Array Expr) : MetaM Expr := abstractRange e xs.size xs +def collectForwardDeps (toRevert : Array Expr) (preserveOrder : Bool) : MetaM (Array Expr) := do + liftMkBindingM <| MetavarContext.collectForwardDeps toRevert preserveOrder + + /-- Takes an array `xs` of free variables or metavariables and a term `e` that may contain those variables, and abstracts and binds them as universal quantifiers. - if `usedOnly = true` then only variables that the expression body depends on will appear. @@ -1144,7 +1144,13 @@ private def withNewMCtxDepthImp (x : MetaM α) : MetaM α := do try x finally - modify fun s => { s with mctx := saved.mctx, postponed := saved.postponed } + -- TODO: document why we need to restore defEqCache + modify fun s => { s with + mctx := saved.mctx + cache.defEqDefault := saved.cache.defEqDefault + cache.defEqAll := saved.cache.defEqAll + postponed := saved.postponed + } /-- Save cache and `MetavarContext`, bump the `MetavarContext` depth, execute `x`, @@ -1283,19 +1289,19 @@ def instantiateLambda (e : Expr) (ps : Array Expr) : MetaM Expr := /-- Return true iff `e` depends on the free variable `fvarId` -/ def dependsOn (e : Expr) (fvarId : FVarId) : MetaM Bool := - return (← getMCtx).exprDependsOn e fvarId + MetavarContext.exprDependsOn e fvarId /-- Return true iff `e` depends on the free variable `fvarId` -/ def localDeclDependsOn (localDecl : LocalDecl) (fvarId : FVarId) : MetaM Bool := - return (← getMCtx).localDeclDependsOn localDecl fvarId + MetavarContext.localDeclDependsOn localDecl fvarId /-- Return true iff `e` depends on a free variable `x` s.t. `pf x`, or an unassigned metavariable `?m` s.t. `pm ?m` is true. -/ def dependsOnPred (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : MetaM Bool := - return (← getMCtx).findExprDependsOn e pf pm + MetavarContext.findExprDependsOn e pf pm /-- Return true iff the local declaration `localDecl` depends on a free variable `x` s.t. `pf x`, an unassigned metavariable `?m` s.t. `pm ?m` is true. -/ def localDeclDependsOnPred (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : MetaM Bool := do - return (← getMCtx).findLocalDeclDependsOn localDecl pf pm + MetavarContext.findLocalDeclDependsOn localDecl pf pm /-- Pretty-print the given expression. -/ def ppExpr (e : Expr) : MetaM Format := do @@ -1379,7 +1385,7 @@ def isListLevelDefEqAux : List Level → List Level → MetaM Bool | u::us, v::vs => isLevelDefEqAux u v <&&> isListLevelDefEqAux us vs | _, _ => return false -private def getNumPostponed : MetaM Nat := do +def getNumPostponed : MetaM Nat := do return (← getPostponed).size def getResetPostponed : MetaM (PersistentArray PostponedEntry) := do diff --git a/src/Lean/Meta/CollectMVars.lean b/src/Lean/Meta/CollectMVars.lean index 3bc3373217..8d681dbb68 100644 --- a/src/Lean/Meta/CollectMVars.lean +++ b/src/Lean/Meta/CollectMVars.lean @@ -24,7 +24,7 @@ partial def collectMVars (e : Expr) : StateRefT CollectMVars.State MetaM Unit := let s := e.collectMVars s set s for mvarId in s.result[resultSavedSize:] do - match (← getDelayedAssignment? mvarId) with + match (← getDelayedMVarAssignment? mvarId) with | none => pure () | some d => collectMVars d.val diff --git a/src/Lean/Meta/DecLevel.lean b/src/Lean/Meta/DecLevel.lean index e938946ed2..a3e36db58d 100644 --- a/src/Lean/Meta/DecLevel.lean +++ b/src/Lean/Meta/DecLevel.lean @@ -19,8 +19,7 @@ private partial def decAux? : Level → ReaderT DecLevelContext MetaM (Option Le | Level.zero _ => return none | Level.param _ _ => return none | Level.mvar mvarId _ => do - let mctx ← getMCtx - match mctx.getLevelAssignment? mvarId with + match (← getLevelMVarAssignment? mvarId) with | some u => decAux? u | none => if (← isReadOnlyLevelMVar mvarId) || !(← read).canAssignMVars then diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 0cdb2d2f8c..24be82cb20 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -16,7 +16,7 @@ import Lean.Meta.UnificationHint namespace Lean.Meta /-- - Return true `b` is of the form `mk a.1 ... a.n`, and `a` is not a constructor application. + Return true if `b` is of the form `mk a.1 ... a.n`, and `a` is not a constructor application. If `a` and `b` are constructor applications, the method returns `false` to force `isDefEq` to use `isDefEqArgs`. For example, suppose we are trying to solve the constraint @@ -644,14 +644,13 @@ mutual partial def checkMVar (mvar : Expr) : CheckAssignmentM Expr := do let mvarId := mvar.mvarId! let ctx ← read - let mctx ← getMCtx if mvarId == ctx.mvarId then traceM `Meta.isDefEq.assign.occursCheck <| addAssignmentInfo "occurs check failed" throwCheckAssignmentFailure - else match mctx.getExprAssignment? mvarId with + else match (← getExprMVarAssignment? mvarId) with | some v => check v | none => - match mctx.findDecl? mvarId with + match (← findMVarDecl? mvarId) with | none => throwUnknownMVar mvarId | some mvarDecl => if ctx.hasCtxLocals then @@ -660,7 +659,7 @@ mutual /- The local context of `mvar` - free variables being abstracted is a subprefix of the metavariable being assigned. We "substract" variables being abstracted because we use `elimMVarDeps` -/ pure mvar - else if mvarDecl.depth != mctx.depth || mvarDecl.kind.isSyntheticOpaque then + else if mvarDecl.depth != (← getMCtx).depth || mvarDecl.kind.isSyntheticOpaque then traceM `Meta.isDefEq.assign.readOnlyMVarWithBiggerLCtx <| addAssignmentInfo (mkMVar mvarId) throwCheckAssignmentFailure else @@ -678,17 +677,17 @@ mutual Notat that if a variable is `ctx.fvars`, but it depends on variable at `toErase`, we must also erase it. -/ - let toErase := mvarDecl.lctx.foldl (init := #[]) fun toErase localDecl => + let toErase ← mvarDecl.lctx.foldlM (init := #[]) fun toErase localDecl => do if ctx.mvarDecl.lctx.contains localDecl.fvarId then - toErase + return toErase else if ctx.fvars.any fun fvar => fvar.fvarId! == localDecl.fvarId then - if mctx.findLocalDeclDependsOn localDecl fun fvarId => toErase.contains fvarId then + if (← MetavarContext.findLocalDeclDependsOn localDecl fun fvarId => toErase.contains fvarId) then -- localDecl depends on a variable that will be erased. So, we must add it to `toErase` too - toErase.push localDecl.fvarId + return toErase.push localDecl.fvarId else - toErase + return toErase else - toErase.push localDecl.fvarId + return toErase.push localDecl.fvarId let lctx := toErase.foldl (init := mvarDecl.lctx) fun lctx toEraseFVar => lctx.erase toEraseFVar /- Compute new set of local instances. -/ @@ -808,7 +807,7 @@ partial def check if fvars.any fun x => x.fvarId! == fvarId then true else false -- We could throw an exception here, but we would have to use ExceptM. So, we let CheckAssignment.check do it | Expr.mvar mvarId' _ => - match mctx.getExprAssignment? mvarId' with + match mctx.getExprAssignmentCore? mvarId' with | some _ => false -- use CheckAssignment.check to instantiate | none => if mvarId' == mvarId then false -- occurs check failed, use CheckAssignment.check to throw exception @@ -1620,18 +1619,28 @@ private def skipDefEqCache : MetaM Bool := do private def mkCacheKey (t : Expr) (s : Expr) : Expr × Expr := if Expr.quickLt t s then (t, s) else (s, t) -private def isCached (key : Expr × Expr) : MetaM Bool := do - match (← getConfig).transparency with - | TransparencyMode.default => return (← get).cache.defEqDefault.contains key - | TransparencyMode.all => return (← get).cache.defEqAll.contains key - | _ => return false +private def getCachedResult (key : Expr × Expr) : MetaM LBool := do + let cache ← match (← getConfig).transparency with + | TransparencyMode.default => pure (← get).cache.defEqDefault + | TransparencyMode.all => pure (← get).cache.defEqAll + | _ => return .undef + match cache.find? key with + | some val => return val.toLBool + | none => return .undef -private def cacheResult (key : Expr × Expr) : MetaM Unit := do +private def cacheResult (key : Expr × Expr) (result : Bool) : MetaM Unit := do match (← getConfig).transparency with - | TransparencyMode.default => modify fun s => { s with cache.defEqDefault := s.cache.defEqDefault.insert key () } - | TransparencyMode.all => modify fun s => { s with cache.defEqAll := s.cache.defEqAll.insert key () } + | TransparencyMode.default => modify fun s => { s with cache.defEqDefault := s.cache.defEqDefault.insert key result } + | TransparencyMode.all => modify fun s => { s with cache.defEqAll := s.cache.defEqAll.insert key result } | _ => pure () +private abbrev withResetUsedAssignment (k : MetaM α) : MetaM α := do + let usedAssignment := (← getMCtx).usedAssignment + modifyMCtx fun mctx => { mctx with usedAssignment := false } + try + k + finally + modifyMCtx fun mctx => { mctx with usedAssignment } @[export lean_is_expr_def_eq] partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecDepth do @@ -1658,18 +1667,24 @@ partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecD -/ let t ← instantiateMVars t let s ← instantiateMVars s - if t.hasMVar || s.hasMVar then - -- It is not safe to use DefEq cache if terms contain metavariables - isExprDefEqExpensive t s - else - let k := mkCacheKey t s - if (← isCached k) then - return true - else if (← isExprDefEqExpensive t s) then - cacheResult k - return true - else - return false + let numPostponed ← getNumPostponed + let k := mkCacheKey t s + match (← getCachedResult k) with + | .true => + trace[Meta.isDefEq.cache] "cache hit 'true' for {t} =?= {s}" + return true + | .false => + trace[Meta.isDefEq.cache] "cache hit 'false' for {t} =?= {s}" + return false + | .undef => + withResetUsedAssignment do + let result ← isExprDefEqExpensive t s + if numPostponed == (← getNumPostponed) && !(← getMCtx).usedAssignment then + -- It is only safe to cache the result if the mvars assignments have not been accessed/used, + -- and universe level variables have not been postponed. + trace[Meta.isDefEq.cache] "cache {result} for {t} =?= {s}" + cacheResult k result + return result builtin_initialize registerTraceClass `Meta.isDefEq diff --git a/src/Lean/Meta/GeneralizeVars.lean b/src/Lean/Meta/GeneralizeVars.lean index bd2867828a..c6dba7af09 100644 --- a/src/Lean/Meta/GeneralizeVars.lean +++ b/src/Lean/Meta/GeneralizeVars.lean @@ -60,7 +60,7 @@ def getFVarSetToGeneralize (targets : Array Expr) (forbidden : FVarIdSet) (ignor for localDecl in (← getLCtx) do unless forbidden.contains localDecl.fvarId do unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit || (ignoreLetDecls && localDecl.isLet) do - if (← getMCtx).findLocalDeclDependsOn localDecl (s.contains ·) then + if (← MetavarContext.findLocalDeclDependsOn localDecl (s.contains ·)) then r := r.insert localDecl.fvarId s := s.insert localDecl.fvarId return r diff --git a/src/Lean/Meta/PPGoal.lean b/src/Lean/Meta/PPGoal.lean index 004ff02b0f..5651036fdc 100644 --- a/src/Lean/Meta/PPGoal.lean +++ b/src/Lean/Meta/PPGoal.lean @@ -69,13 +69,13 @@ def moveToHiddeProp (fvarId : FVarId) : M Unit := do Recall that hiddenInaccessibleProps are visible, only their names are hidden -/ def hasVisibleDep (localDecl : LocalDecl) : M Bool := do let s ← get - return (← getMCtx).findLocalDeclDependsOn localDecl (!s.hiddenInaccessible.contains ·) + MetavarContext.findLocalDeclDependsOn localDecl (!s.hiddenInaccessible.contains ·) /- Return true if the given local declaration has a "nonvisible dependency", that is, it contains a free variable that is `hiddenInaccessible` or `hiddenInaccessibleProp` -/ def hasInaccessibleNameDep (localDecl : LocalDecl) : M Bool := do let s ← get - return (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId => + MetavarContext.findLocalDeclDependsOn localDecl fun fvarId => s.hiddenInaccessible.contains fvarId || s.hiddenInaccessibleProp.contains fvarId /- If `e` is visible, then any inaccessible in `e` marked as hidden should be unmarked. -/ diff --git a/src/Lean/Meta/Tactic/Cases.lean b/src/Lean/Meta/Tactic/Cases.lean index c70fb8c5ce..ca312ffd95 100644 --- a/src/Lean/Meta/Tactic/Cases.lean +++ b/src/Lean/Meta/Tactic/Cases.lean @@ -163,20 +163,18 @@ We say the major premise has independent indices IF -/ private def hasIndepIndices (ctx : Context) : MetaM Bool := do if ctx.majorTypeIndices.isEmpty then - pure true + return true else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then /- One of the indices is not a free variable. -/ - pure false + return false else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then /- An index ocurrs more than once -/ - pure false + return false else - let lctx ← getLCtx - let mctx ← getMCtx - return lctx.all fun decl => - decl.fvarId == ctx.majorDecl.fvarId || -- decl is the major - ctx.majorTypeIndices.any (fun index => decl.fvarId == index.fvarId!) || -- decl is one of the indices - mctx.findLocalDeclDependsOn decl (fun fvarId => ctx.majorTypeIndices.all fun idx => idx.fvarId! != fvarId) -- or does not depend on any index + (← getLCtx).allM fun decl => + pure (decl.fvarId == ctx.majorDecl.fvarId) <||> -- decl is the major + pure (ctx.majorTypeIndices.any (fun index => decl.fvarId == index.fvarId!)) <||> -- decl is one of the indices + MetavarContext.findLocalDeclDependsOn decl (fun fvarId => ctx.majorTypeIndices.all fun idx => idx.fvarId! != fvarId) -- or does not depend on any index private def elimAuxIndices (s₁ : GeneralizeIndicesSubgoal) (s₂ : Array CasesSubgoal) : MetaM (Array CasesSubgoal) := let indicesFVarIds := s₁.indicesFVarIds diff --git a/src/Lean/Meta/Tactic/Clear.lean b/src/Lean/Meta/Tactic/Clear.lean index cc9b4a81d3..4c93c12e49 100644 --- a/src/Lean/Meta/Tactic/Clear.lean +++ b/src/Lean/Meta/Tactic/Clear.lean @@ -17,10 +17,10 @@ def clear (mvarId : MVarId) (fvarId : FVarId) : MetaM MVarId := let mctx ← getMCtx lctx.forM fun localDecl => do unless localDecl.fvarId == fvarId do - if mctx.localDeclDependsOn localDecl fvarId then + if (← MetavarContext.localDeclDependsOn localDecl fvarId) then throwTacticEx `clear mvarId m!"variable '{localDecl.toExpr}' depends on '{mkFVar fvarId}'" let mvarDecl ← getMVarDecl mvarId - if mctx.exprDependsOn mvarDecl.type fvarId then + if (← MetavarContext.exprDependsOn mvarDecl.type fvarId) then throwTacticEx `clear mvarId m!"target depends on '{mkFVar fvarId}'" let lctx := lctx.erase fvarId let localInsts ← getLocalInstances diff --git a/src/Lean/Meta/Tactic/Induction.lean b/src/Lean/Meta/Tactic/Induction.lean index e9421941a3..2315e62d63 100644 --- a/src/Lean/Meta/Tactic/Induction.lean +++ b/src/Lean/Meta/Tactic/Induction.lean @@ -143,17 +143,18 @@ def induction (mvarId : MVarId) (majorFVarId : FVarId) (recursorName : Name) (gi let arg := majorTypeArgs[i]! if i != idxPos && arg == idx then throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs more than once{indentExpr majorType}" - if i < idxPos && mctx.exprDependsOn arg idx.fvarId! then - throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs in previous arguments{indentExpr majorType}" + if i < idxPos then + if (← MetavarContext.exprDependsOn arg idx.fvarId!) then + throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs in previous arguments{indentExpr majorType}" -- If arg is also and index and a variable occurring after `idx`, we need to make sure it doesn't depend on `idx`. -- Note that if `arg` is not a variable, we will fail anyway when we visit it. if i > idxPos && recursorInfo.indicesPos.contains i && arg.isFVar then let idxDecl ← getLocalDecl idx.fvarId! - if mctx.localDeclDependsOn idxDecl arg.fvarId! then + if (← MetavarContext.localDeclDependsOn idxDecl arg.fvarId!) then throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it depends on index occurring at position #{i+1}" pure idx let target ← getMVarType mvarId - if !recursorInfo.depElim && mctx.exprDependsOn target majorFVarId then + if (← pure !recursorInfo.depElim <&&> MetavarContext.exprDependsOn target majorFVarId) then throwTacticEx `induction mvarId m!"recursor '{recursorName}' does not support dependent elimination, but conclusion depends on major premise" -- Revert indices and major premise preserving variable order let (reverted, mvarId) ← revert mvarId ((indices.map Expr.fvarId!).push majorFVarId) true diff --git a/src/Lean/Meta/Tactic/Revert.lean b/src/Lean/Meta/Tactic/Revert.lean index 944a070f32..7dacfea3d7 100644 --- a/src/Lean/Meta/Tactic/Revert.lean +++ b/src/Lean/Meta/Tactic/Revert.lean @@ -16,30 +16,28 @@ def revert (mvarId : MVarId) (fvarIds : Array FVarId) (preserveOrder : Bool := f if (← getLocalDecl fvarId) |>.isAuxDecl then throwError "failed to revert {mkFVar fvarId}, it is an auxiliary declaration created to represent recursive definitions" let fvars := fvarIds.map mkFVar - match MetavarContext.MkBinding.collectForwardDeps (← getMCtx) (← getLCtx) fvars preserveOrder with - | Except.error _ => throwError "failed to revert variables {fvars}" - | Except.ok toRevert => - /- We should clear any `auxDecl` in `toRevert` -/ - let mut mvarId := mvarId - let mut toRevertNew := #[] - for x in toRevert do - if (← getLocalDecl x.fvarId!) |>.isAuxDecl then - mvarId ← clear mvarId x.fvarId! - else - toRevertNew := toRevertNew.push x - let tag ← getMVarTag mvarId - -- TODO: the following code can be optimized because `MetavarContext.revert` will compute `collectDeps` again. - -- We should factor out the relevat part + let toRevert ← collectForwardDeps fvars preserveOrder + /- We should clear any `auxDecl` in `toRevert` -/ + let mut mvarId := mvarId + let mut toRevertNew := #[] + for x in toRevert do + if (← getLocalDecl x.fvarId!) |>.isAuxDecl then + mvarId ← clear mvarId x.fvarId! + else + toRevertNew := toRevertNew.push x + let tag ← getMVarTag mvarId + -- TODO: the following code can be optimized because `MetavarContext.revert` will compute `collectDeps` again. + -- We should factor out the relevat part - -- Set metavariable kind to natural to make sure `revert` will assign it. - setMVarKind mvarId MetavarKind.natural - let (e, toRevert) ← - try - liftMkBindingM <| MetavarContext.revert toRevertNew mvarId preserveOrder - finally - setMVarKind mvarId MetavarKind.syntheticOpaque - let mvar := e.getAppFn - setMVarTag mvar.mvarId! tag - return (toRevert.map Expr.fvarId!, mvar.mvarId!) + -- Set metavariable kind to natural to make sure `revert` will assign it. + setMVarKind mvarId MetavarKind.natural + let (e, toRevert) ← + try + liftMkBindingM <| MetavarContext.revert toRevertNew mvarId preserveOrder + finally + setMVarKind mvarId MetavarKind.syntheticOpaque + let mvar := e.getAppFn + setMVarTag mvar.mvarId! tag + return (toRevert.map Expr.fvarId!, mvar.mvarId!) end Lean.Meta diff --git a/src/Lean/Meta/Tactic/Subst.lean b/src/Lean/Meta/Tactic/Subst.lean index 3358c13b82..5b8106bcd4 100644 --- a/src/Lean/Meta/Tactic/Subst.lean +++ b/src/Lean/Meta/Tactic/Subst.lean @@ -29,8 +29,7 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst : | Expr.fvar aFVarId _ => do let aFVarIdOriginal := aFVarId trace[Meta.Tactic.subst] "substituting {a} (id: {aFVarId.name}) with {b}" - let mctx ← getMCtx - if mctx.exprDependsOn b aFVarId then + if (← MetavarContext.exprDependsOn b aFVarId) then throwTacticEx `subst mvarId m!"'{a}' occurs at{indentExpr b}" let (vars, mvarId) ← revert mvarId #[aFVarId, hFVarId] true trace[Meta.Tactic.subst] "after revert {MessageData.ofGoal mvarId}" @@ -46,8 +45,9 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst : pure false else let mvarType ← getMVarType mvarId - let mctx ← getMCtx - pure (!mctx.exprDependsOn mvarType aFVarId && !mctx.exprDependsOn mvarType hFVarId) + if (← MetavarContext.exprDependsOn mvarType aFVarId) then pure false + else if (← MetavarContext.exprDependsOn mvarType hFVarId) then pure false + else pure true if skip then if clearH then let mvarId ← clear mvarId hFVarId @@ -64,8 +64,7 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst : | none => unreachable! | some (_, lhs, rhs) => do let b ← instantiateMVars <| if symm then lhs else rhs - let mctx ← getMCtx - let depElim := mctx.exprDependsOn mvarDecl.type hFVarId + let depElim ← MetavarContext.exprDependsOn mvarDecl.type hFVarId let cont (motive : Expr) (newType : Expr) : MetaM (FVarSubst × MVarId) := do let major ← if symm then pure h else mkEqSymm h let newMVar ← mkFreshExprSyntheticOpaqueMVar newType tag @@ -192,12 +191,13 @@ where | some (_, lhs, rhs) => let lhs ← instantiateMVars lhs let rhs ← instantiateMVars rhs - if rhs.isFVar && rhs.fvarId! == h && !mctx.exprDependsOn lhs h then - return some (localDecl.fvarId, true) - else if lhs.isFVar && lhs.fvarId! == h && !mctx.exprDependsOn rhs h then - return some (localDecl.fvarId, false) - else - return none + if rhs.isFVar && rhs.fvarId! == h then + if !(← MetavarContext.exprDependsOn lhs h) then + return some (localDecl.fvarId, true) + if lhs.isFVar && lhs.fvarId! == h then + if !(← MetavarContext.exprDependsOn rhs h) then + return some (localDecl.fvarId, false) + return none | _ => return none | throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar h}'" return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2 diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 2546d86f42..8dd7b25d95 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -450,7 +450,7 @@ def reduceProj? (e : Expr) : MetaM (Option Expr) := do -/ private def whnfDelayedAssigned? (f' : Expr) (e : Expr) : MetaM (Option Expr) := do if f'.isMVar then - match (← getDelayedAssignment? f'.mvarId!) with + match (← getDelayedMVarAssignment? f'.mvarId!) with | none => return none | some { fvars := fvars, val := val, .. } => let args := e.getAppArgs diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index ceafb60ddc..29648aaceb 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -271,14 +271,15 @@ structure DelayedMetavarAssignment where open Std (HashMap PersistentHashMap) structure MetavarContext where - depth : Nat := 0 - mvarCounter : Nat := 0 -- Counter for setting the field `index` at `MetavarDecl` - lDepth : PersistentHashMap MVarId Nat := {} - decls : PersistentHashMap MVarId MetavarDecl := {} - userNames : PersistentHashMap Name MVarId := {} - lAssignment : PersistentHashMap MVarId Level := {} - eAssignment : PersistentHashMap MVarId Expr := {} - dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {} + depth : Nat := 0 + mvarCounter : Nat := 0 -- Counter for setting the field `index` at `MetavarDecl` + lDepth : PersistentHashMap MVarId Nat := {} + decls : PersistentHashMap MVarId MetavarDecl := {} + userNames : PersistentHashMap Name MVarId := {} + lAssignment : PersistentHashMap MVarId Level := {} + eAssignment : PersistentHashMap MVarId Expr := {} + dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {} + usedAssignment : Bool := false class MonadMCtx (m : Type → Type) where getMCtx : m MetavarContext @@ -290,6 +291,48 @@ instance (m n) [MonadLift m n] [MonadMCtx m] : MonadMCtx n where getMCtx := liftM (getMCtx : m _) modifyMCtx := fun f => liftM (modifyMCtx f : m _) +def markUsedAssignment [MonadMCtx m] : m Unit := + modifyMCtx fun mctx => { mctx with usedAssignment := true } + +abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit := + modifyMCtx fun _ => mctx + +abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Level) := do + let result? := (← getMCtx).lAssignment.find? mvarId + if result?.isSome then + markUsedAssignment + return result? + +def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr := + m.eAssignment.find? mvarId + +def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) := do + let result? := (← getMCtx).getExprAssignmentCore? mvarId + if result?.isSome then + markUsedAssignment + return result? + +def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) := do + let result? := (← getMCtx).dAssignment.find? mvarId + if result?.isSome then + markUsedAssignment + return result? + +/- Given a sequence of delayed assignments + ``` + mvarId₁ := mvarId₂ ...; + ... + mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned + ``` + in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`. + If `mvarId₁` is not delayed assigned then return `mvarId₁` -/ +partial def getDelayedMVarRoot [Monad m] [MonadMCtx m] (mvarId : MVarId) : m MVarId := do + match (← getDelayedMVarAssignment? mvarId) with + | some d => match d.val.getAppFn with + | Expr.mvar mvarId _ => getDelayedMVarRoot mvarId + | _ => return mvarId + | none => return mvarId + namespace MetavarContext instance : Inhabited MetavarContext := ⟨{}⟩ @@ -387,22 +430,13 @@ def isAnonymousMVar (mctx : MetavarContext) (mvarId : MVarId) : Bool := | some mvarDecl => mvarDecl.userName.isAnonymous def assignLevel (m : MetavarContext) (mvarId : MVarId) (val : Level) : MetavarContext := - { m with lAssignment := m.lAssignment.insert mvarId val } + { m with lAssignment := m.lAssignment.insert mvarId val, usedAssignment := true } def assignExpr (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext := - { m with eAssignment := m.eAssignment.insert mvarId val } + { m with eAssignment := m.eAssignment.insert mvarId val, usedAssignment := true } def assignDelayed (m : MetavarContext) (mvarId : MVarId) (fvars : Array Expr) (val : Expr) : MetavarContext := - { m with dAssignment := m.dAssignment.insert mvarId { fvars, val } } - -def getLevelAssignment? (m : MetavarContext) (mvarId : MVarId) : Option Level := - m.lAssignment.find? mvarId - -def getExprAssignment? (m : MetavarContext) (mvarId : MVarId) : Option Expr := - m.eAssignment.find? mvarId - -def getDelayedAssignment? (m : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment := - m.dAssignment.find? mvarId + { m with dAssignment := m.dAssignment.insert mvarId { fvars, val }, usedAssignment := true } def isLevelAssigned (m : MetavarContext) (mvarId : MVarId) : Bool := m.lAssignment.contains mvarId @@ -416,21 +450,6 @@ def isDelayedAssigned (m : MetavarContext) (mvarId : MVarId) : Bool := def eraseDelayed (m : MetavarContext) (mvarId : MVarId) : MetavarContext := { m with dAssignment := m.dAssignment.erase mvarId } -/- Given a sequence of delayed assignments - ``` - mvarId₁ := mvarId₂ ...; - ... - mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned - ``` - in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`. - If `mvarId₁` is not delayed assigned then return `mvarId₁` -/ -partial def getDelayedRoot (m : MetavarContext) : MVarId → MVarId - | mvarId => match getDelayedAssignment? m mvarId with - | some d => match d.val.getAppFn with - | Expr.mvar mvarId _ => getDelayedRoot m mvarId - | _ => mvarId - | none => mvarId - def isLevelAssignable (mctx : MetavarContext) (mvarId : MVarId) : Bool := match mctx.lDepth.find? mvarId with | some d => d == mctx.depth @@ -515,7 +534,7 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level | lvl@(Level.max lvl₁ lvl₂ _) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂) | lvl@(Level.imax lvl₁ lvl₂ _) => return Level.updateIMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂) | lvl@(Level.mvar mvarId _) => do - match getLevelAssignment? (← getMCtx) mvarId with + match (← getLevelMVarAssignment? mvarId) with | some newLvl => if !newLvl.hasMVar then pure newLvl else do @@ -551,8 +570,7 @@ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLi instArgs f match f with | Expr.mvar mvarId _ => - let mctx ← getMCtx - match mctx.getDelayedAssignment? mvarId with + match (← getDelayedMVarAssignment? mvarId) with | none => instApp | some { fvars, val, .. } => /- @@ -589,8 +607,7 @@ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLi pure result | _ => instApp | e@(Expr.mvar mvarId _) => checkCache { val := e : ExprStructEq } fun _ => do - let mctx ← getMCtx - match mctx.getExprAssignment? mvarId with + match (← getExprMVarAssignment? mvarId) with | some newE => do let newE' ← instantiateExprMVars newE modifyMCtx fun mctx => mctx.assignExpr mvarId newE' @@ -629,18 +646,26 @@ def instantiateMVarDeclMVars (mctx : MetavarContext) (mvarId : MVarId) : Metavar namespace DependsOn -private abbrev M := StateM ExprSet +structure State where + visited : ExprSet := {} + mctx : MetavarContext + +private abbrev M := StateM State + +instance : MonadMCtx M where + getMCtx := return (← get).mctx + modifyMCtx f := modify fun s => { s with mctx := f s.mctx } private def shouldVisit (e : Expr) : M Bool := do if !e.hasMVar && !e.hasFVar then return false - else if (← get).contains e then + else if (← get).visited.contains e then return false else - modify fun s => s.insert e + modify fun s => { s with visited := s.visited.insert e } return true -@[specialize] private partial def dep (mctx : MetavarContext) (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool := +@[specialize] private partial def dep (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool := let rec visit (e : Expr) : M Bool := do if !(← shouldVisit e) then @@ -656,10 +681,10 @@ private def shouldVisit (e : Expr) : M Bool := do | Expr.lam _ d b _ => visit d <||> visit b | Expr.letE _ t v b _ => visit t <||> visit v <||> visit b | Expr.mdata _ b _ => visit b - | e@(Expr.app ..) => + | e@(Expr.app ..) => do let f := e.getAppFn if f.isMVar then - let (e', _) := instantiateMVars mctx e + let e' ← modifyGet fun ⟨visited, mctx⟩ => let (e, mctx) := instantiateMVars mctx e; (e, ⟨visited, mctx⟩) if e'.getAppFn != f then visitMain e' else if pm f.mvarId! then @@ -668,21 +693,21 @@ private def shouldVisit (e : Expr) : M Bool := do visitApp e else visitApp e - | Expr.mvar mvarId _ => - match mctx.getExprAssignment? mvarId with + | Expr.mvar mvarId _ => do + match (← getExprMVarAssignment? mvarId) with | some a => visit a | none => if pm mvarId then return true else - let lctx := (mctx.getDecl mvarId).lctx + let lctx := (← getMCtx).getDecl mvarId |>.lctx return lctx.any fun decl => pf decl.fvarId | Expr.fvar fvarId _ => return pf fvarId | _ => pure false visit e -@[inline] partial def main (mctx : MetavarContext) (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool := - if !e.hasFVar && !e.hasMVar then pure false else dep mctx pf pm e +@[inline] partial def main (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool := + if !e.hasFVar && !e.hasMVar then pure false else dep pf pm e end DependsOn @@ -692,40 +717,45 @@ end DependsOn 1- If `?m := t`, then we visit `t` looking for `x` 2- If `?m` is unassigned, then we consider the worst case and check whether `x` is in the local context of `?m`. This case is a "may dependency". That is, we may assign a term `t` to `?m` s.t. `t` contains `x`. -/ -@[inline] def findExprDependsOn (mctx : MetavarContext) (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : Bool := - DependsOn.main mctx pf pm e |>.run' {} +@[inline] def findExprDependsOn [Monad m] [MonadMCtx m] (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : m Bool := do + let (result, { mctx, .. }) := DependsOn.main pf pm e |>.run { mctx := (← getMCtx) } + setMCtx mctx + return result /-- Similar to `findExprDependsOn`, but checks the expressions in the given local declaration depends on a free variable `x` s.t. `pf x` is `true` or an unassigned metavariable `?m` s.t. `pm ?m` is true. -/ -@[inline] def findLocalDeclDependsOn (mctx : MetavarContext) (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : Bool := +@[inline] def findLocalDeclDependsOn [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : m Bool := do match localDecl with - | LocalDecl.cdecl (type := t) .. => findExprDependsOn mctx t pf pm - | LocalDecl.ldecl (type := t) (value := v) .. => (DependsOn.main mctx pf pm t <||> DependsOn.main mctx pf pm v).run' {} + | LocalDecl.cdecl (type := t) .. => findExprDependsOn t pf pm + | LocalDecl.ldecl (type := t) (value := v) .. => + let (result, { mctx, .. }) := (DependsOn.main pf pm t <||> DependsOn.main pf pm v).run { mctx := (← getMCtx) } + setMCtx mctx + return result -def exprDependsOn (mctx : MetavarContext) (e : Expr) (fvarId : FVarId) : Bool := - findExprDependsOn mctx e (fvarId == ·) +def exprDependsOn [Monad m] [MonadMCtx m] (e : Expr) (fvarId : FVarId) : m Bool := + findExprDependsOn e (fvarId == ·) -def localDeclDependsOn (mctx : MetavarContext) (localDecl : LocalDecl) (fvarId : FVarId) : Bool := - findLocalDeclDependsOn mctx localDecl (fvarId == ·) +def localDeclDependsOn [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (fvarId : FVarId) : m Bool := + findLocalDeclDependsOn localDecl (fvarId == ·) /-- Similar to `exprDependsOn`, but `x` can be a free variable or an unassigned metavariable. -/ -def exprDependsOn' (mctx : MetavarContext) (e : Expr) (x : Expr) : Bool := +def exprDependsOn' [Monad m] [MonadMCtx m] (e : Expr) (x : Expr) : m Bool := if x.isFVar then - findExprDependsOn mctx e (x.fvarId! == ·) + findExprDependsOn e (x.fvarId! == ·) else if x.isMVar then - findExprDependsOn mctx e (pm := (x.mvarId! == ·)) + findExprDependsOn e (pm := (x.mvarId! == ·)) else - false + return false /-- Similar to `localDeclDependsOn`, but `x` can be a free variable or an unassigned metavariable. -/ -def localDeclDependsOn' (mctx : MetavarContext) (localDecl : LocalDecl) (x : Expr) : Bool := +def localDeclDependsOn' [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (x : Expr) : m Bool := if x.isFVar then - findLocalDeclDependsOn mctx localDecl (x.fvarId! == ·) + findLocalDeclDependsOn localDecl (x.fvarId! == ·) else if x.isMVar then - findLocalDeclDependsOn mctx localDecl (pm := (x.mvarId! == ·)) + findLocalDeclDependsOn localDecl (pm := (x.mvarId! == ·)) else - false + return false namespace MkBinding @@ -762,6 +792,10 @@ structure Context where abbrev MCore := EStateM Exception State abbrev M := ReaderT Context MCore +instance : MonadMCtx M where + getMCtx := return (← get).mctx + modifyMCtx f := modify fun s => { s with mctx := f s.mctx } + private def mkFreshBinderName (n : Name := `x) : M Name := do let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 }) return addMacroScope (← read).mainModule n fresh @@ -816,20 +850,20 @@ private def getLocalDeclWithSmallestIdx (lctx : LocalContext) (xs : Array Expr) Note that https://github.com/leanprover/lean/issues/1258 is not an issue in Lean4 because we have changed how we compile recursive definitions. -/ -def collectForwardDeps (mctx : MetavarContext) (lctx : LocalContext) (toRevert : Array Expr) (preserveOrder : Bool) : Except Exception (Array Expr) := do +def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array Expr) := do if toRevert.size == 0 then pure toRevert else - if preserveOrder then + if (← preserveOrder) then -- Make sure toRevert[j] does not depend on toRevert[i] for j > i toRevert.size.forM fun i => do let fvar := toRevert[i]! i.forM fun j => do let prevFVar := toRevert[j]! let prevDecl := lctx.getFVar! prevFVar - if localDeclDependsOn mctx prevDecl fvar.fvarId! then - throw (Exception.revertFailure mctx lctx toRevert prevDecl.userName.toString) - let newToRevert := if preserveOrder then toRevert else Array.mkEmpty toRevert.size + if (← localDeclDependsOn prevDecl fvar.fvarId!) then + throw (Exception.revertFailure (← getMCtx) lctx toRevert prevDecl.userName.toString) + let newToRevert := if (← preserveOrder) then toRevert else Array.mkEmpty toRevert.size let firstDeclToVisit := getLocalDeclWithSmallestIdx lctx toRevert let initSize := newToRevert.size lctx.foldlM (init := newToRevert) (start := firstDeclToVisit.index) fun (newToRevert : Array Expr) decl => do @@ -837,7 +871,7 @@ def collectForwardDeps (mctx : MetavarContext) (lctx : LocalContext) (toRevert : return newToRevert else if toRevert.any fun x => decl.fvarId == x.fvarId! then return newToRevert.push decl.toExpr - else if findLocalDeclDependsOn mctx decl (newToRevert.any fun x => x.fvarId! == ·) then + else if (← findLocalDeclDependsOn decl (newToRevert.any fun x => x.fvarId! == ·)) then return newToRevert.push decl.toExpr else return newToRevert @@ -848,9 +882,6 @@ def reduceLocalContext (lctx : LocalContext) (toRevert : Array Expr) : LocalCont toRevert.foldr (init := lctx) fun x lctx => if x.isFVar then lctx.erase x.fvarId! else lctx -@[inline] private def getMCtx : M MetavarContext := - return (← get).mctx - /-- Return free variables in `xs` that are in the local context `lctx` -/ private def getInScope (lctx : LocalContext) (xs : Array Expr) : Array Expr := xs.foldl (init := #[]) fun scope x => @@ -930,8 +961,7 @@ mutual pure (e.abstractRange i xs) private partial def elimMVar (xs : Array Expr) (mvarId : MVarId) (args : Array Expr) : M (Expr × Array Expr) := do - let mctx ← getMCtx - let mvarDecl := mctx.getDecl mvarId + let mvarDecl := (← getMCtx).getDecl mvarId let mvarLCtx := mvarDecl.lctx let toRevert := getInScope mvarLCtx xs if toRevert.size == 0 then @@ -949,7 +979,7 @@ mutual A potential disadvantage is that `isDefEq` will not eagerly use `synthPending` for natural metavariables. That being said, we should try this approach as soon as we have an extensive test suite. -/ - let newMVarKind := if !mctx.isExprAssignable mvarId then MetavarKind.syntheticOpaque else mvarDecl.kind + let newMVarKind := if !(← getMCtx).isExprAssignable mvarId then MetavarKind.syntheticOpaque else mvarDecl.kind /- If `mvarId` is the lhs of a delayed assignment `?m #[x_1, ... x_n] := val`, then `nestedFVars` is `#[x_1, ..., x_n]`. In this case, we produce a new `syntheticOpaque` metavariable `?n` and a delayed assignment @@ -962,39 +992,36 @@ mutual -/ let rec cont (nestedFVars : Array Expr) : M (Expr × Array Expr) := do let args ← args.mapM (visit xs) - let preserve ← preserveOrder -- Note that `toRevert` only contains free variables since it is the result of `getInScope` - match collectForwardDeps mctx mvarLCtx toRevert preserve with - | Except.error ex => throw ex - | Except.ok toRevert => - let newMVarLCtx := reduceLocalContext mvarLCtx toRevert - let newLocalInsts := mvarDecl.localInstances.filter fun inst => toRevert.all fun x => inst.fvar != x - -- Remark: we must reset the before processing `mkAuxMVarType` because `toRevert` may not be equal to `xs` - let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type - let newMVarId := { name := (← get).ngen.curr } - let newMVar := mkMVar newMVarId - let result := mkMVarApp mvarLCtx newMVar toRevert newMVarKind - let numScopeArgs := mvarDecl.numScopeArgs + result.getAppNumArgs - modify fun s => { s with - mctx := s.mctx.addExprMVarDecl newMVarId Name.anonymous newMVarLCtx newLocalInsts newMVarType newMVarKind numScopeArgs, - ngen := s.ngen.next - } - match newMVarKind with - | MetavarKind.syntheticOpaque => - modify fun s => { s with mctx := assignDelayed s.mctx newMVarId (toRevert ++ nestedFVars) (mkAppN (mkMVar mvarId) nestedFVars) } - | _ => - modify fun s => { s with mctx := assignExpr s.mctx mvarId result } - return (mkAppN result args, toRevert) + let toRevert ← collectForwardDeps mvarLCtx toRevert + let newMVarLCtx := reduceLocalContext mvarLCtx toRevert + let newLocalInsts := mvarDecl.localInstances.filter fun inst => toRevert.all fun x => inst.fvar != x + -- Remark: we must reset the before processing `mkAuxMVarType` because `toRevert` may not be equal to `xs` + let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type + let newMVarId := { name := (← get).ngen.curr } + let newMVar := mkMVar newMVarId + let result := mkMVarApp mvarLCtx newMVar toRevert newMVarKind + let numScopeArgs := mvarDecl.numScopeArgs + result.getAppNumArgs + modify fun s => { s with + mctx := s.mctx.addExprMVarDecl newMVarId Name.anonymous newMVarLCtx newLocalInsts newMVarType newMVarKind numScopeArgs, + ngen := s.ngen.next + } + match newMVarKind with + | MetavarKind.syntheticOpaque => + modify fun s => { s with mctx := assignDelayed s.mctx newMVarId (toRevert ++ nestedFVars) (mkAppN (mkMVar mvarId) nestedFVars) } + | _ => + modify fun s => { s with mctx := assignExpr s.mctx mvarId result } + return (mkAppN result args, toRevert) if !mvarDecl.kind.isSyntheticOpaque then cont #[] - else match mctx.getDelayedAssignment? mvarId with + else match (← getDelayedMVarAssignment? mvarId) with | none => cont #[] | some { fvars, .. } => cont fvars private partial def elimApp (xs : Array Expr) (f : Expr) (args : Array Expr) : M Expr := do match f with | Expr.mvar mvarId _ => - match (← getMCtx).getExprAssignment? mvarId with + match (← getExprMVarAssignment? mvarId) with | some newF => if newF.isLambda then let args ← args.mapM (visit xs) @@ -1096,28 +1123,32 @@ def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool := @[inline] def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MkBindingM Expr := fun ctx => MkBinding.abstractRange xs n e { preserveOrder := false, mainModule := ctx.mainModule } +@[inline] def collectForwardDeps (toRevert : Array Expr) (preserveOrder : Bool) : MkBindingM (Array Expr) := fun ctx => + MkBinding.collectForwardDeps ctx.lctx toRevert { preserveOrder, mainModule := ctx.mainModule } + /-- `isWellFormed mctx lctx e` return true if - All locals in `e` are declared in `lctx` - All metavariables `?m` in `e` have a local context which is a subprefix of `lctx` or are assigned, and the assignment is well-formed. -/ -partial def isWellFormed (mctx : MetavarContext) (lctx : LocalContext) : Expr → Bool - | Expr.mdata _ e _ => isWellFormed mctx lctx e - | Expr.proj _ _ e _ => isWellFormed mctx lctx e - | e@(Expr.app f a _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx f && isWellFormed mctx lctx a) - | e@(Expr.lam _ d b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx d && isWellFormed mctx lctx b) - | e@(Expr.forallE _ d b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx d && isWellFormed mctx lctx b) - | e@(Expr.letE _ t v b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx t && isWellFormed mctx lctx v && isWellFormed mctx lctx b) - | Expr.const .. => true - | Expr.bvar .. => true - | Expr.sort .. => true - | Expr.lit .. => true - | Expr.mvar mvarId _ => - let mvarDecl := mctx.getDecl mvarId; - if mvarDecl.lctx.isSubPrefixOf lctx then true - else match mctx.getExprAssignment? mvarId with - | none => false - | some v => isWellFormed mctx lctx v - | Expr.fvar fvarId _ => lctx.contains fvarId +partial def isWellFormed [Monad m] [MonadMCtx m] (lctx : LocalContext) : Expr → m Bool + | Expr.mdata _ e _ => isWellFormed lctx e + | Expr.proj _ _ e _ => isWellFormed lctx e + | e@(Expr.app f a _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx f <&&> isWellFormed lctx a) + | e@(Expr.lam _ d b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx d <&&> isWellFormed lctx b) + | e@(Expr.forallE _ d b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx d <&&> isWellFormed lctx b) + | e@(Expr.letE _ t v b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx t <&&> isWellFormed lctx v <&&> isWellFormed lctx b) + | Expr.const .. => return true + | Expr.bvar .. => return true + | Expr.sort .. => return true + | Expr.lit .. => return true + | Expr.mvar mvarId _ => do + let mvarDecl := (← getMCtx).getDecl mvarId; + if mvarDecl.lctx.isSubPrefixOf lctx then + return true + else match (← getExprMVarAssignment? mvarId) with + | none => return false + | some v => isWellFormed lctx v + | Expr.fvar fvarId _ => return lctx.contains fvarId namespace LevelMVarToParam @@ -1134,6 +1165,10 @@ structure State where abbrev M := ReaderT Context <| StateM State +instance : MonadMCtx M where + getMCtx := return (← get).mctx + modifyMCtx f := modify fun s => { s with mctx := f s.mctx } + instance : MonadCache ExprStructEq Expr M where findCached? e := return (← get).cache.find? e cache e v := modify fun s => { s with cache := s.cache.insert e v } @@ -1158,7 +1193,7 @@ partial def visitLevel (u : Level) : M Level := do | Level.param .. => return u | Level.mvar mvarId _ => let s ← get - match s.mctx.getLevelAssignment? mvarId with + match (← getLevelMVarAssignment? mvarId) with | some v => visitLevel v | none => if (← read).except mvarId then @@ -1189,7 +1224,7 @@ where visitApp (f : Expr) (args : Array Expr) : M Expr := do match f with | Expr.mvar mvarId .. => - match (← get).mctx.getExprAssignment? mvarId with + match (← getExprMVarAssignment? mvarId) with | some v => return (← visitApp v args).headBeta | none => return mkAppN f (← args.mapM main) | _ => return mkAppN (← main f) (← args.mapM main) diff --git a/src/Lean/Server/Rpc/Deriving.lean b/src/Lean/Server/Rpc/Deriving.lean index ecb25bc016..0906a1969c 100644 --- a/src/Lean/Server/Rpc/Deriving.lean +++ b/src/Lean/Server/Rpc/Deriving.lean @@ -177,7 +177,7 @@ private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr let argFVars ← argVars.mapM (LocalDecl.fvarId <$> getFVarLocalDecl ·) for arg in argVars do let argTp ← inferType arg - if (← getMCtx).findExprDependsOn argTp (pf := fun fv => argFVars.contains fv) then + if (← MetavarContext.findExprDependsOn argTp (pf := fun fv => argFVars.contains fv)) then throwError "cross-argument dependencies are not supported ({arg} : {argTp})" if (← acc.encArgTypes.getMatch argTp).isEmpty then diff --git a/src/Lean/Util/OccursCheck.lean b/src/Lean/Util/OccursCheck.lean index 203d26b9ce..e48c0a6e60 100644 --- a/src/Lean/Util/OccursCheck.lean +++ b/src/Lean/Util/OccursCheck.lean @@ -10,26 +10,26 @@ namespace Lean /-- Return true if `e` does **not** contain `mvarId` directly or indirectly This function considers assigments and delayed assignments. -/ -partial def MetavarContext.occursCheck (mctx : MetavarContext) (mvarId : MVarId) (e : Expr) : Bool := +partial def occursCheck [Monad m] [MonadMCtx m] (mvarId : MVarId) (e : Expr) : m Bool := do if !e.hasExprMVar then - true + return true else - match visit e |>.run {} with - | EStateM.Result.ok .. => true - | EStateM.Result.error .. => false + match (← visit e |>.run |>.run {}) with + | (.ok .., _) => return true + | (.error .., _) => return false where - visitMVar (mvarId' : MVarId) : EStateM Unit ExprSet Unit := do + visitMVar (mvarId' : MVarId) : ExceptT Unit (StateT ExprSet m) Unit := do if mvarId == mvarId' then throw () -- found else - match mctx.getExprAssignment? mvarId' with + match (← getExprMVarAssignment? mvarId') with | some v => visit v | none => - match mctx.getDelayedAssignment? mvarId' with + match (← getDelayedMVarAssignment? mvarId') with | some d => visit d.val | none => return () - visit (e : Expr) : EStateM Unit ExprSet Unit := do + visit (e : Expr) : ExceptT Unit (StateT ExprSet m) Unit := do if !e.hasExprMVar then return () else if (← get).contains e then