chore: fix tests

This commit is contained in:
Leonardo de Moura 2020-10-20 16:15:30 -07:00
parent 93f7b1d7bc
commit 192d45d867
3 changed files with 45 additions and 42 deletions

View file

@ -1,3 +1,4 @@
#lang lean4
import Std.Data.PersistentHashMap
import Lean.Data.Format
open Lean Std Std.PersistentHashMap

View file

@ -1,3 +1,4 @@
#lang lean4
import Std.Data.PersistentHashMap
import Lean.Data.Format
open Lean Std Std.PersistentHashMap
@ -26,8 +27,8 @@ partial def formatMap : Node Nat Nat → Format
Format.nil
def checkState (m : Map) : IO Unit := do
unless (m.stats.maxDepth == 1) (IO.println "unexpected max depth");
unless (m.stats.numCollisions == 0) (IO.println "unexpected number of collisions")
unless (m.stats.maxDepth == 1) do (IO.println "unexpected max depth");
unless (m.stats.numCollisions == 0) do (IO.println "unexpected number of collisions")
def main : IO Unit := do
let m : Map := PersistentHashMap.empty;
@ -37,13 +38,13 @@ let max := PersistentHashMap.maxDepth.toNat;
let m := m.insert (32^max + 1) 3;
let m := m.insert (32^(max+1) + 1) 4;
let m := m.insert (32^(max+2) + 1) 5;
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) (IO.println "unexpected max depth");
unless (m.stats.numCollisions == 3) (IO.println "unexpected number of collisions");
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) do (IO.println "unexpected max depth");
unless (m.stats.numCollisions == 3) do (IO.println "unexpected number of collisions");
IO.println m.stats;
let m := m.erase (32^(max+1) + 1);
let m := m.erase (32^(max+2) + 1);
let m := m.erase (32^max + 1);
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) (IO.println "unexpected max depth");
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) do (IO.println "unexpected max depth");
let m := m.erase (32^5 + 1);
checkState m;
IO.println m.stats

View file

@ -1,8 +1,9 @@
#lang lean4
import Std
open Std
def check (b : Bool) : IO Unit :=
unless b $ IO.println "ERROR"
def check (b : Bool) : IO Unit := do
unless b do IO.println "ERROR"
def sz {α β : Type} {lt : αα → Bool} (m : RBMap α β lt) : Nat :=
m.fold (fun sz _ _ => sz+1) 0
@ -11,29 +12,29 @@ def depth {α β : Type} {lt : αα → Bool} (m : RBMap α β lt) : Nat :=
m.depth Nat.max
def tst1 : IO Unit :=
do let Map := RBMap String Nat (fun a b => a < b);
let m : Map := {};
let m := m.insert "hello" 0;
let m := m.insert "world" 1;
check (m.find? "hello" == some 0);
check (m.find? "world" == some 1);
let m := m.erase "hello";
check (m.find? "hello" == none);
check (m.find? "world" == some 1);
do let Map := RBMap String Nat (fun a b => a < b)
let m : Map := {}
let m := m.insert "hello" 0
let m := m.insert "world" 1
check (m.find? "hello" == some 0)
check (m.find? "world" == some 1)
let m := m.erase "hello"
check (m.find? "hello" == none)
check (m.find? "world" == some 1)
pure ()
def tst2 : IO Unit :=
do let Map := RBMap Nat Nat (fun a b => a < b);
let m : Map := {};
let n : Nat := 10000;
let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m;
check (m.all (fun k v => v == k*10));
check (sz m == n);
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m));
let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m;
check (m.all (fun k v => v == k*10));
check (sz m == n / 2);
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m));
do let Map := RBMap Nat Nat (fun a b => a < b)
let m : Map := {}
let n : Nat := 10000
let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m
check (m.all (fun k v => v == k*10))
check (sz m == n)
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m))
let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m
check (m.all (fun k v => v == k*10))
check (sz m == n / 2)
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m))
pure ()
abbrev Map := RBMap Nat Nat (fun a b => a < b)
@ -41,25 +42,25 @@ abbrev Map := RBMap Nat Nat (fun a b => a < b)
def mkRandMap (max : Nat) : Nat → Map → Array (Nat × Nat) → IO (Map × Array (Nat × Nat))
| 0, m, a => pure (m, a)
| n+1, m, a => do
k ← IO.rand 0 max;
v ← IO.rand 0 max;
let k ← IO.rand 0 max
let v ← IO.rand 0 max
if m.find? k == none then do
let m := m.insert k v;
let a := a.push (k, v);
mkRandMap n m a
let m := m.insert k v
let a := a.push (k, v)
mkRandMap max n m a
else
mkRandMap n m a
mkRandMap max n m a
def tst3 (seed : Nat) (n : Nat) (max : Nat) : IO Unit :=
do IO.setRandSeed seed;
(m, a) ← mkRandMap max n {} Array.empty;
check (sz m == a.size);
check (a.all (fun ⟨k, v⟩ => m.find? k == some v));
IO.println ("tst3 size: " ++ toString a.size);
let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m);
check (sz m == a.size / 2);
a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v)));
IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m));
do IO.setRandSeed seed
let (m, a) ← mkRandMap max n {} Array.empty
check (sz m == a.size)
check (a.all (fun ⟨k, v⟩ => m.find? k == some v))
IO.println ("tst3 size: " ++ toString a.size)
let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m)
check (sz m == a.size / 2)
a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v)))
IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m))
pure ()
def main (xs : List String) : IO Unit :=