diff --git a/src/Init/Control/Lawful/Lemmas.lean b/src/Init/Control/Lawful/Lemmas.lean index aaf8ecc472..5dfdd1e6e6 100644 --- a/src/Init/Control/Lawful/Lemmas.lean +++ b/src/Init/Control/Lawful/Lemmas.lean @@ -9,25 +9,49 @@ import Init.RCases import Init.ByCases -- Mapping by a function with a left inverse is injective. -theorem map_inj_of_left_inverse [Applicative m] [LawfulApplicative m] {f : α → β} - (w : ∃ g : β → α, ∀ x, g (f x) = x) {x y : m α} - (h : f <$> x = f <$> y) : x = y := by - rcases w with ⟨g, w⟩ - replace h := congrArg (g <$> ·) h - simpa [w] using h +theorem map_inj_of_left_inverse [Functor m] [LawfulFunctor m] {f : α → β} + (w : ∃ g : β → α, ∀ x, g (f x) = x) {x y : m α} : + f <$> x = f <$> y ↔ x = y := by + constructor + · intro h + rcases w with ⟨g, w⟩ + replace h := congrArg (g <$> ·) h + simpa [w] using h + · rintro rfl + rfl -- Mapping by an injective function is injective, as long as the domain is nonempty. -theorem map_inj_of_inj [Applicative m] [LawfulApplicative m] [Nonempty α] {f : α → β} - (w : ∀ x y, f x = f y → x = y) {x y : m α} - (h : f <$> x = f <$> y) : x = y := by - apply map_inj_of_left_inverse ?_ h - let ⟨a⟩ := ‹Nonempty α› - refine ⟨?_, ?_⟩ - · intro b - by_cases p : ∃ a, f a = b - · exact Exists.choose p - · exact a - · intro b - simp only [exists_apply_eq_apply, ↓reduceDIte] - apply w - apply Exists.choose_spec (p := fun a => f a = f b) +@[simp] theorem map_inj_right_of_nonempty [Functor m] [LawfulFunctor m] [Nonempty α] {f : α → β} + (w : ∀ {x y}, f x = f y → x = y) {x y : m α} : + f <$> x = f <$> y ↔ x = y := by + constructor + · intro h + apply (map_inj_of_left_inverse ?_).mp h + let ⟨a⟩ := ‹Nonempty α› + refine ⟨?_, ?_⟩ + · intro b + by_cases p : ∃ a, f a = b + · exact Exists.choose p + · exact a + · intro b + simp only [exists_apply_eq_apply, ↓reduceDIte] + apply w + apply Exists.choose_spec (p := fun a => f a = f b) + · rintro rfl + rfl + +@[simp] theorem map_inj_right [Monad m] [LawfulMonad m] + {f : α → β} (h : ∀ {x y : α}, f x = f y → x = y) {x y : m α} : + f <$> x = f <$> y ↔ x = y := by + by_cases hempty : Nonempty α + · exact map_inj_right_of_nonempty h + · constructor + · intro h' + have (z : m α) : z = (do let a ← z; let b ← pure (f a); x) := by + conv => lhs; rw [← bind_pure z] + congr; funext a + exact (hempty ⟨a⟩).elim + rw [this x, this y] + rw [← bind_assoc, ← map_eq_pure_bind, h', map_eq_pure_bind, bind_assoc] + · intro h' + rw [h'] diff --git a/src/Init/Data/Array/MapIdx.lean b/src/Init/Data/Array/MapIdx.lean index e6dd3e63f8..9e4c002eb0 100644 --- a/src/Init/Data/Array/MapIdx.lean +++ b/src/Init/Data/Array/MapIdx.lean @@ -434,3 +434,57 @@ theorem mapIdx_eq_mkArray_iff {l : Array α} {f : Nat → α → β} {b : β} : simp [List.mapIdx_reverse] end Array + +namespace List + +theorem mapFinIdxM_toArray [Monad m] [LawfulMonad m] (l : List α) + (f : (i : Nat) → α → (h : i < l.length) → m β) : + l.toArray.mapFinIdxM f = toArray <$> l.mapFinIdxM f := by + let rec go (i : Nat) (acc : Array β) (inv : i + acc.size = l.length) : + Array.mapFinIdxM.map l.toArray f i acc.size inv acc + = toArray <$> mapFinIdxM.go l f (l.drop acc.size) acc + (by simp [Nat.sub_add_cancel (Nat.le.intro (Nat.add_comm _ _ ▸ inv))]) := by + match i with + | 0 => + rw [Nat.zero_add] at inv + simp only [Array.mapFinIdxM.map, inv, drop_length, mapFinIdxM.go, map_pure] + | k + 1 => + conv => enter [2, 2, 3]; rw [← getElem_cons_drop l acc.size (by omega)] + simp only [Array.mapFinIdxM.map, mapFinIdxM.go, _root_.map_bind] + congr; funext x + conv => enter [1, 4]; rw [← Array.size_push _ x] + conv => enter [2, 2, 3]; rw [← Array.size_push _ x] + refine go k (acc.push x) _ + simp only [Array.mapFinIdxM, mapFinIdxM] + exact go _ #[] _ + +theorem mapIdxM_toArray [Monad m] [LawfulMonad m] (l : List α) + (f : Nat → α → m β) : + l.toArray.mapIdxM f = toArray <$> l.mapIdxM f := by + let rec go (bs : List α) (acc : Array β) (inv : bs.length + acc.size = l.length) : + mapFinIdxM.go l (fun i a h => f i a) bs acc inv = mapIdxM.go f bs acc := by + match bs with + | [] => simp only [mapFinIdxM.go, mapIdxM.go] + | x :: xs => simp only [mapFinIdxM.go, mapIdxM.go, go] + unfold Array.mapIdxM + rw [mapFinIdxM_toArray] + simp only [mapFinIdxM, mapIdxM] + rw [go] + +end List + +namespace Array + +theorem toList_mapFinIdxM [Monad m] [LawfulMonad m] (l : Array α) + (f : (i : Nat) → α → (h : i < l.size) → m β) : + toList <$> l.mapFinIdxM f = l.toList.mapFinIdxM f := by + rw [List.mapFinIdxM_toArray] + simp only [Functor.map_map, id_map'] + +theorem toList_mapIdxM [Monad m] [LawfulMonad m] (l : Array α) + (f : Nat → α → m β) : + toList <$> l.mapIdxM f = l.toList.mapIdxM f := by + rw [List.mapIdxM_toArray] + simp only [Functor.map_map, id_map'] + +end Array diff --git a/src/Init/Data/Array/Monadic.lean b/src/Init/Data/Array/Monadic.lean index b3428aa201..9da57e8ebe 100644 --- a/src/Init/Data/Array/Monadic.lean +++ b/src/Init/Data/Array/Monadic.lean @@ -221,6 +221,121 @@ theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] cases l simp +end Array + +namespace List + +theorem filterM_toArray [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) : + l.toArray.filterM p = toArray <$> l.filterM p := by + simp only [Array.filterM, filterM, foldlM_toArray, bind_pure_comp, Functor.map_map] + conv => lhs; rw [← reverse_nil] + generalize [] = acc + induction l generalizing acc with simp + | cons x xs ih => + congr; funext b + cases b + · simp only [Bool.false_eq_true, ↓reduceIte, pure_bind, cond_false] + exact ih acc + · simp only [↓reduceIte, ← reverse_cons, pure_bind, cond_true] + exact ih (x :: acc) + +/-- Variant of `filterM_toArray` with a side condition for the stop position. -/ +@[simp] theorem filterM_toArray' [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) (w : stop = l.length) : + l.toArray.filterM p 0 stop = toArray <$> l.filterM p := by + subst w + rw [filterM_toArray] + +theorem filterRevM_toArray [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) : + l.toArray.filterRevM p = toArray <$> l.filterRevM p := by + simp [Array.filterRevM, filterRevM] + rw [← foldlM_reverse, ← foldlM_toArray, ← Array.filterM, filterM_toArray] + simp only [filterM, bind_pure_comp, Functor.map_map, reverse_toArray, reverse_reverse] + +/-- Variant of `filterRevM_toArray` with a side condition for the start position. -/ +@[simp] theorem filterRevM_toArray' [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) (w : start = l.length) : + l.toArray.filterRevM p start 0 = toArray <$> l.filterRevM p := by + subst w + rw [filterRevM_toArray] + +theorem filterMapM_toArray [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Option β)) : + l.toArray.filterMapM f = toArray <$> l.filterMapM f := by + simp [Array.filterMapM, filterMapM] + conv => lhs; rw [← reverse_nil] + generalize [] = acc + induction l generalizing acc with simp [filterMapM.loop] + | cons x xs ih => + congr; funext o + cases o + · simp only [pure_bind]; exact ih acc + · simp only [pure_bind]; rw [← List.reverse_cons]; exact ih _ + +/-- Variant of `filterMapM_toArray` with a side condition for the stop position. -/ +@[simp] theorem filterMapM_toArray' [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Option β)) (w : stop = l.length) : + l.toArray.filterMapM f 0 stop = toArray <$> l.filterMapM f := by + subst w + rw [filterMapM_toArray] + +@[simp] theorem flatMapM_toArray [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Array β)) : + l.toArray.flatMapM f = toArray <$> l.flatMapM (fun a => Array.toList <$> f a) := by + simp only [Array.flatMapM, bind_pure_comp, foldlM_toArray, flatMapM] + conv => lhs; arg 2; change [].reverse.flatten.toArray + generalize [] = acc + induction l generalizing acc with + | nil => simp only [foldlM_nil, flatMapM.loop, map_pure] + | cons x xs ih => + simp only [foldlM_cons, bind_map_left, flatMapM.loop, _root_.map_bind] + congr; funext a + conv => lhs; rw [Array.toArray_append, ← flatten_concat, ← reverse_cons] + exact ih _ + +end List + +namespace Array + +@[congr] theorem filterM_congr [Monad m] {as bs : Array α} (w : as = bs) + {p : α → m Bool} {q : α → m Bool} (h : ∀ a, p a = q a) : + as.filterM p = bs.filterM q := by + subst w + simp [filterM, h] + +@[congr] theorem filterRevM_congr [Monad m] {as bs : Array α} (w : as = bs) + {p : α → m Bool} {q : α → m Bool} (h : ∀ a, p a = q a) : + as.filterRevM p = bs.filterRevM q := by + subst w + simp [filterRevM, h] + +@[congr] theorem filterMapM_congr [Monad m] {as bs : Array α} (w : as = bs) + {f : α → m (Option β)} {g : α → m (Option β)} (h : ∀ a, f a = g a) : + as.filterMapM f = bs.filterMapM g := by + subst w + simp [filterMapM, h] + +@[congr] theorem flatMapM_congr [Monad m] {as bs : Array α} (w : as = bs) + {f : α → m (Array β)} {g : α → m (Array β)} (h : ∀ a, f a = g a) : + as.flatMapM f = bs.flatMapM g := by + subst w + simp [flatMapM, h] + +theorem toList_filterM [Monad m] [LawfulMonad m] (a : Array α) (p : α → m Bool) : + toList <$> a.filterM p = a.toList.filterM p := by + rw [List.filterM_toArray] + simp only [Functor.map_map, id_map'] + +theorem toList_filterRevM [Monad m] [LawfulMonad m] (a : Array α) (p : α → m Bool) : + toList <$> a.filterRevM p = a.toList.filterRevM p := by + rw [List.filterRevM_toArray] + simp only [Functor.map_map, id_map'] + +theorem toList_filterMapM [Monad m] [LawfulMonad m] (a : Array α) (f : α → m (Option β)) : + toList <$> a.filterMapM f = a.toList.filterMapM f := by + rw [List.filterMapM_toArray] + simp only [Functor.map_map, id_map'] + +theorem toList_flatMapM [Monad m] [LawfulMonad m] (a : Array α) (f : α → m (Array β)) : + toList <$> a.flatMapM f = a.toList.flatMapM (fun a => toList <$> f a) := by + rw [List.flatMapM_toArray] + simp only [Functor.map_map, id_map'] + /-! ### Recognizing higher order functions using a function that only depends on the value. -/ /-- @@ -260,20 +375,20 @@ and simplifies these to the function directly taking the value. simp rw [List.mapM_subtype hf] --- Without `filterMapM_toArray` relating `filterMapM` on `List` and `Array` we can't prove this yet: --- @[simp] theorem filterMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }} --- {f : { x // p x } → m (Option β)} {g : α → m (Option β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) : --- l.filterMapM f = l.unattach.filterMapM g := by --- rcases l with ⟨l⟩ --- simp --- rw [List.filterMapM_subtype hf] +@[simp] theorem filterMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }} + {f : { x // p x } → m (Option β)} {g : α → m (Option β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) (w : stop = l.size) : + l.filterMapM f 0 stop = l.unattach.filterMapM g := by + subst w + rcases l with ⟨l⟩ + simp + rw [List.filterMapM_subtype hf] --- Without `flatMapM_toArray` relating `flatMapM` on `List` and `Array` we can't prove this yet: --- @[simp] theorem flatMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }} --- {f : { x // p x } → m (Array β)} {g : α → m (Array β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) : --- (l.flatMapM f) = l.unattach.flatMapM g := by --- rcases l with ⟨l⟩ --- simp --- rw [List.flatMapM_subtype hf] +@[simp] theorem flatMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }} + {f : { x // p x } → m (Array β)} {g : α → m (Array β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) : + (l.flatMapM f) = l.unattach.flatMapM g := by + rcases l with ⟨l⟩ + simp + rw [List.flatMapM_subtype] + simp [hf] end Array diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index 8b7d1a35c5..91c224dc2e 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -128,7 +128,7 @@ Applies the monadic function `f` on every element `x` in the list, left-to-right results `y` for which `f x` returns `some y`. -/ @[inline] -def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m (Option β)) (as : List α) : m (List β) := +def filterMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (Option β)) (as : List α) : m (List β) := let rec @[specialize] loop | [], bs => pure bs.reverse | a :: as, bs => do diff --git a/src/Init/Data/Vector/MapIdx.lean b/src/Init/Data/Vector/MapIdx.lean index 4c8c878e65..83ff181cc0 100644 --- a/src/Init/Data/Vector/MapIdx.lean +++ b/src/Init/Data/Vector/MapIdx.lean @@ -363,4 +363,25 @@ theorem mapIdx_eq_mkVector_iff {l : Vector α n} {f : Nat → α → β} {b : β rcases l with ⟨l, rfl⟩ simp [Array.mapIdx_reverse] +theorem toArray_mapFinIdxM [Monad m] [LawfulMonad m] + (a : Vector α n) (f : (i : Nat) → α → (h : i < n) → m β) : + toArray <$> a.mapFinIdxM f = a.toArray.mapFinIdxM + (fun i x h => f i x (size_toArray a ▸ h)) := by + let rec go (i j : Nat) (inv : i + j = n) (bs : Vector β (n - i)) : + toArray <$> mapFinIdxM.map a f i j inv bs + = Array.mapFinIdxM.map a.toArray (fun i x h => f i x (size_toArray a ▸ h)) + i j (size_toArray _ ▸ inv) bs.toArray := by + match i with + | 0 => simp only [mapFinIdxM.map, map_pure, Array.mapFinIdxM.map, Nat.sub_zero] + | k + 1 => + simp only [mapFinIdxM.map, map_bind, Array.mapFinIdxM.map, getElem_toArray] + conv => lhs; arg 2; intro; rw [go] + rfl + simp only [mapFinIdxM, Array.mapFinIdxM, size_toArray] + exact go _ _ _ _ + +theorem toArray_mapIdxM [Monad m] [LawfulMonad m] (a : Vector α n) (f : Nat → α → m β) : + toArray <$> a.mapIdxM f = a.toArray.mapIdxM f := by + exact toArray_mapFinIdxM _ _ + end Vector diff --git a/src/Init/Data/Vector/Monadic.lean b/src/Init/Data/Vector/Monadic.lean index 6230a721ae..39e5f7e78f 100644 --- a/src/Init/Data/Vector/Monadic.lean +++ b/src/Init/Data/Vector/Monadic.lean @@ -19,11 +19,10 @@ open Nat /-! ## Monadic operations -/ -theorem map_toArray_inj [Monad m] [LawfulMonad m] [Nonempty α] - {v₁ : m (Vector α n)} {v₂ : m (Vector α n)} (w : toArray <$> v₁ = toArray <$> v₂) : - v₁ = v₂ := by - apply map_inj_of_inj ?_ w - simp +@[simp] theorem map_toArray_inj [Monad m] [LawfulMonad m] + {v₁ : m (Vector α n)} {v₂ : m (Vector α n)} : + toArray <$> v₁ = toArray <$> v₂ ↔ v₁ = v₂ := + _root_.map_inj_right (by simp) /-! ### mapM -/ @@ -39,11 +38,10 @@ theorem map_toArray_inj [Monad m] [LawfulMonad m] [Nonempty α] unfold mapM.go simp --- The `[Nonempty β]` hypothesis should be avoidable by unfolding `mapM` directly. -@[simp] theorem mapM_append [Monad m] [LawfulMonad m] [Nonempty β] +@[simp] theorem mapM_append [Monad m] [LawfulMonad m] (f : α → m β) {l₁ : Vector α n} {l₂ : Vector α n'} : (l₁ ++ l₂).mapM f = (return (← l₁.mapM f) ++ (← l₂.mapM f)) := by - apply map_toArray_inj + apply map_toArray_inj.mp suffices toArray <$> (l₁ ++ l₂).mapM f = (return (← toArray <$> l₁.mapM f) ++ (← toArray <$> l₂.mapM f)) by rw [this] simp only [bind_pure_comp, Functor.map_map, bind_map_left, map_bind, toArray_append]