feat: add Expr.foldConsts

This commit is contained in:
Leonardo de Moura 2020-03-16 12:44:38 -07:00
parent 0ee6672a77
commit b69c851a5a
3 changed files with 78 additions and 0 deletions

View file

@ -19,3 +19,4 @@ import Init.Lean.Util.Trace
import Init.Lean.Util.WHNF
import Init.Lean.Util.FindExpr
import Init.Lean.Util.ReplaceExpr
import Init.Lean.Util.FoldConsts

View file

@ -0,0 +1,63 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Control.Option
import Init.Lean.Expr
namespace Lean
namespace Expr
namespace FoldConstsImpl
abbrev cacheSize : USize := 8192
structure State :=
(visitedTerms : Array Expr) -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr
(visitedConsts : NameHashSet) -- cache based on structural equality
abbrev FoldM := StateM State
@[inline] unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do
s ← get;
let h := ptrAddrUnsafe e;
let i := h % size;
let k := s.visitedTerms.uget i lcProof;
if ptrAddrUnsafe k == h then pure true
else do
modify $ fun s => { visitedTerms := s.visitedTerms.uset i e lcProof, .. s };
pure false
@[specialize] unsafe partial def fold {α : Type} (f : Name → αα) (size : USize) : Expr → α → FoldM α
| e, acc => condM (liftM $ visited e size) (pure acc) $
match e with
| Expr.forallE _ d b _ => do acc ← fold d acc; fold b acc
| Expr.lam _ d b _ => do acc ← fold d acc; fold b acc
| Expr.mdata _ b _ => fold b acc
| Expr.letE _ t v b _ => do acc ← fold t acc; acc ← fold v acc; fold b acc
| Expr.app f a _ => do acc ← fold f acc; fold a acc
| Expr.proj _ _ b _ => fold b acc
| Expr.const c _ _ => do
s ← get;
if s.visitedConsts.contains c then pure acc
else do
modify $ fun s => { visitedConsts := s.visitedConsts.insert c, .. s };
pure $ f c acc
| _ => pure acc
unsafe def initCache : State :=
{ visitedTerms := mkArray cacheSize.toNat (cast lcProof ()),
visitedConsts := {} }
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → αα) : α :=
(fold f cacheSize e init).run' initCache
end FoldConstsImpl
/-- Apply `f` to every constant occurring in `e` once. -/
@[implementedBy FoldConstsImpl.foldUnsafe]
constant foldConsts {α : Type} (e : Expr) (init : α) (f : Name → αα) : α := init
end Expr
end Lean

View file

@ -0,0 +1,14 @@
import Init.Lean
open Lean
def mkTerm : Nat → Expr
| 0 => mkApp (mkConst `a) (mkConst `b)
| n+1 => mkApp (mkTerm n) (mkTerm n)
def collectConsts (e : Expr) : List Name :=
e.foldConsts [] List.cons
def tst1 : IO Unit :=
IO.println $ collectConsts (mkTerm 1000)
#eval tst1