perf: scale Expr.replace cache with input size

This commit is contained in:
Sebastian Ullrich 2023-03-14 12:59:23 +01:00
parent 96aa021007
commit 3d21124445

View file

@ -1,7 +1,7 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
-/
import Lean.Expr
@ -10,40 +10,40 @@ namespace Expr
namespace ReplaceImpl
@[inline] def cacheSize : USize := 8192 - 1
structure Cache where
-- First cacheSize elements are the keys.
-- Second cacheSize elements are the results.
size : USize
-- First `size` elements are the keys.
-- Second `size` elements are the results.
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
set_option compiler.extract_closed false in
unsafe def Cache.new : Cache :=
{ keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) }
unsafe def Cache.new (e : Expr) : Cache :=
-- scale size with approximate number of subterms up to 8k
-- make sure size is coprime with power of two for collision avoidance
let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1
{ size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) }
@[inline]
unsafe def Cache.keyIdx (key : Expr) : USize :=
ptrAddrUnsafe key % cacheSize
unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize :=
ptrAddrUnsafe key % c.size
@[inline]
unsafe def Cache.resultIdx (key : Expr) : USize :=
keyIdx key + cacheSize
unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize :=
c.keyIdx key + c.size
@[inline]
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
have : (keyIdx key).toNat < c.keysResults.size := lcProof
ptrEq (unsafeCast key) c.keysResults[keyIdx key]
have : (c.keyIdx key).toNat < c.keysResults.size := lcProof
ptrEq (unsafeCast key) c.keysResults[c.keyIdx key]
@[inline]
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
have : (resultIdx key).toNat < c.keysResults.size := lcProof
unsafeCast c.keysResults[resultIdx key]
have : (c.resultIdx key).toNat < c.keysResults.size := lcProof
unsafeCast c.keysResults[c.resultIdx key]
@[inline]
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
Cache.mk <| c.keysResults
|>.uset (keyIdx key) (unsafeCast key) lcProof
|>.uset (resultIdx key) (unsafeCast result) lcProof
{ c with keysResults := c.keysResults
|>.uset (c.keyIdx key) (unsafeCast key) lcProof
|>.uset (c.resultIdx key) (unsafeCast result) lcProof }
abbrev ReplaceM := StateM Cache
@ -71,7 +71,7 @@ unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr
@[inline]
unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr :=
(replaceUnsafeM f? e).run' Cache.new
(replaceUnsafeM f? e).run' (Cache.new e)
end ReplaceImpl