feat: forIn for PersistentArray
This commit is contained in:
parent
702ceb7a3f
commit
fac2849e50
3 changed files with 63 additions and 0 deletions
|
|
@ -1152,6 +1152,8 @@ instance : Inhabited PNonScalar.{u} := ⟨⟨arbitrary _⟩⟩
|
|||
|
||||
instance : Inhabited PointedType := ⟨{type := PUnit, val := ⟨⟩}⟩
|
||||
|
||||
instance {α} [Inhabited α] : Inhabited (ForInStep α) := ⟨ForInStep.done (arbitrary _)⟩
|
||||
|
||||
class inductive Nonempty (α : Sort u) : Prop
|
||||
| intro (val : α) : Nonempty
|
||||
|
||||
|
|
|
|||
|
|
@ -198,6 +198,33 @@ variable {β : Type v}
|
|||
let b ← foldlMAux f t.root init
|
||||
t.tail.foldlM f b
|
||||
|
||||
@[specialize]
|
||||
partial def forInAux {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] [inh : Inhabited β]
|
||||
(f : α → β → m (ForInStep β)) (n : PersistentArrayNode α) (b : β) : m (ForInStep β) := do
|
||||
match n with
|
||||
| leaf vs =>
|
||||
for v in vs do
|
||||
match (← f v b) with
|
||||
| r@(ForInStep.done b) => return r
|
||||
| ForInStep.yield bNew => b := bNew
|
||||
return ForInStep.yield b
|
||||
| node cs =>
|
||||
for c in cs do
|
||||
match (← forInAux f c b) with
|
||||
| r@(ForInStep.done b) => return r
|
||||
| ForInStep.yield bNew => b := bNew
|
||||
return ForInStep.yield b
|
||||
|
||||
@[specialize] def forIn (t : PersistentArray α) (init : β) (f : α → β → m (ForInStep β)) : m β := do
|
||||
match (← forInAux (inh := ⟨init⟩) f t.root init) with
|
||||
| ForInStep.done b => b
|
||||
| ForInStep.yield b =>
|
||||
for v in t.tail do
|
||||
match (← f v b) with
|
||||
| ForInStep.done b => return b
|
||||
| ForInStep.yield bNew => b := bNew
|
||||
return b
|
||||
|
||||
@[specialize] partial def findSomeMAux (f : α → m (Option β)) : PersistentArrayNode α → m (Option β)
|
||||
| node cs => cs.findSomeM? (fun c => findSomeMAux f c)
|
||||
| leaf vs => vs.findSomeM? f
|
||||
|
|
|
|||
34
tests/lean/run/forInPArray.lean
Normal file
34
tests/lean/run/forInPArray.lean
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#lang lean4
|
||||
import Std
|
||||
|
||||
def check (x : IO Nat) (expected : IO Nat) : IO Unit := do
|
||||
unless (← x) == (← expected) do
|
||||
throw $ IO.userError "unexpected result"
|
||||
|
||||
def f1 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
|
||||
let sum := 0
|
||||
for x in xs do
|
||||
if x % 2 == 0 then
|
||||
IO.println s!"x: {x}"
|
||||
sum := sum + x
|
||||
if sum > top then
|
||||
return sum
|
||||
IO.println s!"sum: {sum}"
|
||||
return sum
|
||||
|
||||
#eval f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10
|
||||
|
||||
#eval check (f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10) (pure 16)
|
||||
|
||||
def f2 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
|
||||
let sum := 0
|
||||
for x in xs do
|
||||
if x % 2 == 0 then
|
||||
IO.println s!"x: {x}"
|
||||
sum := sum + x
|
||||
if sum > top then
|
||||
break
|
||||
IO.println s!"sum: {sum}"
|
||||
return sum
|
||||
|
||||
#eval check (f1 (List.iota 100).toPersistentArray 1000) (f2 (List.iota 100).toPersistentArray 1000)
|
||||
Loading…
Add table
Reference in a new issue