feat: mark forIn_pure_yield lemmas simp (#7433)

This PR makes `simp` able to simplify basic `for` loops in monads other
than `Id`.

This is some prework for #7352, where the `Id` lemmas will be
deprecated.
This commit is contained in:
Eric Wieser 2025-03-14 00:28:23 +00:00 committed by GitHub
parent 07ee2eea21
commit 5c333d88c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 17 additions and 8 deletions

View file

@ -169,7 +169,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m]
rcases xs with ⟨xs⟩
simp [List.foldlM_map]
theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(xs : Array α) (f : (a : α) → a ∈ xs → β → β) (init : β) :
forIn' xs init (fun a m b => pure (.yield (f a m b))) =
pure (f := m) (xs.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by
@ -211,7 +211,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m]
rcases xs with ⟨xs⟩
simp [List.foldlM_map]
theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(xs : Array α) (f : α → β → β) (init : β) :
forIn xs init (fun a b => pure (.yield (f a b))) =
pure (f := m) (xs.foldl (fun b a => f a b) init) := by

View file

@ -330,7 +330,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m]
simp only [forIn'_eq_foldlM]
induction l.attach generalizing init <;> simp_all
theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(l : List α) (f : (a : α) → a ∈ l → β → β) (init : β) :
forIn' l init (fun a m b => pure (.yield (f a m b))) =
pure (f := m) (l.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by
@ -383,7 +383,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m]
simp only [forIn_eq_foldlM]
induction l generalizing init <;> simp_all
theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(l : List α) (f : α → β → β) (init : β) :
forIn l init (fun a b => pure (.yield (f a b))) =
pure (f := m) (l.foldl (fun b a => f a b) init) := by

View file

@ -46,7 +46,7 @@ theorem forIn'_eq_pelim [Monad m] [LawfulMonad m]
o.pelim (pure b) (fun a h => g a h b <$> f a h b) := by
cases o <;> simp
theorem forIn'_pure_yield_eq_pelim [Monad m] [LawfulMonad m]
@[simp] theorem forIn'_pure_yield_eq_pelim [Monad m] [LawfulMonad m]
(o : Option α) (f : (a : α) → a ∈ o → β → β) (b : β) :
forIn' o b (fun a m b => pure (.yield (f a m b))) =
pure (f := m) (o.pelim b (fun a h => f a h b)) := by
@ -75,7 +75,7 @@ theorem forIn_eq_elim [Monad m] [LawfulMonad m]
o.elim (pure b) (fun a => g a b <$> f a b) := by
cases o <;> simp
theorem forIn_pure_yield_eq_elim [Monad m] [LawfulMonad m]
@[simp] theorem forIn_pure_yield_eq_elim [Monad m] [LawfulMonad m]
(o : Option α) (f : (a : α) → β → β) (b : β) :
forIn o b (fun a b => pure (.yield (f a b))) =
pure (f := m) (o.elim b (fun a => f a b)) := by

View file

@ -159,7 +159,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m]
rcases xs with ⟨xs, rfl⟩
simp
theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(xs : Vector α n) (f : (a : α) → a ∈ xs → β → β) (init : β) :
forIn' xs init (fun a m b => pure (.yield (f a m b))) =
pure (f := m) (xs.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by
@ -201,7 +201,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m]
rcases xs with ⟨xs, rfl⟩
simp
theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
(xs : Vector α n) (f : α → β → β) (init : β) :
forIn xs init (fun a b => pure (.yield (f a b))) =
pure (f := m) (xs.foldl (fun b a => f a b) init) := by

View file

@ -492,6 +492,15 @@ variable (l : List α) (k m : Nat) in
x := x + k
pure x) ~> m + k * l.length
-- as above, but for an arbitrary monad
variable (l : List α) (k m : Nat) {M} [Monad M] [LawfulMonad M] in
#check_simp
(show M _ from do
let mut x := m
for _ in l do
x := x + k
pure x) ~> pure (m + k * l.length)
/-! ### mapM -/
/-! ### forM -/