From 5c333d88c01f71180f1ceeb00fcdad2bea1bdbf9 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Fri, 14 Mar 2025 00:28:23 +0000 Subject: [PATCH] feat: mark `forIn_pure_yield` lemmas simp (#7433) This PR makes `simp` able to simplify basic `for` loops in monads other than `Id`. This is some prework for #7352, where the `Id` lemmas will be deprecated. --- src/Init/Data/Array/Monadic.lean | 4 ++-- src/Init/Data/List/Monadic.lean | 4 ++-- src/Init/Data/Option/Monadic.lean | 4 ++-- src/Init/Data/Vector/Monadic.lean | 4 ++-- tests/lean/run/list_simp.lean | 9 +++++++++ 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/Init/Data/Array/Monadic.lean b/src/Init/Data/Array/Monadic.lean index ed3e58e3b7..92a3504f79 100644 --- a/src/Init/Data/Array/Monadic.lean +++ b/src/Init/Data/Array/Monadic.lean @@ -169,7 +169,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m] rcases xs with ⟨xs⟩ simp [List.foldlM_map] -theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (xs : Array α) (f : (a : α) → a ∈ xs → β → β) (init : β) : forIn' xs init (fun a m b => pure (.yield (f a m b))) = pure (f := m) (xs.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by @@ -211,7 +211,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m] rcases xs with ⟨xs⟩ simp [List.foldlM_map] -theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (xs : Array α) (f : α → β → β) (init : β) : forIn xs init (fun a b => pure (.yield (f a b))) = pure (f := m) (xs.foldl (fun b a => f a b) init) := by diff --git a/src/Init/Data/List/Monadic.lean b/src/Init/Data/List/Monadic.lean index f4fe358c23..7528cebdbe 100644 --- a/src/Init/Data/List/Monadic.lean +++ b/src/Init/Data/List/Monadic.lean @@ -330,7 +330,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m] simp only [forIn'_eq_foldlM] induction l.attach generalizing init <;> simp_all -theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (l : List α) (f : (a : α) → a ∈ l → β → β) (init : β) : forIn' l init (fun a m b => pure (.yield (f a m b))) = pure (f := m) (l.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by @@ -383,7 +383,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m] simp only [forIn_eq_foldlM] induction l generalizing init <;> simp_all -theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (l : List α) (f : α → β → β) (init : β) : forIn l init (fun a b => pure (.yield (f a b))) = pure (f := m) (l.foldl (fun b a => f a b) init) := by diff --git a/src/Init/Data/Option/Monadic.lean b/src/Init/Data/Option/Monadic.lean index 2d520c680a..89abb46d6b 100644 --- a/src/Init/Data/Option/Monadic.lean +++ b/src/Init/Data/Option/Monadic.lean @@ -46,7 +46,7 @@ theorem forIn'_eq_pelim [Monad m] [LawfulMonad m] o.pelim (pure b) (fun a h => g a h b <$> f a h b) := by cases o <;> simp -theorem forIn'_pure_yield_eq_pelim [Monad m] [LawfulMonad m] +@[simp] theorem forIn'_pure_yield_eq_pelim [Monad m] [LawfulMonad m] (o : Option α) (f : (a : α) → a ∈ o → β → β) (b : β) : forIn' o b (fun a m b => pure (.yield (f a m b))) = pure (f := m) (o.pelim b (fun a h => f a h b)) := by @@ -75,7 +75,7 @@ theorem forIn_eq_elim [Monad m] [LawfulMonad m] o.elim (pure b) (fun a => g a b <$> f a b) := by cases o <;> simp -theorem forIn_pure_yield_eq_elim [Monad m] [LawfulMonad m] +@[simp] theorem forIn_pure_yield_eq_elim [Monad m] [LawfulMonad m] (o : Option α) (f : (a : α) → β → β) (b : β) : forIn o b (fun a b => pure (.yield (f a b))) = pure (f := m) (o.elim b (fun a => f a b)) := by diff --git a/src/Init/Data/Vector/Monadic.lean b/src/Init/Data/Vector/Monadic.lean index 18dfdc5b15..e4c8b104d3 100644 --- a/src/Init/Data/Vector/Monadic.lean +++ b/src/Init/Data/Vector/Monadic.lean @@ -159,7 +159,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m] rcases xs with ⟨xs, rfl⟩ simp -theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (xs : Vector α n) (f : (a : α) → a ∈ xs → β → β) (init : β) : forIn' xs init (fun a m b => pure (.yield (f a m b))) = pure (f := m) (xs.attach.foldl (fun b ⟨a, h⟩ => f a h b) init) := by @@ -201,7 +201,7 @@ theorem forIn_eq_foldlM [Monad m] [LawfulMonad m] rcases xs with ⟨xs, rfl⟩ simp -theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] +@[simp] theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m] (xs : Vector α n) (f : α → β → β) (init : β) : forIn xs init (fun a b => pure (.yield (f a b))) = pure (f := m) (xs.foldl (fun b a => f a b) init) := by diff --git a/tests/lean/run/list_simp.lean b/tests/lean/run/list_simp.lean index 3c71b3c471..4ecfaf657a 100644 --- a/tests/lean/run/list_simp.lean +++ b/tests/lean/run/list_simp.lean @@ -492,6 +492,15 @@ variable (l : List α) (k m : Nat) in x := x + k pure x) ~> m + k * l.length +-- as above, but for an arbitrary monad +variable (l : List α) (k m : Nat) {M} [Monad M] [LawfulMonad M] in +#check_simp + (show M _ from do + let mut x := m + for _ in l do + x := x + k + pure x) ~> pure (m + k * l.length) + /-! ### mapM -/ /-! ### forM -/