From 3d21124445b32e61af1ba313eed40b708117dfeb Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Tue, 14 Mar 2023 12:59:23 +0100 Subject: [PATCH] perf: scale `Expr.replace` cache with input size --- src/Lean/Util/ReplaceExpr.lean | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Lean/Util/ReplaceExpr.lean b/src/Lean/Util/ReplaceExpr.lean index 853fc717c0..5d39d08c4e 100644 --- a/src/Lean/Util/ReplaceExpr.lean +++ b/src/Lean/Util/ReplaceExpr.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura +Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich -/ import Lean.Expr @@ -10,40 +10,40 @@ namespace Expr namespace ReplaceImpl -@[inline] def cacheSize : USize := 8192 - 1 - structure Cache where - -- First cacheSize elements are the keys. - -- Second cacheSize elements are the results. + size : USize + -- First `size` elements are the keys. + -- Second `size` elements are the results. keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation) -set_option compiler.extract_closed false in -unsafe def Cache.new : Cache := - { keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) } +unsafe def Cache.new (e : Expr) : Cache := + -- scale size with approximate number of subterms up to 8k + -- make sure size is coprime with power of two for collision avoidance + let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1 + { size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) } @[inline] -unsafe def Cache.keyIdx (key : Expr) : USize := - ptrAddrUnsafe key % cacheSize +unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize := + ptrAddrUnsafe key % c.size @[inline] -unsafe def Cache.resultIdx (key : Expr) : USize := - keyIdx key + cacheSize +unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize := + c.keyIdx key + c.size @[inline] unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool := - have : (keyIdx key).toNat < c.keysResults.size := lcProof - ptrEq (unsafeCast key) c.keysResults[keyIdx key] + have : (c.keyIdx key).toNat < c.keysResults.size := lcProof + ptrEq (unsafeCast key) c.keysResults[c.keyIdx key] @[inline] unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr := - have : (resultIdx key).toNat < c.keysResults.size := lcProof - unsafeCast c.keysResults[resultIdx key] + have : (c.resultIdx key).toNat < c.keysResults.size := lcProof + unsafeCast c.keysResults[c.resultIdx key] -@[inline] unsafe def Cache.store (c : Cache) (key result : Expr) : Cache := - Cache.mk <| c.keysResults - |>.uset (keyIdx key) (unsafeCast key) lcProof - |>.uset (resultIdx key) (unsafeCast result) lcProof + { c with keysResults := c.keysResults + |>.uset (c.keyIdx key) (unsafeCast key) lcProof + |>.uset (c.resultIdx key) (unsafeCast result) lcProof } abbrev ReplaceM := StateM Cache @@ -71,7 +71,7 @@ unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr @[inline] unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr := - (replaceUnsafeM f? e).run' Cache.new + (replaceUnsafeM f? e).run' (Cache.new e) end ReplaceImpl