From d3c55ef249dd511cadd8109eceaa2a13902ac4af Mon Sep 17 00:00:00 2001 From: Gabriel Ebner Date: Tue, 14 Mar 2023 12:08:36 -0700 Subject: [PATCH] perf: do not reset tc cache when adding local instances --- src/Lean/Elab/Binders.lean | 7 +++--- src/Lean/Meta/Basic.lean | 41 +++++--------------------------- src/Lean/Meta/SynthInstance.lean | 5 ++-- src/Lean/MetavarContext.lean | 3 +++ 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/src/Lean/Elab/Binders.lean b/src/Lean/Elab/Binders.lean index 4caa91f8d4..7db9a88b6d 100644 --- a/src/Lean/Elab/Binders.lean +++ b/src/Lean/Elab/Binders.lean @@ -421,9 +421,8 @@ private partial def elabFunBinderViews (binderViews : Array BinderView) (i : Nat let s := { s with lctx } match ← isClass? type, kind with | some className, .default => - resettingSynthInstanceCache do - let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId } - elabFunBinderViews binderViews (i+1) { s with localInsts } + let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId } + elabFunBinderViews binderViews (i+1) { s with localInsts } | _, _ => elabFunBinderViews binderViews (i+1) s else pure s @@ -445,7 +444,7 @@ def elabFunBinders (binders : Array Syntax) (expectedType? : Option Expr) (x : A let lctx ← getLCtx let localInsts ← getLocalInstances let s ← FunBinders.elabFunBindersAux binders 0 { lctx, localInsts, expectedType? } - resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) <| withLCtx s.lctx s.localInsts <| + withLCtx s.lctx s.localInsts do x s.fvars s.expectedType? def expandWhereDecls (whereDecls : Syntax) (body : Syntax) : MacroM Syntax := diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 2a7df070f5..f1e4348182 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -195,7 +195,7 @@ instance : Hashable InfoCacheKey := ⟨fun ⟨transparency, expr, nargs⟩ => mixHash (hash transparency) <| mixHash (hash expr) (hash nargs)⟩ end InfoCacheKey -abbrev SynthInstanceCache := PersistentHashMap Expr (Option Expr) +abbrev SynthInstanceCache := PersistentHashMap (LocalInstances × Expr) (Option Expr) abbrev InferTypeCache := PersistentExprStructMap Expr abbrev FunInfoCache := PersistentHashMap InfoCacheKey FunInfo @@ -873,38 +873,15 @@ private partial def isClassQuick? : Expr → MetaM (LOption Name) | .lam .. => return .undef | _ => return .none -def saveAndResetSynthInstanceCache : MetaM SynthInstanceCache := do - let savedSythInstance := (← get).cache.synthInstance - modifyCache fun c => { c with synthInstance := {} } - return savedSythInstance - -def restoreSynthInstanceCache (cache : SynthInstanceCache) : MetaM Unit := - modifyCache fun c => { c with synthInstance := cache } - -@[inline] private def resettingSynthInstanceCacheImpl (x : MetaM α) : MetaM α := do - let savedSythInstance ← saveAndResetSynthInstanceCache - try x finally restoreSynthInstanceCache savedSythInstance - -/-- Reset `synthInstance` cache, execute `x`, and restore cache -/ -@[inline] def resettingSynthInstanceCache : n α → n α := - mapMetaM resettingSynthInstanceCacheImpl - -@[inline] def resettingSynthInstanceCacheWhen (b : Bool) (x : n α) : n α := - if b then resettingSynthInstanceCache x else x - private def withNewLocalInstanceImp (className : Name) (fvar : Expr) (k : MetaM α) : MetaM α := do let localDecl ← getFVarLocalDecl fvar if localDecl.isImplementationDetail then k else - resettingSynthInstanceCache <| - withReader - (fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } }) - k + withReader (fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } }) k /-- Add entry `{ className := className, fvar := fvar }` to localInstances, - and then execute continuation `k`. - It resets the type class cache using `resettingSynthInstanceCache`. -/ + and then execute continuation `k`. -/ def withNewLocalInstance (className : Name) (fvar : Expr) : n α → n α := mapMetaM <| withNewLocalInstanceImp className fvar @@ -919,9 +896,7 @@ mutual using free variables `fvars[j] ... fvars.back`, and execute `k`. - `isClassExpensive` is defined later. - - The type class chache is reset whenever a new local instance is found. - - `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances. - Thus, each new local instance requires a new `resettingSynthInstanceCache`. -/ + - `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances. -/ private partial def withNewLocalInstancesImp (fvars : Array Expr) (i : Nat) (k : MetaM α) : MetaM α := do if h : i < fvars.size then @@ -1278,7 +1253,7 @@ def withLocalInstancesImp (decls : List LocalDecl) (k : MetaM α) : MetaM α := if localInsts.size == size then k else - resettingSynthInstanceCache <| withReader (fun ctx => { ctx with localInstances := localInsts }) k + withReader (fun ctx => { ctx with localInstances := localInsts }) k /-- Register any local instance in `decls` -/ def withLocalInstances (decls : List LocalDecl) : n α → n α := @@ -1322,12 +1297,8 @@ def withNewMCtxDepth (k : n α) (allowLevelAssignments := false) : n α := mapMetaM (withNewMCtxDepthImp allowLevelAssignments) k private def withLocalContextImp (lctx : LocalContext) (localInsts : LocalInstances) (x : MetaM α) : MetaM α := do - let localInstsCurr ← getLocalInstances withReader (fun ctx => { ctx with lctx := lctx, localInstances := localInsts }) do - if localInsts == localInstsCurr then - x - else - resettingSynthInstanceCache x + x /-- `withLCtx lctx localInsts k` replaces the local context and local instances, and then executes `k`. diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index f25300cf98..68081b13d0 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -672,10 +672,11 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM ( withConfig (fun config => { config with isDefEqStuckEx := true, transparency := TransparencyMode.instances, foApprox := true, ctxApprox := true, constApprox := false, etaStruct }) do + let localInsts ← getLocalInstances let type ← instantiateMVars type let type ← preprocess type let s ← get - match s.cache.synthInstance.find? type with + match s.cache.synthInstance.find? (localInsts, type) with | some result => trace[Meta.synthInstance] "result {result} (cached)" pure result @@ -728,7 +729,7 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM ( if type.hasMVar || resultHasUnivMVars then pure result? else do - modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert type result? } + modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert (localInsts, type) result? } pure result? /-- diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index decc9dd73d..9b1733936a 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -234,6 +234,9 @@ abbrev LocalInstances := Array LocalInstance instance : BEq LocalInstance where beq i₁ i₂ := i₁.fvar == i₂.fvar +instance : Hashable LocalInstance where + hash i := hash i.fvar + /-- Remove local instance with the given `fvarId`. Do nothing if `localInsts` does not contain any free variable with id `fvarId`. -/ def LocalInstances.erase (localInsts : LocalInstances) (fvarId : FVarId) : LocalInstances := match localInsts.findIdx? (fun inst => inst.fvar.fvarId! == fvarId) with