perf: do not reset tc cache when adding local instances

This commit is contained in:
Gabriel Ebner 2023-03-14 12:08:36 -07:00
parent 83c1a1ab77
commit d3c55ef249
4 changed files with 15 additions and 41 deletions

View file

@ -421,9 +421,8 @@ private partial def elabFunBinderViews (binderViews : Array BinderView) (i : Nat
let s := { s with lctx }
match ← isClass? type, kind with
| some className, .default =>
resettingSynthInstanceCache do
let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId }
elabFunBinderViews binderViews (i+1) { s with localInsts }
let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId }
elabFunBinderViews binderViews (i+1) { s with localInsts }
| _, _ => elabFunBinderViews binderViews (i+1) s
else
pure s
@ -445,7 +444,7 @@ def elabFunBinders (binders : Array Syntax) (expectedType? : Option Expr) (x : A
let lctx ← getLCtx
let localInsts ← getLocalInstances
let s ← FunBinders.elabFunBindersAux binders 0 { lctx, localInsts, expectedType? }
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) <| withLCtx s.lctx s.localInsts <|
withLCtx s.lctx s.localInsts do
x s.fvars s.expectedType?
def expandWhereDecls (whereDecls : Syntax) (body : Syntax) : MacroM Syntax :=

View file

@ -195,7 +195,7 @@ instance : Hashable InfoCacheKey :=
⟨fun ⟨transparency, expr, nargs⟩ => mixHash (hash transparency) <| mixHash (hash expr) (hash nargs)⟩
end InfoCacheKey
abbrev SynthInstanceCache := PersistentHashMap Expr (Option Expr)
abbrev SynthInstanceCache := PersistentHashMap (LocalInstances × Expr) (Option Expr)
abbrev InferTypeCache := PersistentExprStructMap Expr
abbrev FunInfoCache := PersistentHashMap InfoCacheKey FunInfo
@ -873,38 +873,15 @@ private partial def isClassQuick? : Expr → MetaM (LOption Name)
| .lam .. => return .undef
| _ => return .none
def saveAndResetSynthInstanceCache : MetaM SynthInstanceCache := do
let savedSythInstance := (← get).cache.synthInstance
modifyCache fun c => { c with synthInstance := {} }
return savedSythInstance
def restoreSynthInstanceCache (cache : SynthInstanceCache) : MetaM Unit :=
modifyCache fun c => { c with synthInstance := cache }
@[inline] private def resettingSynthInstanceCacheImpl (x : MetaM α) : MetaM α := do
let savedSythInstance ← saveAndResetSynthInstanceCache
try x finally restoreSynthInstanceCache savedSythInstance
/-- Reset `synthInstance` cache, execute `x`, and restore cache -/
@[inline] def resettingSynthInstanceCache : n α → n α :=
mapMetaM resettingSynthInstanceCacheImpl
@[inline] def resettingSynthInstanceCacheWhen (b : Bool) (x : n α) : n α :=
if b then resettingSynthInstanceCache x else x
private def withNewLocalInstanceImp (className : Name) (fvar : Expr) (k : MetaM α) : MetaM α := do
let localDecl ← getFVarLocalDecl fvar
if localDecl.isImplementationDetail then
k
else
resettingSynthInstanceCache <|
withReader
(fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } })
k
withReader (fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } }) k
/-- Add entry `{ className := className, fvar := fvar }` to localInstances,
and then execute continuation `k`.
It resets the type class cache using `resettingSynthInstanceCache`. -/
and then execute continuation `k`. -/
def withNewLocalInstance (className : Name) (fvar : Expr) : n α → n α :=
mapMetaM <| withNewLocalInstanceImp className fvar
@ -919,9 +896,7 @@ mutual
using free variables `fvars[j] ... fvars.back`, and execute `k`.
- `isClassExpensive` is defined later.
- The type class chache is reset whenever a new local instance is found.
- `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances.
Thus, each new local instance requires a new `resettingSynthInstanceCache`. -/
- `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances. -/
private partial def withNewLocalInstancesImp
(fvars : Array Expr) (i : Nat) (k : MetaM α) : MetaM α := do
if h : i < fvars.size then
@ -1278,7 +1253,7 @@ def withLocalInstancesImp (decls : List LocalDecl) (k : MetaM α) : MetaM α :=
if localInsts.size == size then
k
else
resettingSynthInstanceCache <| withReader (fun ctx => { ctx with localInstances := localInsts }) k
withReader (fun ctx => { ctx with localInstances := localInsts }) k
/-- Register any local instance in `decls` -/
def withLocalInstances (decls : List LocalDecl) : n α → n α :=
@ -1322,12 +1297,8 @@ def withNewMCtxDepth (k : n α) (allowLevelAssignments := false) : n α :=
mapMetaM (withNewMCtxDepthImp allowLevelAssignments) k
private def withLocalContextImp (lctx : LocalContext) (localInsts : LocalInstances) (x : MetaM α) : MetaM α := do
let localInstsCurr ← getLocalInstances
withReader (fun ctx => { ctx with lctx := lctx, localInstances := localInsts }) do
if localInsts == localInstsCurr then
x
else
resettingSynthInstanceCache x
x
/--
`withLCtx lctx localInsts k` replaces the local context and local instances, and then executes `k`.

View file

@ -672,10 +672,11 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
withConfig (fun config => { config with isDefEqStuckEx := true, transparency := TransparencyMode.instances,
foApprox := true, ctxApprox := true, constApprox := false,
etaStruct }) do
let localInsts ← getLocalInstances
let type ← instantiateMVars type
let type ← preprocess type
let s ← get
match s.cache.synthInstance.find? type with
match s.cache.synthInstance.find? (localInsts, type) with
| some result =>
trace[Meta.synthInstance] "result {result} (cached)"
pure result
@ -728,7 +729,7 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
if type.hasMVar || resultHasUnivMVars then
pure result?
else do
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert type result? }
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert (localInsts, type) result? }
pure result?
/--

View file

@ -234,6 +234,9 @@ abbrev LocalInstances := Array LocalInstance
instance : BEq LocalInstance where
beq i₁ i₂ := i₁.fvar == i₂.fvar
instance : Hashable LocalInstance where
hash i := hash i.fvar
/-- Remove local instance with the given `fvarId`. Do nothing if `localInsts` does not contain any free variable with id `fvarId`. -/
def LocalInstances.erase (localInsts : LocalInstances) (fvarId : FVarId) : LocalInstances :=
match localInsts.findIdx? (fun inst => inst.fvar.fvarId! == fvarId) with