feat: forIn for PersistentArray

This commit is contained in:
Leonardo de Moura 2020-10-20 06:09:00 -07:00
parent 702ceb7a3f
commit fac2849e50
3 changed files with 63 additions and 0 deletions

View file

@ -1152,6 +1152,8 @@ instance : Inhabited PNonScalar.{u} := ⟨⟨arbitrary _⟩⟩
instance : Inhabited PointedType := ⟨{type := PUnit, val := ⟨⟩}⟩
instance {α} [Inhabited α] : Inhabited (ForInStep α) := ⟨ForInStep.done (arbitrary _)⟩
class inductive Nonempty (α : Sort u) : Prop
| intro (val : α) : Nonempty

View file

@ -198,6 +198,33 @@ variable {β : Type v}
let b ← foldlMAux f t.root init
t.tail.foldlM f b
@[specialize]
partial def forInAux {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] [inh : Inhabited β]
(f : α → β → m (ForInStep β)) (n : PersistentArrayNode α) (b : β) : m (ForInStep β) := do
match n with
| leaf vs =>
for v in vs do
match (← f v b) with
| r@(ForInStep.done b) => return r
| ForInStep.yield bNew => b := bNew
return ForInStep.yield b
| node cs =>
for c in cs do
match (← forInAux f c b) with
| r@(ForInStep.done b) => return r
| ForInStep.yield bNew => b := bNew
return ForInStep.yield b
@[specialize] def forIn (t : PersistentArray α) (init : β) (f : α → β → m (ForInStep β)) : m β := do
match (← forInAux (inh := ⟨init⟩) f t.root init) with
| ForInStep.done b => b
| ForInStep.yield b =>
for v in t.tail do
match (← f v b) with
| ForInStep.done b => return b
| ForInStep.yield bNew => b := bNew
return b
@[specialize] partial def findSomeMAux (f : α → m (Option β)) : PersistentArrayNode α → m (Option β)
| node cs => cs.findSomeM? (fun c => findSomeMAux f c)
| leaf vs => vs.findSomeM? f

View file

@ -0,0 +1,34 @@
#lang lean4
import Std
def check (x : IO Nat) (expected : IO Nat) : IO Unit := do
unless (← x) == (← expected) do
throw $ IO.userError "unexpected result"
def f1 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
let sum := 0
for x in xs do
if x % 2 == 0 then
IO.println s!"x: {x}"
sum := sum + x
if sum > top then
return sum
IO.println s!"sum: {sum}"
return sum
#eval f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10
#eval check (f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10) (pure 16)
def f2 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
let sum := 0
for x in xs do
if x % 2 == 0 then
IO.println s!"x: {x}"
sum := sum + x
if sum > top then
break
IO.println s!"sum: {sum}"
return sum
#eval check (f1 (List.iota 100).toPersistentArray 1000) (f2 (List.iota 100).toPersistentArray 1000)