diff --git a/library/init/data/array/basic.lean b/library/init/data/array/basic.lean index 902c646412..9e6ec02b6b 100644 --- a/library/init/data/array/basic.lean +++ b/library/init/data/array/basic.lean @@ -124,6 +124,26 @@ miterateAux a f 0 b @[inline] def mfoldl (a : Array α) (b : β) (f : α → β → m β) : m β := miterate a b (λ _, f) + +@[inline] def mfoldlFrom (a : Array α) (b : β) (f : α → β → m β) (ini : Nat := 0) : m β := +miterateAux a (λ _, f) ini b + +local attribute [instance] monadInhabited + +-- TODO(Leo): justify termination using wf-rec +@[specialize] partial def mfindAux (a : Array α) (f : α → m (Option β)) : Nat → m (Option β) +| i := + if h : i < a.sz then + let idx : Fin a.sz := ⟨i, h⟩ in + do r ← f (a.index idx), + (match r with + | some v := pure r + | none := mfindAux (i+1)) + else pure none + +@[inline] def mfind (a : Array α) (f : α → m (Option β)) : m (Option β) := +mfindAux a f 0 + end @[inline] def iterate (a : Array α) (b : β) (f : Π i : Fin a.sz, α → β → β) : β := @@ -132,6 +152,12 @@ Id.run $ miterateAux a f 0 b @[inline] def foldl (a : Array α) (f : α → β → β) (b : β) : β := iterate a b (λ _, f) +@[inline] def foldlFrom (a : Array α) (f : α → β → β) (b : β) (ini : Nat := 0) : β := +Id.run $ mfoldlFrom a b f ini + +@[inline] def find (a : Array α) (f : α → Option β) : Option β := +Id.run $ mfindAux a f 0 + @[specialize] private def revIterateAux (a : Array α) (f : Π i : Fin a.sz, α → β → β) : Π (i : Nat), i ≤ a.sz → β → β | 0 h b := b | (j+1) h b := @@ -175,14 +201,18 @@ Id.run $ mforeach a f theorem szForeachEq (a : Array α) (f : Π i : Fin a.sz, α → α) : (foreach a f).sz = a.sz := (Id.run $ mforeachAux a f).property -@[inline] def map (f : α → α) (a : Array α) : Array α := +/- Homogeneous map -/ +@[inline] def hmap (f : α → α) (a : Array α) : Array α := foreach a (λ _, f) -@[inline] def map₂ (f : α → α → α) (a b : Array α) : Array α := +@[inline] def hmap₂ (f : α → α → α) (a b : Array α) : Array α := if h : a.size ≤ b.size then foreach a (λ ⟨i, h'⟩, f (b.index ⟨i, Nat.ltOfLtOfLe h' h⟩)) else foreach b (λ ⟨i, h'⟩, f (a.index ⟨i, Nat.ltTrans h' (Nat.gtOfNotLe h)⟩)) +def map (f : α → β) (as : Array α) : Array β := +as.foldl (λ a bs, bs.push (f a)) (mkEmpty as.sz) + end Array export Array (mkArray) diff --git a/tests/compiler/array_test.lean b/tests/compiler/array_test.lean index 3aa48bb9d3..e080cb3d64 100644 --- a/tests/compiler/array_test.lean +++ b/tests/compiler/array_test.lean @@ -12,7 +12,7 @@ do IO.println (toString a.sz), let a := foo a, IO.println (toString a), - let a := a.map (+10), + let a := a.hmap (+10), IO.println (toString a), IO.println (toString a.sz), let a1 := a.pop, @@ -21,4 +21,5 @@ do IO.println (toString a2), let a2 := a.pop, IO.println a2, + IO.println $ (([1, 2, 3, 4].toArray).hmap (+2)).map toString, pure 0 diff --git a/tests/compiler/array_test.lean.expected.out b/tests/compiler/array_test.lean.expected.out index e9dbce964e..c58dc4632c 100644 --- a/tests/compiler/array_test.lean.expected.out +++ b/tests/compiler/array_test.lean.expected.out @@ -6,3 +6,4 @@ [10, 11, 12] [10, 11, 12, 13, 100] [10, 11, 12] +[3, 4, 5, 6]