From a94805ff71f7576ee014aafc76b29e29b066bd80 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 18 Jul 2024 04:24:15 +0200 Subject: [PATCH] perf: ensure `Expr.replaceExpr` preserve DAG structure in `Expr`s (#4779) --- src/Lean/Util/PtrSet.lean | 19 +++++++++ src/Lean/Util/ReplaceExpr.lean | 71 ++++++++++------------------------ 2 files changed, 40 insertions(+), 50 deletions(-) diff --git a/src/Lean/Util/PtrSet.lean b/src/Lean/Util/PtrSet.lean index 05b33c0d36..33908d4bd2 100644 --- a/src/Lean/Util/PtrSet.lean +++ b/src/Lean/Util/PtrSet.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Init.Data.Hashable import Lean.Data.HashSet +import Lean.Data.HashMap namespace Lean @@ -33,4 +34,22 @@ unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α := unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool := HashSet.contains s { value := a } +/-- +Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs. +-/ +unsafe def PtrMap (α : Type) (β : Type) := + HashMap (Ptr α) β + +unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β := + mkHashMap capacity + +unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β := + HashMap.insert s { value := a } b + +unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool := + HashMap.contains s { value := a } + +unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β := + HashMap.find? s { value := a } + end Lean diff --git a/src/Lean/Util/ReplaceExpr.lean b/src/Lean/Util/ReplaceExpr.lean index d0ff29b0fc..7354d49082 100644 --- a/src/Lean/Util/ReplaceExpr.lean +++ b/src/Lean/Util/ReplaceExpr.lean @@ -5,74 +5,45 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich -/ prelude import Lean.Expr +import Lean.Util.PtrSet namespace Lean namespace Expr namespace ReplaceImpl -structure Cache where - size : USize - -- First `size` elements are the keys. - -- Second `size` elements are the results. - keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation) - -unsafe def Cache.new (e : Expr) : Cache := - -- scale size with approximate number of subterms up to 8k - -- make sure size is coprime with power of two for collision avoidance - let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1 - { size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) } +unsafe abbrev ReplaceM := StateM (PtrMap Expr Expr) @[inline] -unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize := - ptrAddrUnsafe key % c.size - -@[inline] -unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize := - c.keyIdx key + c.size - -@[inline] -unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool := - have : (c.keyIdx key).toNat < c.keysResults.size := lcProof - ptrEq (unsafeCast key) c.keysResults[c.keyIdx key] - -@[inline] -unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr := - have : (c.resultIdx key).toNat < c.keysResults.size := lcProof - unsafeCast c.keysResults[c.resultIdx key] - -unsafe def Cache.store (c : Cache) (key result : Expr) : Cache := - { c with keysResults := c.keysResults - |>.uset (c.keyIdx key) (unsafeCast key) lcProof - |>.uset (c.resultIdx key) (unsafeCast result) lcProof } - -abbrev ReplaceM := StateM Cache - -@[inline] -unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do - modify (·.store key result) +unsafe def cache (key : Expr) (exclusive : Bool) (result : Expr) : ReplaceM Expr := do + unless exclusive do + modify (·.insert key result) pure result @[specialize] unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr := do let rec @[specialize] visit (e : Expr) := do - if (← get).hasResultFor e then - return (← get).getResultFor e - else match f? e with - | some eNew => cache e eNew + -- TODO: We need better control over RC operations to ensure + -- the following (unsafe) optimization is correctly applied. + let excl := isExclusiveUnsafe e + unless excl do + if let some result := (← get).find? e then + return result + match f? e with + | some eNew => cache e excl eNew | none => match e with - | Expr.forallE _ d b _ => cache e <| e.updateForallE! (← visit d) (← visit b) - | Expr.lam _ d b _ => cache e <| e.updateLambdaE! (← visit d) (← visit b) - | Expr.mdata _ b => cache e <| e.updateMData! (← visit b) - | Expr.letE _ t v b _ => cache e <| e.updateLet! (← visit t) (← visit v) (← visit b) - | Expr.app f a => cache e <| e.updateApp! (← visit f) (← visit a) - | Expr.proj _ _ b => cache e <| e.updateProj! (← visit b) - | e => pure e + | .forallE _ d b _ => cache e excl <| e.updateForallE! (← visit d) (← visit b) + | .lam _ d b _ => cache e excl <| e.updateLambdaE! (← visit d) (← visit b) + | .mdata _ b => cache e excl <| e.updateMData! (← visit b) + | .letE _ t v b _ => cache e excl <| e.updateLet! (← visit t) (← visit v) (← visit b) + | .app f a => cache e excl <| e.updateApp! (← visit f) (← visit a) + | .proj _ _ b => cache e excl <| e.updateProj! (← visit b) + | e => return e visit e @[inline] unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr := - (replaceUnsafeM f? e).run' (Cache.new e) + (replaceUnsafeM f? e).run' mkPtrMap end ReplaceImpl