From 4bd347de3a9fffe14e7dfc17460585d13f4188ae Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 4 Aug 2019 11:50:05 -0700 Subject: [PATCH] feat(library/init/data/persistentarray/basic): `PersistentArray.pop` --- library/init/data/persistentarray/basic.lean | 45 +++++++++++++++++++- tests/playground/persistentarray.lean | 22 ++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/library/init/data/persistentarray/basic.lean b/library/init/data/persistentarray/basic.lean index 8815102933..f7b96c6dd5 100644 --- a/library/init/data/persistentarray/basic.lean +++ b/library/init/data/persistentarray/basic.lean @@ -131,6 +131,49 @@ if r.tail.size < branching.toNat || t.size >= tooBig then else mkNewTail r +private def emptyArray {α : Type u} : Array (PersistentArrayNode α) := +Array.mkEmpty PersistentArray.branching.toNat + +partial def popLeaf : PersistentArrayNode α → Option (Array α) × Array (PersistentArrayNode α) +| n@(node cs) := + if h : cs.size ≠ 0 then + let idx : Fin cs.size := ⟨cs.size - 1, Nat.predLt h⟩; + let last := cs.fget idx; + match popLeaf last with + | (none, _) => (none, emptyArray) + | (some l, newLast) => + if newLast.size == 0 then + let cs := cs.pop; + if cs.isEmpty then (some l, emptyArray) else (some l, cs) + else + (some l, cs.fset idx (node newLast)) + else + (none, emptyArray) +| (leaf vs) := (some vs, emptyArray) + +def pop (t : PersistentArray α) : PersistentArray α := +if t.tail.size > 0 then + { tail := t.tail.pop, size := t.size - 1, .. t } +else + match popLeaf t.root with + | (none, _) => t + | (some last, newRoots) => + let last := last.pop; + let newSize := t.size - 1; + let newTailOff := newSize - last.size; + if newRoots.size == 1 then + { root := newRoots.get 0, + shift := t.shift - initShift, + size := newSize, + tail := last, + tailOff := newTailOff } + else + { root := node newRoots, + size := newSize, + tail := last, + tailOff := newTailOff, + .. t } + section variables {m : Type u → Type v} [Monad m] variable {β: Type u} @@ -183,7 +226,7 @@ def stats (r : PersistentArray α) : Stats := collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0 def Stats.toString (s : Stats) : String := -toString [s.numNodes, s.depth, s.tailSize] +"{nodes := " ++ toString s.numNodes ++ ", depth := " ++ toString s.depth ++ ", tail size := " ++ toString s.tailSize ++ "}" instance : HasToString Stats := ⟨Stats.toString⟩ diff --git a/tests/playground/persistentarray.lean b/tests/playground/persistentarray.lean index d942964392..44b2d57b7b 100644 --- a/tests/playground/persistentarray.lean +++ b/tests/playground/persistentarray.lean @@ -16,6 +16,17 @@ n.fold (λ i s => s.set i (s.get i + 1)) s def checkId (n : Nat) (s : MyArray) : IO Unit := check n (fun a b => a == b) s +def popTest (n : Nat) (p : Nat → Nat → Bool) (s : MyArray) : IO MyArray := +n.mfold (λ i s => do + -- IO.println i; + check (n - i) p s; + let s := s.pop; + -- IO.println s.stats; + -- IO.println ("size: " ++ toString s.size ++ ", tailOff " ++ toString s.tailOff ++ ", shift: " ++ toString s.shift); + -- IO.println s.tail; + pure s) + s + def main (xs : List String) : IO Unit := do let n := xs.head.toNat; @@ -25,4 +36,15 @@ let t := inc1 n t; check n (λ i v => v == i + 1) t; IO.println t.size; IO.println t.stats; +IO.println "popping..."; +t ← popTest (n - 33) (λ i v => v == i + 1) t; +IO.println t.size; +check 33 (λ i v => v == i + 1) t; +let t : MyArray := (1000 : Nat).fold (fun i s => s.push i) t; +check t.size (λ i v => if i < 33 then v == i + 1 else v == i - 33) t; +IO.println t.size; +IO.println t.stats; +t ← popTest t.size (λ i v => if i < 33 then v == i + 1 else v == i - 33) t; +IO.println t.size; +IO.println t.stats; pure ()