90 lines
3 KiB
Text
90 lines
3 KiB
Text
/-
|
|
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
Authors: Leonardo de Moura
|
|
-/
|
|
prelude
|
|
import Lean.Expr
|
|
import Lean.Util.MonadCache
|
|
|
|
namespace Lean
|
|
/-!
|
|
`forEachWhere p f e` is similar to `forEach f e`, but only applies `f` to subterms that satisfy the
|
|
(pure) predicate `p`.
|
|
It also uses the caching trick used at `FindExpr` and `ReplaceExpr`. This can be very effective
|
|
if the number of subterms satisfying `p` is a small subset of the set of subterms.
|
|
If `p` holds for most subterms, then it is more efficient to use `forEach f e`.
|
|
-/
|
|
|
|
namespace ForEachExprWhere
|
|
abbrev cacheSize : USize := 8192 - 1
|
|
|
|
structure State where
|
|
/--
|
|
Implements caching trick similar to the one used at `FindExpr` and `ReplaceExpr`.
|
|
-/
|
|
visited : Array Expr -- Remark: our "unsafe" implementation relies on the fact that `()` is not a valid Expr
|
|
/--
|
|
Set of visited subterms that satisfy the predicate `p`.
|
|
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
|
|
-/
|
|
checked : Std.HashSet Expr
|
|
|
|
unsafe def initCache : State := {
|
|
visited := mkArray cacheSize.toNat (cast lcProof ())
|
|
checked := {}
|
|
}
|
|
|
|
abbrev ForEachM {ω : Type} (m : Type → Type) [STWorld ω m] := StateRefT' ω State m
|
|
|
|
variable {ω : Type} {m : Type → Type} [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m]
|
|
|
|
unsafe def visited (e : Expr) : ForEachM m Bool := do
|
|
let s ← get
|
|
let h := ptrAddrUnsafe e
|
|
let i := h % cacheSize
|
|
let k := s.visited.uget i lcProof
|
|
if ptrAddrUnsafe k == h then
|
|
return true
|
|
else
|
|
modify fun s => { s with visited := s.visited.uset i e lcProof }
|
|
return false
|
|
|
|
def checked (e : Expr) : ForEachM m Bool := do
|
|
if (← get).checked.contains e then
|
|
return true
|
|
else
|
|
modify fun s => { s with checked := s.checked.insert e }
|
|
return false
|
|
|
|
/-- `Expr.forEachWhere` (unsafe) implementation -/
|
|
unsafe def visit (p : Expr → Bool) (f : Expr → m Unit) (e : Expr) (stopWhenVisited : Bool := false) : m Unit := do
|
|
go e |>.run' initCache
|
|
where
|
|
go (e : Expr) : StateRefT' ω State m Unit := do
|
|
unless (← visited e) do
|
|
if p e then
|
|
unless (← checked e) do
|
|
f e
|
|
if stopWhenVisited then
|
|
return ()
|
|
match e with
|
|
| .forallE _ d b _ => go d; go b
|
|
| .lam _ d b _ => go d; go b
|
|
| .letE _ t v b _ => go t; go v; go b
|
|
| .app f a => go f; go a
|
|
| .mdata _ b => go b
|
|
| .proj _ _ b => go b
|
|
| _ => return ()
|
|
|
|
end ForEachExprWhere
|
|
|
|
/--
|
|
`e.forEachWhere p f` applies `f` to each subterm that satisfies `p`.
|
|
If `stopWhenVisited` is `true`, the function doesn't visit subterms of terms
|
|
which satisfy `p`.
|
|
-/
|
|
@[implemented_by ForEachExprWhere.visit]
|
|
opaque Expr.forEachWhere {ω : Type} {m : Type → Type} [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (p : Expr → Bool) (f : Expr → m Unit) (e : Expr) (stopWhenVisited : Bool := false) : m Unit
|
|
|
|
end Lean
|