diff --git a/src/Lean/Util/ReplaceExpr.lean b/src/Lean/Util/ReplaceExpr.lean index bd2ba7a7e1..2c0d27742f 100644 --- a/src/Lean/Util/ReplaceExpr.lean +++ b/src/Lean/Util/ReplaceExpr.lean @@ -10,43 +10,68 @@ namespace Expr namespace ReplaceImpl -abbrev cacheSize : USize := 8192 +@[inline] def cacheSize : USize := 8192 -structure State where - keys : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr - results : Array Expr +structure Cache where + -- First cacheSize elements are the keys. + -- Second cacheSize elements are the results. + keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation) -abbrev ReplaceM := StateM State +set_option compiler.extract_closed false in +unsafe def Cache.new : Cache := + { keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) } -unsafe def cache (i : USize) (key : Expr) (result : Expr) : ReplaceM Expr := do - modify fun ⟨keys, results⟩ => { keys := keys.uset i key lcProof, results := results.uset i result lcProof }; +@[inline] +unsafe def Cache.keyIdx (key : Expr) : USize := + (ptrAddrUnsafe key >>> 4) % cacheSize + +@[inline] +unsafe def Cache.resultIdx (key : Expr) : USize := + keyIdx key + cacheSize + +@[inline] +unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool := + have : (keyIdx key).toNat < c.keysResults.size := lcProof + ptrEq (unsafeCast key) c.keysResults[keyIdx key] + +@[inline] +unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr := + have : (resultIdx key).toNat < c.keysResults.size := lcProof + unsafeCast c.keysResults[resultIdx key] + +@[inline] +unsafe def Cache.store (c : Cache) (key result : Expr) : Cache := + Cache.mk <| c.keysResults + |>.uset (keyIdx key) (unsafeCast key) lcProof + |>.uset (resultIdx key) (unsafeCast result) lcProof + +abbrev ReplaceM := StateM Cache + +@[inline] +unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do + modify (·.store key result) pure result -unsafe def replaceUnsafeM (f? : Expr → Option Expr) (size : USize) (e : Expr) : ReplaceM Expr := do - let rec visit (e : Expr) := do - let c ← get - let h := ptrAddrUnsafe e - let i := h % size - if ptrAddrUnsafe (c.keys.uget i lcProof) == h then - pure <| c.results.uget i lcProof +@[specialize] +unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr := do + let rec @[specialize] visit (e : Expr) := do + if (← get).hasResultFor e then + return (← get).getResultFor e else match f? e with - | some eNew => cache i e eNew + | some eNew => cache e eNew | none => match e with - | Expr.forallE _ d b _ => cache i e <| e.updateForallE! (← visit d) (← visit b) - | Expr.lam _ d b _ => cache i e <| e.updateLambdaE! (← visit d) (← visit b) - | Expr.mdata _ b => cache i e <| e.updateMData! (← visit b) - | Expr.letE _ t v b _ => cache i e <| e.updateLet! (← visit t) (← visit v) (← visit b) - | Expr.app f a => cache i e <| e.updateApp! (← visit f) (← visit a) - | Expr.proj _ _ b => cache i e <| e.updateProj! (← visit b) + | Expr.forallE _ d b _ => cache e <| e.updateForallE! (← visit d) (← visit b) + | Expr.lam _ d b _ => cache e <| e.updateLambdaE! (← visit d) (← visit b) + | Expr.mdata _ b => cache e <| e.updateMData! (← visit b) + | Expr.letE _ t v b _ => cache e <| e.updateLet! (← visit t) (← visit v) (← visit b) + | Expr.app f a => cache e <| e.updateApp! (← visit f) (← visit a) + | Expr.proj _ _ b => cache e <| e.updateProj! (← visit b) | e => pure e visit e -unsafe def initCache : State := - { keys := mkArray cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr` - results := mkArray cacheSize.toNat default } - +@[inline] unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr := - (replaceUnsafeM f? cacheSize e).run' initCache + (replaceUnsafeM f? e).run' Cache.new end ReplaceImpl