perf: do not reset tc cache when adding local instances
This commit is contained in:
parent
83c1a1ab77
commit
d3c55ef249
4 changed files with 15 additions and 41 deletions
|
|
@ -421,9 +421,8 @@ private partial def elabFunBinderViews (binderViews : Array BinderView) (i : Nat
|
|||
let s := { s with lctx }
|
||||
match ← isClass? type, kind with
|
||||
| some className, .default =>
|
||||
resettingSynthInstanceCache do
|
||||
let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId }
|
||||
elabFunBinderViews binderViews (i+1) { s with localInsts }
|
||||
let localInsts := s.localInsts.push { className, fvar := mkFVar fvarId }
|
||||
elabFunBinderViews binderViews (i+1) { s with localInsts }
|
||||
| _, _ => elabFunBinderViews binderViews (i+1) s
|
||||
else
|
||||
pure s
|
||||
|
|
@ -445,7 +444,7 @@ def elabFunBinders (binders : Array Syntax) (expectedType? : Option Expr) (x : A
|
|||
let lctx ← getLCtx
|
||||
let localInsts ← getLocalInstances
|
||||
let s ← FunBinders.elabFunBindersAux binders 0 { lctx, localInsts, expectedType? }
|
||||
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) <| withLCtx s.lctx s.localInsts <|
|
||||
withLCtx s.lctx s.localInsts do
|
||||
x s.fvars s.expectedType?
|
||||
|
||||
def expandWhereDecls (whereDecls : Syntax) (body : Syntax) : MacroM Syntax :=
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ instance : Hashable InfoCacheKey :=
|
|||
⟨fun ⟨transparency, expr, nargs⟩ => mixHash (hash transparency) <| mixHash (hash expr) (hash nargs)⟩
|
||||
end InfoCacheKey
|
||||
|
||||
abbrev SynthInstanceCache := PersistentHashMap Expr (Option Expr)
|
||||
abbrev SynthInstanceCache := PersistentHashMap (LocalInstances × Expr) (Option Expr)
|
||||
|
||||
abbrev InferTypeCache := PersistentExprStructMap Expr
|
||||
abbrev FunInfoCache := PersistentHashMap InfoCacheKey FunInfo
|
||||
|
|
@ -873,38 +873,15 @@ private partial def isClassQuick? : Expr → MetaM (LOption Name)
|
|||
| .lam .. => return .undef
|
||||
| _ => return .none
|
||||
|
||||
def saveAndResetSynthInstanceCache : MetaM SynthInstanceCache := do
|
||||
let savedSythInstance := (← get).cache.synthInstance
|
||||
modifyCache fun c => { c with synthInstance := {} }
|
||||
return savedSythInstance
|
||||
|
||||
def restoreSynthInstanceCache (cache : SynthInstanceCache) : MetaM Unit :=
|
||||
modifyCache fun c => { c with synthInstance := cache }
|
||||
|
||||
@[inline] private def resettingSynthInstanceCacheImpl (x : MetaM α) : MetaM α := do
|
||||
let savedSythInstance ← saveAndResetSynthInstanceCache
|
||||
try x finally restoreSynthInstanceCache savedSythInstance
|
||||
|
||||
/-- Reset `synthInstance` cache, execute `x`, and restore cache -/
|
||||
@[inline] def resettingSynthInstanceCache : n α → n α :=
|
||||
mapMetaM resettingSynthInstanceCacheImpl
|
||||
|
||||
@[inline] def resettingSynthInstanceCacheWhen (b : Bool) (x : n α) : n α :=
|
||||
if b then resettingSynthInstanceCache x else x
|
||||
|
||||
private def withNewLocalInstanceImp (className : Name) (fvar : Expr) (k : MetaM α) : MetaM α := do
|
||||
let localDecl ← getFVarLocalDecl fvar
|
||||
if localDecl.isImplementationDetail then
|
||||
k
|
||||
else
|
||||
resettingSynthInstanceCache <|
|
||||
withReader
|
||||
(fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } })
|
||||
k
|
||||
withReader (fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } }) k
|
||||
|
||||
/-- Add entry `{ className := className, fvar := fvar }` to localInstances,
|
||||
and then execute continuation `k`.
|
||||
It resets the type class cache using `resettingSynthInstanceCache`. -/
|
||||
and then execute continuation `k`. -/
|
||||
def withNewLocalInstance (className : Name) (fvar : Expr) : n α → n α :=
|
||||
mapMetaM <| withNewLocalInstanceImp className fvar
|
||||
|
||||
|
|
@ -919,9 +896,7 @@ mutual
|
|||
using free variables `fvars[j] ... fvars.back`, and execute `k`.
|
||||
|
||||
- `isClassExpensive` is defined later.
|
||||
- The type class chache is reset whenever a new local instance is found.
|
||||
- `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances.
|
||||
Thus, each new local instance requires a new `resettingSynthInstanceCache`. -/
|
||||
- `isClassExpensive` uses `whnf` which depends (indirectly) on the set of local instances. -/
|
||||
private partial def withNewLocalInstancesImp
|
||||
(fvars : Array Expr) (i : Nat) (k : MetaM α) : MetaM α := do
|
||||
if h : i < fvars.size then
|
||||
|
|
@ -1278,7 +1253,7 @@ def withLocalInstancesImp (decls : List LocalDecl) (k : MetaM α) : MetaM α :=
|
|||
if localInsts.size == size then
|
||||
k
|
||||
else
|
||||
resettingSynthInstanceCache <| withReader (fun ctx => { ctx with localInstances := localInsts }) k
|
||||
withReader (fun ctx => { ctx with localInstances := localInsts }) k
|
||||
|
||||
/-- Register any local instance in `decls` -/
|
||||
def withLocalInstances (decls : List LocalDecl) : n α → n α :=
|
||||
|
|
@ -1322,12 +1297,8 @@ def withNewMCtxDepth (k : n α) (allowLevelAssignments := false) : n α :=
|
|||
mapMetaM (withNewMCtxDepthImp allowLevelAssignments) k
|
||||
|
||||
private def withLocalContextImp (lctx : LocalContext) (localInsts : LocalInstances) (x : MetaM α) : MetaM α := do
|
||||
let localInstsCurr ← getLocalInstances
|
||||
withReader (fun ctx => { ctx with lctx := lctx, localInstances := localInsts }) do
|
||||
if localInsts == localInstsCurr then
|
||||
x
|
||||
else
|
||||
resettingSynthInstanceCache x
|
||||
x
|
||||
|
||||
/--
|
||||
`withLCtx lctx localInsts k` replaces the local context and local instances, and then executes `k`.
|
||||
|
|
|
|||
|
|
@ -672,10 +672,11 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
|
|||
withConfig (fun config => { config with isDefEqStuckEx := true, transparency := TransparencyMode.instances,
|
||||
foApprox := true, ctxApprox := true, constApprox := false,
|
||||
etaStruct }) do
|
||||
let localInsts ← getLocalInstances
|
||||
let type ← instantiateMVars type
|
||||
let type ← preprocess type
|
||||
let s ← get
|
||||
match s.cache.synthInstance.find? type with
|
||||
match s.cache.synthInstance.find? (localInsts, type) with
|
||||
| some result =>
|
||||
trace[Meta.synthInstance] "result {result} (cached)"
|
||||
pure result
|
||||
|
|
@ -728,7 +729,7 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
|
|||
if type.hasMVar || resultHasUnivMVars then
|
||||
pure result?
|
||||
else do
|
||||
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert type result? }
|
||||
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert (localInsts, type) result? }
|
||||
pure result?
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -234,6 +234,9 @@ abbrev LocalInstances := Array LocalInstance
|
|||
instance : BEq LocalInstance where
|
||||
beq i₁ i₂ := i₁.fvar == i₂.fvar
|
||||
|
||||
instance : Hashable LocalInstance where
|
||||
hash i := hash i.fvar
|
||||
|
||||
/-- Remove local instance with the given `fvarId`. Do nothing if `localInsts` does not contain any free variable with id `fvarId`. -/
|
||||
def LocalInstances.erase (localInsts : LocalInstances) (fvarId : FVarId) : LocalInstances :=
|
||||
match localInsts.findIdx? (fun inst => inst.fvar.fvarId! == fvarId) with
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue