chore: fix tests
This commit is contained in:
parent
93f7b1d7bc
commit
192d45d867
3 changed files with 45 additions and 42 deletions
|
|
@ -1,3 +1,4 @@
|
|||
#lang lean4
|
||||
import Std.Data.PersistentHashMap
|
||||
import Lean.Data.Format
|
||||
open Lean Std Std.PersistentHashMap
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#lang lean4
|
||||
import Std.Data.PersistentHashMap
|
||||
import Lean.Data.Format
|
||||
open Lean Std Std.PersistentHashMap
|
||||
|
|
@ -26,8 +27,8 @@ partial def formatMap : Node Nat Nat → Format
|
|||
Format.nil
|
||||
|
||||
def checkState (m : Map) : IO Unit := do
|
||||
unless (m.stats.maxDepth == 1) (IO.println "unexpected max depth");
|
||||
unless (m.stats.numCollisions == 0) (IO.println "unexpected number of collisions")
|
||||
unless (m.stats.maxDepth == 1) do (IO.println "unexpected max depth");
|
||||
unless (m.stats.numCollisions == 0) do (IO.println "unexpected number of collisions")
|
||||
|
||||
def main : IO Unit := do
|
||||
let m : Map := PersistentHashMap.empty;
|
||||
|
|
@ -37,13 +38,13 @@ let max := PersistentHashMap.maxDepth.toNat;
|
|||
let m := m.insert (32^max + 1) 3;
|
||||
let m := m.insert (32^(max+1) + 1) 4;
|
||||
let m := m.insert (32^(max+2) + 1) 5;
|
||||
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) (IO.println "unexpected max depth");
|
||||
unless (m.stats.numCollisions == 3) (IO.println "unexpected number of collisions");
|
||||
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) do (IO.println "unexpected max depth");
|
||||
unless (m.stats.numCollisions == 3) do (IO.println "unexpected number of collisions");
|
||||
IO.println m.stats;
|
||||
let m := m.erase (32^(max+1) + 1);
|
||||
let m := m.erase (32^(max+2) + 1);
|
||||
let m := m.erase (32^max + 1);
|
||||
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) (IO.println "unexpected max depth");
|
||||
unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) do (IO.println "unexpected max depth");
|
||||
let m := m.erase (32^5 + 1);
|
||||
checkState m;
|
||||
IO.println m.stats
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#lang lean4
|
||||
import Std
|
||||
open Std
|
||||
|
||||
def check (b : Bool) : IO Unit :=
|
||||
unless b $ IO.println "ERROR"
|
||||
def check (b : Bool) : IO Unit := do
|
||||
unless b do IO.println "ERROR"
|
||||
|
||||
def sz {α β : Type} {lt : α → α → Bool} (m : RBMap α β lt) : Nat :=
|
||||
m.fold (fun sz _ _ => sz+1) 0
|
||||
|
|
@ -11,29 +12,29 @@ def depth {α β : Type} {lt : α → α → Bool} (m : RBMap α β lt) : Nat :=
|
|||
m.depth Nat.max
|
||||
|
||||
def tst1 : IO Unit :=
|
||||
do let Map := RBMap String Nat (fun a b => a < b);
|
||||
let m : Map := {};
|
||||
let m := m.insert "hello" 0;
|
||||
let m := m.insert "world" 1;
|
||||
check (m.find? "hello" == some 0);
|
||||
check (m.find? "world" == some 1);
|
||||
let m := m.erase "hello";
|
||||
check (m.find? "hello" == none);
|
||||
check (m.find? "world" == some 1);
|
||||
do let Map := RBMap String Nat (fun a b => a < b)
|
||||
let m : Map := {}
|
||||
let m := m.insert "hello" 0
|
||||
let m := m.insert "world" 1
|
||||
check (m.find? "hello" == some 0)
|
||||
check (m.find? "world" == some 1)
|
||||
let m := m.erase "hello"
|
||||
check (m.find? "hello" == none)
|
||||
check (m.find? "world" == some 1)
|
||||
pure ()
|
||||
|
||||
def tst2 : IO Unit :=
|
||||
do let Map := RBMap Nat Nat (fun a b => a < b);
|
||||
let m : Map := {};
|
||||
let n : Nat := 10000;
|
||||
let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m;
|
||||
check (m.all (fun k v => v == k*10));
|
||||
check (sz m == n);
|
||||
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m));
|
||||
let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m;
|
||||
check (m.all (fun k v => v == k*10));
|
||||
check (sz m == n / 2);
|
||||
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m));
|
||||
do let Map := RBMap Nat Nat (fun a b => a < b)
|
||||
let m : Map := {}
|
||||
let n : Nat := 10000
|
||||
let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m
|
||||
check (m.all (fun k v => v == k*10))
|
||||
check (sz m == n)
|
||||
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m))
|
||||
let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m
|
||||
check (m.all (fun k v => v == k*10))
|
||||
check (sz m == n / 2)
|
||||
IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m))
|
||||
pure ()
|
||||
|
||||
abbrev Map := RBMap Nat Nat (fun a b => a < b)
|
||||
|
|
@ -41,25 +42,25 @@ abbrev Map := RBMap Nat Nat (fun a b => a < b)
|
|||
def mkRandMap (max : Nat) : Nat → Map → Array (Nat × Nat) → IO (Map × Array (Nat × Nat))
|
||||
| 0, m, a => pure (m, a)
|
||||
| n+1, m, a => do
|
||||
k ← IO.rand 0 max;
|
||||
v ← IO.rand 0 max;
|
||||
let k ← IO.rand 0 max
|
||||
let v ← IO.rand 0 max
|
||||
if m.find? k == none then do
|
||||
let m := m.insert k v;
|
||||
let a := a.push (k, v);
|
||||
mkRandMap n m a
|
||||
let m := m.insert k v
|
||||
let a := a.push (k, v)
|
||||
mkRandMap max n m a
|
||||
else
|
||||
mkRandMap n m a
|
||||
mkRandMap max n m a
|
||||
|
||||
def tst3 (seed : Nat) (n : Nat) (max : Nat) : IO Unit :=
|
||||
do IO.setRandSeed seed;
|
||||
(m, a) ← mkRandMap max n {} Array.empty;
|
||||
check (sz m == a.size);
|
||||
check (a.all (fun ⟨k, v⟩ => m.find? k == some v));
|
||||
IO.println ("tst3 size: " ++ toString a.size);
|
||||
let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m);
|
||||
check (sz m == a.size / 2);
|
||||
a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v)));
|
||||
IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m));
|
||||
do IO.setRandSeed seed
|
||||
let (m, a) ← mkRandMap max n {} Array.empty
|
||||
check (sz m == a.size)
|
||||
check (a.all (fun ⟨k, v⟩ => m.find? k == some v))
|
||||
IO.println ("tst3 size: " ++ toString a.size)
|
||||
let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m)
|
||||
check (sz m == a.size / 2)
|
||||
a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v)))
|
||||
IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m))
|
||||
pure ()
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue