feat(library/init/data/persistenthashmap/basic): add PersistentHashMap.contains

This commit is contained in:
Leonardo de Moura 2019-08-09 11:20:56 -07:00
parent b8cd88a827
commit 92da659ec7
2 changed files with 29 additions and 0 deletions

View file

@ -153,6 +153,26 @@ partial def findAux [HasBeq α] : Node α β → USize → α → Option β
def find [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Option β
| { root := n, .. }, k => findAux n (hash k) k
partial def containsAtAux [HasBeq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) : Nat → α → Bool
| i, k =>
if h : i < keys.size then
let k' := keys.fget ⟨i, h⟩;
if k == k' then true
else containsAtAux (i+1) k
else false
partial def containsAux [HasBeq α] : Node α β → USize → α → Bool
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat;
match entries.get j with
| Entry.null => false
| Entry.ref node => containsAux node (div2Shift h shift) k
| Entry.entry k' v => k == k'
| Node.collision keys vals heq, _, k => containsAtAux keys vals heq 0 k
def contains [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Bool
| { root := n, .. }, k => containsAux n (hash k) k
partial def isUnaryEntries (a : Array (Entry α β (Node α β))) : Nat → Option (α × β) → Option (α × β)
| i, acc =>
if h : i < a.size then

View file

@ -49,6 +49,12 @@ n.mfor $ fun i =>
def delLess (n : Nat) (m : Map) : Map :=
n.fold (fun i m => m.erase i) m
def checkContains (n : Nat) (m : Map) : IO Unit :=
n.mfor $ fun i =>
match m.find i with
| none => unless (!m.contains i) (IO.println "bug at contains!")
| some _ => unless (m.contains i) (IO.println "bug at contains!")
def main (xs : List String) : IO Unit :=
do
let n := 500000;
@ -56,14 +62,17 @@ let m := mkMap n;
-- IO.println (formatMap m.root);
IO.println m.stats;
check n m;
checkContains n m;
let m := delOdd n m;
IO.println m.stats;
check2 n 0 m;
checkContains n m;
let m := delLess 499000 m;
check2 n 499000 m;
IO.println m.size;
IO.println m.stats;
let m := delLess 499900 m;
check2 n 499900 m;
checkContains n m;
IO.println m.size;
IO.println m.stats