perf: speed up Expr.replace
This commit is contained in:
parent
dcc97c9bbe
commit
d87c36157a
1 changed files with 51 additions and 26 deletions
|
|
@ -10,43 +10,68 @@ namespace Expr
|
|||
|
||||
namespace ReplaceImpl
|
||||
|
||||
abbrev cacheSize : USize := 8192
|
||||
@[inline] def cacheSize : USize := 8192
|
||||
|
||||
structure State where
|
||||
keys : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr
|
||||
results : Array Expr
|
||||
structure Cache where
|
||||
-- First cacheSize elements are the keys.
|
||||
-- Second cacheSize elements are the results.
|
||||
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
|
||||
|
||||
abbrev ReplaceM := StateM State
|
||||
set_option compiler.extract_closed false in
|
||||
unsafe def Cache.new : Cache :=
|
||||
{ keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) }
|
||||
|
||||
unsafe def cache (i : USize) (key : Expr) (result : Expr) : ReplaceM Expr := do
|
||||
modify fun ⟨keys, results⟩ => { keys := keys.uset i key lcProof, results := results.uset i result lcProof };
|
||||
@[inline]
|
||||
unsafe def Cache.keyIdx (key : Expr) : USize :=
|
||||
(ptrAddrUnsafe key >>> 4) % cacheSize
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.resultIdx (key : Expr) : USize :=
|
||||
keyIdx key + cacheSize
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
|
||||
have : (keyIdx key).toNat < c.keysResults.size := lcProof
|
||||
ptrEq (unsafeCast key) c.keysResults[keyIdx key]
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
|
||||
have : (resultIdx key).toNat < c.keysResults.size := lcProof
|
||||
unsafeCast c.keysResults[resultIdx key]
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
|
||||
Cache.mk <| c.keysResults
|
||||
|>.uset (keyIdx key) (unsafeCast key) lcProof
|
||||
|>.uset (resultIdx key) (unsafeCast result) lcProof
|
||||
|
||||
abbrev ReplaceM := StateM Cache
|
||||
|
||||
@[inline]
|
||||
unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do
|
||||
modify (·.store key result)
|
||||
pure result
|
||||
|
||||
unsafe def replaceUnsafeM (f? : Expr → Option Expr) (size : USize) (e : Expr) : ReplaceM Expr := do
|
||||
let rec visit (e : Expr) := do
|
||||
let c ← get
|
||||
let h := ptrAddrUnsafe e
|
||||
let i := h % size
|
||||
if ptrAddrUnsafe (c.keys.uget i lcProof) == h then
|
||||
pure <| c.results.uget i lcProof
|
||||
@[specialize]
|
||||
unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr := do
|
||||
let rec @[specialize] visit (e : Expr) := do
|
||||
if (← get).hasResultFor e then
|
||||
return (← get).getResultFor e
|
||||
else match f? e with
|
||||
| some eNew => cache i e eNew
|
||||
| some eNew => cache e eNew
|
||||
| none => match e with
|
||||
| Expr.forallE _ d b _ => cache i e <| e.updateForallE! (← visit d) (← visit b)
|
||||
| Expr.lam _ d b _ => cache i e <| e.updateLambdaE! (← visit d) (← visit b)
|
||||
| Expr.mdata _ b => cache i e <| e.updateMData! (← visit b)
|
||||
| Expr.letE _ t v b _ => cache i e <| e.updateLet! (← visit t) (← visit v) (← visit b)
|
||||
| Expr.app f a => cache i e <| e.updateApp! (← visit f) (← visit a)
|
||||
| Expr.proj _ _ b => cache i e <| e.updateProj! (← visit b)
|
||||
| Expr.forallE _ d b _ => cache e <| e.updateForallE! (← visit d) (← visit b)
|
||||
| Expr.lam _ d b _ => cache e <| e.updateLambdaE! (← visit d) (← visit b)
|
||||
| Expr.mdata _ b => cache e <| e.updateMData! (← visit b)
|
||||
| Expr.letE _ t v b _ => cache e <| e.updateLet! (← visit t) (← visit v) (← visit b)
|
||||
| Expr.app f a => cache e <| e.updateApp! (← visit f) (← visit a)
|
||||
| Expr.proj _ _ b => cache e <| e.updateProj! (← visit b)
|
||||
| e => pure e
|
||||
visit e
|
||||
|
||||
unsafe def initCache : State :=
|
||||
{ keys := mkArray cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr`
|
||||
results := mkArray cacheSize.toNat default }
|
||||
|
||||
@[inline]
|
||||
unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr :=
|
||||
(replaceUnsafeM f? cacheSize e).run' initCache
|
||||
(replaceUnsafeM f? e).run' Cache.new
|
||||
|
||||
end ReplaceImpl
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue