diff --git a/src/Std/Data/PersistentArray.lean b/src/Std/Data/PersistentArray.lean index c9c2701e51..1c0fa5f7d8 100644 --- a/src/Std/Data/PersistentArray.lean +++ b/src/Std/Data/PersistentArray.lean @@ -202,9 +202,16 @@ variable {β : Type v} else if start >= t.tailOff then t.tail.foldlM (init := init) (start := start - t.tailOff) f else do - let b ← foldlFromMAux f t.root (USize.ofNat start) t.shift init; + let b ← foldlFromMAux f t.root (USize.ofNat start) t.shift init t.tail.foldlM f b +@[specialize] private partial def foldrMAux [Monad m] (f : α → β → m β) : PersistentArrayNode α → β → m β + | node cs, b => cs.foldrM (fun c b => foldrMAux f c b) b + | leaf vs, b => vs.foldrM f b + +@[specialize] def foldrM [Monad m] (t : PersistentArray α) (f : α → β → m β) (init : β) : m β := do + foldrMAux f t.root (← t.tail.foldrM f init) + @[specialize] partial def forInAux {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] [inh : Inhabited β] (f : α → β → m (ForInStep β)) (n : PersistentArrayNode α) (b : β) : m (ForInStep β) := do @@ -264,8 +271,11 @@ instance : ForIn m (PersistentArray α) α where end -@[inline] def foldl {β} (t : PersistentArray α) (f : β → α → β) (init : β) (start : Nat := 0) : β := - Id.run $ t.foldlM f init start +@[inline] def foldl (t : PersistentArray α) (f : β → α → β) (init : β) (start : Nat := 0) : β := + Id.run <| t.foldlM f init start + +@[inline] def foldr (t : PersistentArray α) (f : α → β → β) (init : β) : β := + Id.run <| t.foldrM f init @[inline] def filter (as : PersistentArray α) (p : α → Bool) : PersistentArray α := as.foldl (init := {}) fun asNew a => if p a then asNew.push a else asNew diff --git a/tests/lean/run/parray1.lean b/tests/lean/run/parray1.lean new file mode 100644 index 0000000000..fe2112fdce --- /dev/null +++ b/tests/lean/run/parray1.lean @@ -0,0 +1,15 @@ +import Std.Data.PersistentArray + +def check [BEq α] (as : List α) : Bool := + as.toPersistentArray.foldr (.::.) [] == as + +def tst1 : IO Unit := do + assert! check [1, 2, 3] + assert! check ([] : List Nat) + assert! check (List.iota 17) + assert! check (List.iota 533) + assert! check (List.iota 1000) + assert! check (List.iota 2600) + IO.println "done" + +#eval tst1