diff --git a/tests/compiler/phashmap2.lean b/tests/compiler/phashmap2.lean index 6285c13184..b9270fff97 100644 --- a/tests/compiler/phashmap2.lean +++ b/tests/compiler/phashmap2.lean @@ -1,3 +1,4 @@ +#lang lean4 import Std.Data.PersistentHashMap import Lean.Data.Format open Lean Std Std.PersistentHashMap diff --git a/tests/compiler/phashmap3.lean b/tests/compiler/phashmap3.lean index b968f8f229..33fbeb2221 100644 --- a/tests/compiler/phashmap3.lean +++ b/tests/compiler/phashmap3.lean @@ -1,3 +1,4 @@ +#lang lean4 import Std.Data.PersistentHashMap import Lean.Data.Format open Lean Std Std.PersistentHashMap @@ -26,8 +27,8 @@ partial def formatMap : Node Nat Nat → Format Format.nil def checkState (m : Map) : IO Unit := do -unless (m.stats.maxDepth == 1) (IO.println "unexpected max depth"); -unless (m.stats.numCollisions == 0) (IO.println "unexpected number of collisions") +unless (m.stats.maxDepth == 1) do (IO.println "unexpected max depth"); +unless (m.stats.numCollisions == 0) do (IO.println "unexpected number of collisions") def main : IO Unit := do let m : Map := PersistentHashMap.empty; @@ -37,13 +38,13 @@ let max := PersistentHashMap.maxDepth.toNat; let m := m.insert (32^max + 1) 3; let m := m.insert (32^(max+1) + 1) 4; let m := m.insert (32^(max+2) + 1) 5; -unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) (IO.println "unexpected max depth"); -unless (m.stats.numCollisions == 3) (IO.println "unexpected number of collisions"); +unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat) do (IO.println "unexpected max depth"); +unless (m.stats.numCollisions == 3) do (IO.println "unexpected number of collisions"); IO.println m.stats; let m := m.erase (32^(max+1) + 1); let m := m.erase (32^(max+2) + 1); let m := m.erase (32^max + 1); -unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) (IO.println "unexpected max depth"); +unless (m.stats.maxDepth == PersistentHashMap.maxDepth.toNat - 1) do (IO.println "unexpected max depth"); let m := m.erase (32^5 + 1); checkState m; IO.println m.stats diff --git a/tests/compiler/rbmap_library.lean b/tests/compiler/rbmap_library.lean index a79b14e79c..4b28d491bf 100644 --- a/tests/compiler/rbmap_library.lean +++ b/tests/compiler/rbmap_library.lean @@ -1,8 +1,9 @@ +#lang lean4 import Std open Std -def check (b : Bool) : IO Unit := -unless b $ IO.println "ERROR" +def check (b : Bool) : IO Unit := do +unless b do IO.println "ERROR" def sz {α β : Type} {lt : α → α → Bool} (m : RBMap α β lt) : Nat := m.fold (fun sz _ _ => sz+1) 0 @@ -11,29 +12,29 @@ def depth {α β : Type} {lt : α → α → Bool} (m : RBMap α β lt) : Nat := m.depth Nat.max def tst1 : IO Unit := -do let Map := RBMap String Nat (fun a b => a < b); - let m : Map := {}; - let m := m.insert "hello" 0; - let m := m.insert "world" 1; - check (m.find? "hello" == some 0); - check (m.find? "world" == some 1); - let m := m.erase "hello"; - check (m.find? "hello" == none); - check (m.find? "world" == some 1); +do let Map := RBMap String Nat (fun a b => a < b) + let m : Map := {} + let m := m.insert "hello" 0 + let m := m.insert "world" 1 + check (m.find? "hello" == some 0) + check (m.find? "world" == some 1) + let m := m.erase "hello" + check (m.find? "hello" == none) + check (m.find? "world" == some 1) pure () def tst2 : IO Unit := -do let Map := RBMap Nat Nat (fun a b => a < b); - let m : Map := {}; - let n : Nat := 10000; - let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m; - check (m.all (fun k v => v == k*10)); - check (sz m == n); - IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m)); - let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m; - check (m.all (fun k v => v == k*10)); - check (sz m == n / 2); - IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m)); +do let Map := RBMap Nat Nat (fun a b => a < b) + let m : Map := {} + let n : Nat := 10000 + let m := n.fold (fun i (m : Map) => m.insert i (i*10)) m + check (m.all (fun k v => v == k*10)) + check (sz m == n) + IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m)) + let m := (n/2).fold (fun i (m : Map) => m.erase (2*i)) m + check (m.all (fun k v => v == k*10)) + check (sz m == n / 2) + IO.println (">> " ++ toString (depth m) ++ ", " ++ toString (sz m)) pure () abbrev Map := RBMap Nat Nat (fun a b => a < b) @@ -41,25 +42,25 @@ abbrev Map := RBMap Nat Nat (fun a b => a < b) def mkRandMap (max : Nat) : Nat → Map → Array (Nat × Nat) → IO (Map × Array (Nat × Nat)) | 0, m, a => pure (m, a) | n+1, m, a => do - k ← IO.rand 0 max; - v ← IO.rand 0 max; + let k ← IO.rand 0 max + let v ← IO.rand 0 max if m.find? k == none then do - let m := m.insert k v; - let a := a.push (k, v); - mkRandMap n m a + let m := m.insert k v + let a := a.push (k, v) + mkRandMap max n m a else - mkRandMap n m a + mkRandMap max n m a def tst3 (seed : Nat) (n : Nat) (max : Nat) : IO Unit := -do IO.setRandSeed seed; - (m, a) ← mkRandMap max n {} Array.empty; - check (sz m == a.size); - check (a.all (fun ⟨k, v⟩ => m.find? k == some v)); - IO.println ("tst3 size: " ++ toString a.size); - let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m); - check (sz m == a.size / 2); - a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v))); - IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m)); +do IO.setRandSeed seed + let (m, a) ← mkRandMap max n {} Array.empty + check (sz m == a.size) + check (a.all (fun ⟨k, v⟩ => m.find? k == some v)) + IO.println ("tst3 size: " ++ toString a.size) + let m := a.iterate m (fun i ⟨k, v⟩ m => if i.val % 2 == 0 then m.erase k else m) + check (sz m == a.size / 2) + a.iterateM () (fun i ⟨k, v⟩ _ => when (i.val % 2 == 1) (check (m.find? k == some v))) + IO.println ("tst3 after, depth: " ++ toString (depth m) ++ ", size: " ++ toString (sz m)) pure () def main (xs : List String) : IO Unit :=