perf: scale Expr.replace cache with input size
This commit is contained in:
parent
96aa021007
commit
3d21124445
1 changed files with 21 additions and 21 deletions
|
|
@ -1,7 +1,7 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
|
||||
-/
|
||||
import Lean.Expr
|
||||
|
||||
|
|
@ -10,40 +10,40 @@ namespace Expr
|
|||
|
||||
namespace ReplaceImpl
|
||||
|
||||
@[inline] def cacheSize : USize := 8192 - 1
|
||||
|
||||
structure Cache where
|
||||
-- First cacheSize elements are the keys.
|
||||
-- Second cacheSize elements are the results.
|
||||
size : USize
|
||||
-- First `size` elements are the keys.
|
||||
-- Second `size` elements are the results.
|
||||
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
|
||||
|
||||
set_option compiler.extract_closed false in
|
||||
unsafe def Cache.new : Cache :=
|
||||
{ keysResults := mkArray (2 * cacheSize).toNat (unsafeCast ()) }
|
||||
unsafe def Cache.new (e : Expr) : Cache :=
|
||||
-- scale size with approximate number of subterms up to 8k
|
||||
-- make sure size is coprime with power of two for collision avoidance
|
||||
let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1
|
||||
{ size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) }
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.keyIdx (key : Expr) : USize :=
|
||||
ptrAddrUnsafe key % cacheSize
|
||||
unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize :=
|
||||
ptrAddrUnsafe key % c.size
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.resultIdx (key : Expr) : USize :=
|
||||
keyIdx key + cacheSize
|
||||
unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize :=
|
||||
c.keyIdx key + c.size
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
|
||||
have : (keyIdx key).toNat < c.keysResults.size := lcProof
|
||||
ptrEq (unsafeCast key) c.keysResults[keyIdx key]
|
||||
have : (c.keyIdx key).toNat < c.keysResults.size := lcProof
|
||||
ptrEq (unsafeCast key) c.keysResults[c.keyIdx key]
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
|
||||
have : (resultIdx key).toNat < c.keysResults.size := lcProof
|
||||
unsafeCast c.keysResults[resultIdx key]
|
||||
have : (c.resultIdx key).toNat < c.keysResults.size := lcProof
|
||||
unsafeCast c.keysResults[c.resultIdx key]
|
||||
|
||||
@[inline]
|
||||
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
|
||||
Cache.mk <| c.keysResults
|
||||
|>.uset (keyIdx key) (unsafeCast key) lcProof
|
||||
|>.uset (resultIdx key) (unsafeCast result) lcProof
|
||||
{ c with keysResults := c.keysResults
|
||||
|>.uset (c.keyIdx key) (unsafeCast key) lcProof
|
||||
|>.uset (c.resultIdx key) (unsafeCast result) lcProof }
|
||||
|
||||
abbrev ReplaceM := StateM Cache
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ unsafe def replaceUnsafeM (f? : Expr → Option Expr) (e : Expr) : ReplaceM Expr
|
|||
|
||||
@[inline]
|
||||
unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr :=
|
||||
(replaceUnsafeM f? e).run' Cache.new
|
||||
(replaceUnsafeM f? e).run' (Cache.new e)
|
||||
|
||||
end ReplaceImpl
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue