lean4-htt/src/Lean/Util/ForEachExprWhere.lean

90 lines
3 KiB
Text

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Lean.Util.MonadCache
namespace Lean
/-!
`forEachWhere p f e` is similar to `forEach f e`, but only applies `f` to subterms that satisfy the
(pure) predicate `p`.
It also uses the caching trick used at `FindExpr` and `ReplaceExpr`. This can be very effective
if the number of subterms satisfying `p` is a small subset of the set of subterms.
If `p` holds for most subterms, then it is more efficient to use `forEach f e`.
-/
namespace ForEachExprWhere
abbrev cacheSize : USize := 8192 - 1
structure State where
/--
Implements caching trick similar to the one used at `FindExpr` and `ReplaceExpr`.
-/
visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr
/--
Set of visited subterms that satisfy the predicate `p`.
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
-/
checked : Std.HashSet Expr
unsafe def initCache : State := {
visited := mkArray cacheSize.toNat (cast lcProof ())
checked := {}
}
abbrev ForEachM {ω : Type} (m : Type → Type) [STWorld ω m] := StateRefT' ω State m
variable {ω : Type} {m : Type → Type} [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m]
unsafe def visited (e : Expr) : ForEachM m Bool := do
let s ← get
let h := ptrAddrUnsafe e
let i := h % cacheSize
let k := s.visited.uget i lcProof
if ptrAddrUnsafe k == h then
return true
else
modify fun s => { s with visited := s.visited.uset i e lcProof }
return false
def checked (e : Expr) : ForEachM m Bool := do
if (← get).checked.contains e then
return true
else
modify fun s => { s with checked := s.checked.insert e }
return false
/-- `Expr.forEachWhere` (unsafe) implementation -/
unsafe def visit (p : Expr → Bool) (f : Expr → m Unit) (e : Expr) (stopWhenVisited : Bool := false) : m Unit := do
go e |>.run' initCache
where
go (e : Expr) : StateRefT' ω State m Unit := do
unless (← visited e) do
if p e then
unless (← checked e) do
f e
if stopWhenVisited then
return ()
match e with
| .forallE _ d b _ => go d; go b
| .lam _ d b _ => go d; go b
| .letE _ t v b _ => go t; go v; go b
| .app f a => go f; go a
| .mdata _ b => go b
| .proj _ _ b => go b
| _ => return ()
end ForEachExprWhere
/--
`e.forEachWhere p f` applies `f` to each subterm that satisfies `p`.
If `stopWhenVisited` is `true`, the function doesn't visit subterms of terms
which satisfy `p`.
-/
@[implemented_by ForEachExprWhere.visit]
opaque Expr.forEachWhere {ω : Type} {m : Type → Type} [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (p : Expr → Bool) (f : Expr → m Unit) (e : Expr) (stopWhenVisited : Bool := false) : m Unit
end Lean