From b69c851a5a3beaffeed85949db02c09a4f90591e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 16 Mar 2020 12:44:38 -0700 Subject: [PATCH] feat: add `Expr.foldConsts` --- src/Init/Lean/Util.lean | 1 + src/Init/Lean/Util/FoldConsts.lean | 63 ++++++++++++++++++++++++++++++ tests/lean/run/foldConsts.lean | 14 +++++++ 3 files changed, 78 insertions(+) create mode 100644 src/Init/Lean/Util/FoldConsts.lean create mode 100644 tests/lean/run/foldConsts.lean diff --git a/src/Init/Lean/Util.lean b/src/Init/Lean/Util.lean index 3b439978df..aa614b1111 100644 --- a/src/Init/Lean/Util.lean +++ b/src/Init/Lean/Util.lean @@ -19,3 +19,4 @@ import Init.Lean.Util.Trace import Init.Lean.Util.WHNF import Init.Lean.Util.FindExpr import Init.Lean.Util.ReplaceExpr +import Init.Lean.Util.FoldConsts diff --git a/src/Init/Lean/Util/FoldConsts.lean b/src/Init/Lean/Util/FoldConsts.lean new file mode 100644 index 0000000000..94a281f9ba --- /dev/null +++ b/src/Init/Lean/Util/FoldConsts.lean @@ -0,0 +1,63 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Control.Option +import Init.Lean.Expr + +namespace Lean +namespace Expr +namespace FoldConstsImpl + +abbrev cacheSize : USize := 8192 + +structure State := +(visitedTerms : Array Expr) -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr +(visitedConsts : NameHashSet) -- cache based on structural equality + +abbrev FoldM := StateM State + +@[inline] unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do +s ← get; +let h := ptrAddrUnsafe e; +let i := h % size; +let k := s.visitedTerms.uget i lcProof; +if ptrAddrUnsafe k == h then pure true +else do + modify $ fun s => { visitedTerms := s.visitedTerms.uset i e lcProof, .. s }; + pure false + +@[specialize] unsafe partial def fold {α : Type} (f : Name → α → α) (size : USize) : Expr → α → FoldM α +| e, acc => condM (liftM $ visited e size) (pure acc) $ + match e with + | Expr.forallE _ d b _ => do acc ← fold d acc; fold b acc + | Expr.lam _ d b _ => do acc ← fold d acc; fold b acc + | Expr.mdata _ b _ => fold b acc + | Expr.letE _ t v b _ => do acc ← fold t acc; acc ← fold v acc; fold b acc + | Expr.app f a _ => do acc ← fold f acc; fold a acc + | Expr.proj _ _ b _ => fold b acc + | Expr.const c _ _ => do + s ← get; + if s.visitedConsts.contains c then pure acc + else do + modify $ fun s => { visitedConsts := s.visitedConsts.insert c, .. s }; + pure $ f c acc + | _ => pure acc + +unsafe def initCache : State := +{ visitedTerms := mkArray cacheSize.toNat (cast lcProof ()), + visitedConsts := {} } + +@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → α) : α := +(fold f cacheSize e init).run' initCache + +end FoldConstsImpl + +/-- Apply `f` to every constant occurring in `e` once. -/ +@[implementedBy FoldConstsImpl.foldUnsafe] +constant foldConsts {α : Type} (e : Expr) (init : α) (f : Name → α → α) : α := init + +end Expr +end Lean diff --git a/tests/lean/run/foldConsts.lean b/tests/lean/run/foldConsts.lean new file mode 100644 index 0000000000..ba00314037 --- /dev/null +++ b/tests/lean/run/foldConsts.lean @@ -0,0 +1,14 @@ +import Init.Lean +open Lean + +def mkTerm : Nat → Expr +| 0 => mkApp (mkConst `a) (mkConst `b) +| n+1 => mkApp (mkTerm n) (mkTerm n) + +def collectConsts (e : Expr) : List Name := +e.foldConsts [] List.cons + +def tst1 : IO Unit := +IO.println $ collectConsts (mkTerm 1000) + +#eval tst1