From 466e8a6c5eb37d1382b7ab3d354b18f54c25662c Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Tue, 15 Jul 2025 17:10:01 -0700 Subject: [PATCH] fix: adjust unsafe trick for upcoming optimization (#9393) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes an unsafe trick where a sentinel for a hash table of Exprs (keyed by pointer) is created by constructing a value whose runtime representation can never be a valid Expr. The value chosen for this purpose was Unit.unit, which violates the inference that Expr has no scalar constructors. Instead, we change this to a freshly allocated Unit × Unit value. --- src/Lean/Util/ForEachExprWhere.lean | 6 ++++-- src/Lean/Util/ReplaceLevel.lean | 4 +++- tests/lean/run/addDecorationsWithoutPartial.lean | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/Lean/Util/ForEachExprWhere.lean b/src/Lean/Util/ForEachExprWhere.lean index 963e4baa33..88a5e18e1e 100644 --- a/src/Lean/Util/ForEachExprWhere.lean +++ b/src/Lean/Util/ForEachExprWhere.lean @@ -19,11 +19,13 @@ If `p` holds for most subterms, then it is more efficient to use `forEach f e`. namespace ForEachExprWhere abbrev cacheSize : USize := 8192 - 1 +private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩ + structure State where /-- Implements caching trick similar to the one used at `FindExpr` and `ReplaceExpr`. -/ - visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr + visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `notAnExpr` is not a valid Expr /-- Set of visited subterms that satisfy the predicate `p`. We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`. @@ -31,7 +33,7 @@ structure State where checked : Std.HashSet Expr unsafe def initCache : State := { - visited := .replicate cacheSize.toNat (cast lcProof ()) + visited := .replicate cacheSize.toNat (cast lcProof notAnExpr) checked := {} } diff --git a/src/Lean/Util/ReplaceLevel.lean b/src/Lean/Util/ReplaceLevel.lean index 1f49e4d29f..e159c3e20b 100644 --- a/src/Lean/Util/ReplaceLevel.lean +++ b/src/Lean/Util/ReplaceLevel.lean @@ -55,8 +55,10 @@ unsafe def replaceUnsafeM (f? : Level → Option Level) (size : USize) (e : Expr | e => pure e visit e +private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩ + unsafe def initCache : State := - { keys := .replicate cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr` + { keys := .replicate cacheSize.toNat (cast lcProof notAnExpr), -- `notAnExpr` is not a valid `Expr` results := .replicate cacheSize.toNat default } unsafe def replaceUnsafe (f? : Level → Option Level) (e : Expr) : Expr := diff --git a/tests/lean/run/addDecorationsWithoutPartial.lean b/tests/lean/run/addDecorationsWithoutPartial.lean index 548b229aab..e8f648acf9 100644 --- a/tests/lean/run/addDecorationsWithoutPartial.lean +++ b/tests/lean/run/addDecorationsWithoutPartial.lean @@ -36,8 +36,10 @@ unsafe def replaceUnsafeM (size : USize) (e : Expr) (f? : (e' : Expr) → sizeOf | e => pure e visit e +private def notAnExpr : Unit × Unit := ⟨⟨⟩, ⟨⟩⟩ + unsafe def initCache : State := - { keys := mkArray cacheSize.toNat (cast lcProof ()), -- `()` is not a valid `Expr` + { keys := mkArray cacheSize.toNat (cast lcProof notAnExpr), -- `notAnExpr` is not a valid `Expr` results := mkArray cacheSize.toNat default } unsafe def replaceUnsafe (e : Expr) (f? : (e' : Expr) → sizeOf e' ≤ sizeOf e → Option Expr) : Expr :=