perf: speed up Expr.replace

This commit is contained in:
Gabriel Ebner 2022-10-24 11:08:46 -07:00
parent dcc97c9bbe
commit d87c36157a

View file

@ -10,43 +10,68 @@ namespace Expr
namespace ReplaceImpl
abbrev cacheSize : USize := 8192
@[inline] def cacheSize : USize := 8192
structure State where
keys : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr
results : Array Expr
structure Cache where
-- First cacheSize elements are the keys.
-- Second cacheSize elements are the results.
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
abbrev ReplaceM := StateM State
set_option compiler.extract_closed false in
unsafe def Cache.new : Cache :=
{ keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) }
unsafe def cache (i : USize) (key : Expr) (result : Expr) : ReplaceM Expr := do
modify fun ⟨keys, results⟩ => { keys := keys.uset i key lcProof, results := results.uset i result lcProof };
@[inline]
unsafe def Cache.keyIdx (key : Expr) : USize :=
(ptrAddrUnsafe key >>> 4) % cacheSize
@[inline]
unsafe def Cache.resultIdx (key : Expr) : USize :=
keyIdx key + cacheSize
@[inline]
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
have : (keyIdx key).toNat < c.keysResults.size := lcProof
ptrEq (unsafeCast key) c.keysResults[keyIdx key]
@[inline]
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
have : (resultIdx key).toNat < c.keysResults.size := lcProof
unsafeCast c.keysResults[resultIdx key]
@[inline]
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
Cache.mk <| c.keysResults
|>.uset (keyIdx key) (unsafeCast key) lcProof
|>.uset (resultIdx key) (unsafeCast result) lcProof
abbrev ReplaceM := StateM Cache
@[inline]
unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do
modify (·.store key result)
pure result
unsafe def replaceUnsafeM (f? : Expr → Option Expr) (size : USize) (e : Expr) : ReplaceM Expr := do
let rec visit (e : Expr) := do
let c ← get
let h := ptrAddrUnsafe e
let i := h % size
if ptrAddrUnsafe (c.keys.uget i lcProof) == h then
pure <| c.results.uget i lcProof
@[specialize]
unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr := do
let rec @[specialize] visit (e : Expr) := do
if (← get).hasResultFor e then
return (← get).getResultFor e
else match f? e with
| some eNew => cache i e eNew
| some eNew => cache e eNew
| none => match e with
| Expr.forallE _ d b _ => cache i e <| e.updateForallE! (← visit d) (← visit b)
| Expr.lam _ d b _ => cache i e <| e.updateLambdaE! (← visit d) (← visit b)
| Expr.mdata _ b => cache i e <| e.updateMData! (← visit b)
| Expr.letE _ t v b _ => cache i e <| e.updateLet! (← visit t) (← visit v) (← visit b)
| Expr.app f a => cache i e <| e.updateApp! (← visit f) (← visit a)
| Expr.proj _ _ b => cache i e <| e.updateProj! (← visit b)
| Expr.forallE _ d b _ => cache e <| e.updateForallE! (← visit d) (← visit b)
| Expr.lam _ d b _ => cache e <| e.updateLambdaE! (← visit d) (← visit b)
| Expr.mdata _ b => cache e <| e.updateMData! (← visit b)
| Expr.letE _ t v b _ => cache e <| e.updateLet! (← visit t) (← visit v) (← visit b)
| Expr.app f a => cache e <| e.updateApp! (← visit f) (← visit a)
| Expr.proj _ _ b => cache e <| e.updateProj! (← visit b)
| e => pure e
visit e
unsafe def initCache : State :=
{ keys := mkArray cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr`
results := mkArray cacheSize.toNat default }
@[inline]
unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr :=
(replaceUnsafeM f? cacheSize e).run' initCache
(replaceUnsafeM f? e).run' Cache.new
end ReplaceImpl