perf: ensure Expr.replaceExpr preserve DAG structure in Exprs (#4779)
This commit is contained in:
parent
4eb842560c
commit
a94805ff71
2 changed files with 40 additions and 50 deletions
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
prelude
|
||||
import Init.Data.Hashable
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.HashMap
|
||||
|
||||
namespace Lean
|
||||
|
||||
|
|
@ -33,4 +34,22 @@ unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
|
|||
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
|
||||
HashSet.contains s { value := a }
|
||||
|
||||
/--
|
||||
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
|
||||
-/
|
||||
unsafe def PtrMap (α : Type) (β : Type) :=
|
||||
HashMap (Ptr α) β
|
||||
|
||||
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
|
||||
mkHashMap capacity
|
||||
|
||||
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
|
||||
HashMap.insert s { value := a } b
|
||||
|
||||
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
|
||||
HashMap.contains s { value := a }
|
||||
|
||||
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
|
||||
HashMap.find? s { value := a }
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -5,74 +5,45 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
|
|||
-/
|
||||
prelude
|
||||
import Lean.Expr
|
||||
import Lean.Util.PtrSet
|
||||
|
||||
namespace Lean
|
||||
namespace Expr
|
||||
|
||||
namespace ReplaceImpl
|
||||
|
||||
structure Cache where
|
||||
size : USize
|
||||
-- First `size` elements are the keys.
|
||||
-- Second `size` elements are the results.
|
||||
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
|
||||
|
||||
unsafe def Cache.new (e : Expr) : Cache :=
|
||||
-- scale size with approximate number of subterms up to 8k
|
||||
-- make sure size is coprime with power of two for collision avoidance
|
||||
let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1
|
||||
{ size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) }
|
||||
unsafe abbrev ReplaceM := StateM (PtrMap Expr Expr)
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize :=
|
||||
ptrAddrUnsafe key % c.size
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize :=
|
||||
c.keyIdx key + c.size
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
|
||||
have : (c.keyIdx key).toNat < c.keysResults.size := lcProof
|
||||
ptrEq (unsafeCast key) c.keysResults[c.keyIdx key]
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
|
||||
have : (c.resultIdx key).toNat < c.keysResults.size := lcProof
|
||||
unsafeCast c.keysResults[c.resultIdx key]
|
||||
|
||||
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
|
||||
{ c with keysResults := c.keysResults
|
||||
|>.uset (c.keyIdx key) (unsafeCast key) lcProof
|
||||
|>.uset (c.resultIdx key) (unsafeCast result) lcProof }
|
||||
|
||||
abbrev ReplaceM := StateM Cache
|
||||
|
||||
@[inline]
|
||||
unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do
|
||||
modify (·.store key result)
|
||||
unsafe def cache (key : Expr) (exclusive : Bool) (result : Expr) : ReplaceM Expr := do
|
||||
unless exclusive do
|
||||
modify (·.insert key result)
|
||||
pure result
|
||||
|
||||
@[specialize]
|
||||
unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr := do
|
||||
let rec @[specialize] visit (e : Expr) := do
|
||||
if (← get).hasResultFor e then
|
||||
return (← get).getResultFor e
|
||||
else match f? e with
|
||||
| some eNew => cache e eNew
|
||||
-- TODO: We need better control over RC operations to ensure
|
||||
-- the following (unsafe) optimization is correctly applied.
|
||||
let excl := isExclusiveUnsafe e
|
||||
unless excl do
|
||||
if let some result := (← get).find? e then
|
||||
return result
|
||||
match f? e with
|
||||
| some eNew => cache e excl eNew
|
||||
| none => match e with
|
||||
| Expr.forallE _ d b _ => cache e <| e.updateForallE! (← visit d) (← visit b)
|
||||
| Expr.lam _ d b _ => cache e <| e.updateLambdaE! (← visit d) (← visit b)
|
||||
| Expr.mdata _ b => cache e <| e.updateMData! (← visit b)
|
||||
| Expr.letE _ t v b _ => cache e <| e.updateLet! (← visit t) (← visit v) (← visit b)
|
||||
| Expr.app f a => cache e <| e.updateApp! (← visit f) (← visit a)
|
||||
| Expr.proj _ _ b => cache e <| e.updateProj! (← visit b)
|
||||
| e => pure e
|
||||
| .forallE _ d b _ => cache e excl <| e.updateForallE! (← visit d) (← visit b)
|
||||
| .lam _ d b _ => cache e excl <| e.updateLambdaE! (← visit d) (← visit b)
|
||||
| .mdata _ b => cache e excl <| e.updateMData! (← visit b)
|
||||
| .letE _ t v b _ => cache e excl <| e.updateLet! (← visit t) (← visit v) (← visit b)
|
||||
| .app f a => cache e excl <| e.updateApp! (← visit f) (← visit a)
|
||||
| .proj _ _ b => cache e excl <| e.updateProj! (← visit b)
|
||||
| e => return e
|
||||
visit e
|
||||
|
||||
@[inline]
|
||||
unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr :=
|
||||
(replaceUnsafeM f? e).run' (Cache.new e)
|
||||
(replaceUnsafeM f? e).run' mkPtrMap
|
||||
|
||||
end ReplaceImpl
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue