diff --git a/src/Init/Core.lean b/src/Init/Core.lean index 52c20d8034..6f1c3d35ce 100644 --- a/src/Init/Core.lean +++ b/src/Init/Core.lean @@ -1152,6 +1152,8 @@ instance : Inhabited PNonScalar.{u} := ⟨⟨arbitrary _⟩⟩ instance : Inhabited PointedType := ⟨{type := PUnit, val := ⟨⟩}⟩ +instance {α} [Inhabited α] : Inhabited (ForInStep α) := ⟨ForInStep.done (arbitrary _)⟩ + class inductive Nonempty (α : Sort u) : Prop | intro (val : α) : Nonempty diff --git a/src/Std/Data/PersistentArray.lean b/src/Std/Data/PersistentArray.lean index 405925f9b1..b9ebc17820 100644 --- a/src/Std/Data/PersistentArray.lean +++ b/src/Std/Data/PersistentArray.lean @@ -198,6 +198,33 @@ variable {β : Type v} let b ← foldlMAux f t.root init t.tail.foldlM f b +@[specialize] +partial def forInAux {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] [inh : Inhabited β] + (f : α → β → m (ForInStep β)) (n : PersistentArrayNode α) (b : β) : m (ForInStep β) := do +match n with +| leaf vs => + for v in vs do + match (← f v b) with + | r@(ForInStep.done b) => return r + | ForInStep.yield bNew => b := bNew + return ForInStep.yield b +| node cs => + for c in cs do + match (← forInAux f c b) with + | r@(ForInStep.done b) => return r + | ForInStep.yield bNew => b := bNew + return ForInStep.yield b + +@[specialize] def forIn (t : PersistentArray α) (init : β) (f : α → β → m (ForInStep β)) : m β := do +match (← forInAux (inh := ⟨init⟩) f t.root init) with +| ForInStep.done b => b +| ForInStep.yield b => + for v in t.tail do + match (← f v b) with + | ForInStep.done b => return b + | ForInStep.yield bNew => b := bNew + return b + @[specialize] partial def findSomeMAux (f : α → m (Option β)) : PersistentArrayNode α → m (Option β) | node cs => cs.findSomeM? (fun c => findSomeMAux f c) | leaf vs => vs.findSomeM? f diff --git a/tests/lean/run/forInPArray.lean b/tests/lean/run/forInPArray.lean new file mode 100644 index 0000000000..19b83dba06 --- /dev/null +++ b/tests/lean/run/forInPArray.lean @@ -0,0 +1,34 @@ +#lang lean4 +import Std + +def check (x : IO Nat) (expected : IO Nat) : IO Unit := do +unless (← x) == (← expected) do + throw $ IO.userError "unexpected result" + +def f1 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do +let sum := 0 +for x in xs do + if x % 2 == 0 then + IO.println s!"x: {x}" + sum := sum + x + if sum > top then + return sum +IO.println s!"sum: {sum}" +return sum + +#eval f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10 + +#eval check (f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10) (pure 16) + +def f2 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do +let sum := 0 +for x in xs do + if x % 2 == 0 then + IO.println s!"x: {x}" + sum := sum + x + if sum > top then + break +IO.println s!"sum: {sum}" +return sum + +#eval check (f1 (List.iota 100).toPersistentArray 1000) (f2 (List.iota 100).toPersistentArray 1000)