From 92da659ec741d8c110e44c92abf828dbc8429853 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 9 Aug 2019 11:20:56 -0700 Subject: [PATCH] feat(library/init/data/persistenthashmap/basic): add `PersistentHashMap.contains` --- .../init/data/persistenthashmap/basic.lean | 20 +++++++++++++++++++ tests/playground/phashmap.lean | 9 +++++++++ 2 files changed, 29 insertions(+) diff --git a/library/init/data/persistenthashmap/basic.lean b/library/init/data/persistenthashmap/basic.lean index 4fabc7650e..cb67296ffc 100644 --- a/library/init/data/persistenthashmap/basic.lean +++ b/library/init/data/persistenthashmap/basic.lean @@ -153,6 +153,26 @@ partial def findAux [HasBeq α] : Node α β → USize → α → Option β def find [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Option β | { root := n, .. }, k => findAux n (hash k) k +partial def containsAtAux [HasBeq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) : Nat → α → Bool +| i, k => + if h : i < keys.size then + let k' := keys.fget ⟨i, h⟩; + if k == k' then true + else containsAtAux (i+1) k + else false + +partial def containsAux [HasBeq α] : Node α β → USize → α → Bool +| Node.entries entries, h, k => + let j := (mod2Shift h shift).toNat; + match entries.get j with + | Entry.null => false + | Entry.ref node => containsAux node (div2Shift h shift) k + | Entry.entry k' v => k == k' +| Node.collision keys vals heq, _, k => containsAtAux keys vals heq 0 k + +def contains [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Bool +| { root := n, .. }, k => containsAux n (hash k) k + partial def isUnaryEntries (a : Array (Entry α β (Node α β))) : Nat → Option (α × β) → Option (α × β) | i, acc => if h : i < a.size then diff --git a/tests/playground/phashmap.lean b/tests/playground/phashmap.lean index fe889f2487..00f8b8aa40 100644 --- a/tests/playground/phashmap.lean +++ b/tests/playground/phashmap.lean @@ -49,6 +49,12 @@ n.mfor $ fun i => def delLess (n : Nat) (m : Map) : Map := n.fold (fun i m => m.erase i) m +def checkContains (n : Nat) (m : Map) : IO Unit := +n.mfor $ fun i => + match m.find i with + | none => unless (!m.contains i) (IO.println "bug at contains!") + | some _ => unless (m.contains i) (IO.println "bug at contains!") + def main (xs : List String) : IO Unit := do let n := 500000; @@ -56,14 +62,17 @@ let m := mkMap n; -- IO.println (formatMap m.root); IO.println m.stats; check n m; +checkContains n m; let m := delOdd n m; IO.println m.stats; check2 n 0 m; +checkContains n m; let m := delLess 499000 m; check2 n 499000 m; IO.println m.size; IO.println m.stats; let m := delLess 499900 m; check2 n 499900 m; +checkContains n m; IO.println m.size; IO.println m.stats