fix: adjust unsafe trick for upcoming optimization (#9393)

This PR fixes an unsafe trick where a sentinel for a hash table of Exprs
(keyed by pointer) is created by constructing a value whose runtime
representation can never be a valid Expr. The value chosen for this
purpose was Unit.unit, which violates the inference that Expr has no
scalar constructors. Instead, we change this to a freshly allocated Unit
× Unit value.
This commit is contained in:
Cameron Zwarich 2025-07-15 17:10:01 -07:00 committed by GitHub
parent b131e8b97f
commit 466e8a6c5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 10 additions and 4 deletions

View file

@ -19,11 +19,13 @@ If `p` holds for most subterms, then it is more efficient to use `forEach f e`.
namespace ForEachExprWhere
abbrev cacheSize : USize := 8192 - 1
private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩
structure State where
/--
Implements caching trick similar to the one used at `FindExpr` and `ReplaceExpr`.
-/
visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr
visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `notAnExpr` is not a valid Expr
/--
Set of visited subterms that satisfy the predicate `p`.
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
@ -31,7 +33,7 @@ structure State where
checked : Std.HashSet Expr
unsafe def initCache : State := {
visited := .replicate cacheSize.toNat (cast lcProof ())
visited := .replicate cacheSize.toNat (cast lcProof notAnExpr)
checked := {}
}

View file

@ -55,8 +55,10 @@ unsafe def replaceUnsafeM (f? : Level → Option Level) (size : USize) (e : Expr
| e => pure e
visit e
private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩
unsafe def initCache : State :=
{ keys := .replicate cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr`
{ keys := .replicate cacheSize.toNat (cast lcProof notAnExpr), -- `notAnExpr` is not a valid `Expr`
results := .replicate cacheSize.toNat default }
unsafe def replaceUnsafe (f? : Level → Option Level) (e : Expr) : Expr :=

View file

@ -36,8 +36,10 @@ unsafe def replaceUnsafeM (size : USize) (e : Expr) (f? : (e' : Expr) → sizeOf
| e => pure e
visit e
private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩
unsafe def initCache : State :=
{ keys := mkArray cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr`
{ keys := mkArray cacheSize.toNat (cast lcProof notAnExpr), -- `notAnExpr` is not a valid `Expr`
results := mkArray cacheSize.toNat default }
unsafe def replaceUnsafe (e : Expr) (f? : (e' : Expr) → sizeOf e' ≤ sizeOf e → Option Expr) : Expr :=