From f9b2e550bb6be77df08f159a9469bfb9fafb3a97 Mon Sep 17 00:00:00 2001 From: Kim Morrison <477956+kim-em@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:51:55 +1000 Subject: [PATCH] feat: grind annotations for basic monad transformers (#10227) This PR adds `@[grind]` annotations (nearly all `@[grind =]` annotations parallel to existing `@[simp]`s) for `ReaderT`, `StateT`, `ExceptT`. --- src/Init/Control/Lawful/Instances.lean | 60 +++++++++++++------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/Init/Control/Lawful/Instances.lean b/src/Init/Control/Lawful/Instances.lean index c9b5773dde..32b932a4e3 100644 --- a/src/Init/Control/Lawful/Instances.lean +++ b/src/Init/Control/Lawful/Instances.lean @@ -22,23 +22,24 @@ open Function namespace ExceptT -@[ext] theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by +@[ext, grind ext] theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by simp [run] at h assumption -@[simp] theorem run_pure [Monad m] (x : α) : run (pure x : ExceptT ε m α) = pure (Except.ok x) := rfl +@[simp, grind =] theorem run_pure [Monad m] (x : α) : run (pure x : ExceptT ε m α) = pure (Except.ok x) := rfl -@[simp] theorem run_lift [Monad.{u, v} m] (x : m α) : run (ExceptT.lift x : ExceptT ε m α) = (Except.ok <$> x : m (Except ε α)) := rfl +@[simp, grind =] theorem run_lift [Monad.{u, v} m] (x : m α) : run (ExceptT.lift x : ExceptT ε m α) = (Except.ok <$> x : m (Except ε α)) := rfl -@[simp] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl +@[simp, grind =] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl -@[simp] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α → ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by +@[simp, grind =] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α → ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by simp [ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont] -@[simp] theorem bind_throw [Monad m] [LawfulMonad m] (f : α → ExceptT ε m β) : (throw e >>= f) = throw e := by +@[simp, grind =] theorem bind_throw [Monad m] [LawfulMonad m] (f : α → ExceptT ε m β) : (throw e >>= f) = throw e := by simp [throw, throwThe, MonadExceptOf.throw, bind, ExceptT.bind, ExceptT.bindCont, ExceptT.mk] -theorem run_bind [Monad m] (x : ExceptT ε m α) +@[grind =] +theorem run_bind [Monad m] (x : ExceptT ε m α) (f : α → ExceptT ε m β) : run (x >>= f : ExceptT ε m β) = run x >>= fun @@ -46,10 +47,10 @@ theorem run_bind [Monad m] (x : ExceptT ε m α) | Except.error e => pure (Except.error e) := rfl -@[simp] theorem lift_pure [Monad m] [LawfulMonad m] (a : α) : ExceptT.lift (pure a) = (pure a : ExceptT ε m α) := by +@[simp, grind =] theorem lift_pure [Monad m] [LawfulMonad m] (a : α) : ExceptT.lift (pure a) = (pure a : ExceptT ε m α) := by simp [ExceptT.lift, pure, ExceptT.pure] -@[simp] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α) +@[simp, grind =] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α) : (f <$> x).run = Except.map f <$> x.run := by simp [Functor.map, ExceptT.map, ←bind_pure_comp] apply bind_congr @@ -113,28 +114,28 @@ instance : LawfulFunctor (Except ε) := inferInstance namespace ReaderT -@[ext] theorem ext {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ctx) : x = y := by +@[ext, grind ext] theorem ext {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ctx) : x = y := by simp [run] at h exact funext h -@[simp] theorem run_pure [Monad m] (a : α) (ctx : ρ) : (pure a : ReaderT ρ m α).run ctx = pure a := rfl +@[simp, grind =] theorem run_pure [Monad m] (a : α) (ctx : ρ) : (pure a : ReaderT ρ m α).run ctx = pure a := rfl -@[simp] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) (ctx : ρ) +@[simp, grind =] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) (ctx : ρ) : (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl -@[simp] theorem run_mapConst [Monad m] (a : α) (x : ReaderT ρ m β) (ctx : ρ) +@[simp, grind =] theorem run_mapConst [Monad m] (a : α) (x : ReaderT ρ m β) (ctx : ρ) : (Functor.mapConst a x).run ctx = Functor.mapConst a (x.run ctx) := rfl -@[simp] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ) +@[simp, grind =] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ) : (f <$> x).run ctx = f <$> x.run ctx := rfl -@[simp] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ) +@[simp, grind =] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ) : (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl -@[simp] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ) +@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ) : (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl -@[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl +@[simp, grind =] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl @[simp] theorem run_seq {α β : Type u} [Monad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ) : (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := rfl @@ -175,38 +176,39 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (StateRefT' ω σ m) := namespace StateT -@[ext] theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y := +@[ext, grind ext] theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y := funext h -@[simp] theorem run'_eq [Monad m] (x : StateT σ m α) (s : σ) : run' x s = (·.1) <$> run x s := +@[simp, grind =] theorem run'_eq [Monad m] (x : StateT σ m α) (s : σ) : run' x s = (·.1) <$> run x s := rfl -@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl +@[simp, grind =] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl -@[simp] theorem run_bind [Monad m] (x : StateT σ m α) (f : α → StateT σ m β) (s : σ) +@[simp, grind =] theorem run_bind [Monad m] (x : StateT σ m α) (f : α → StateT σ m β) (s : σ) : (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := by simp [bind, StateT.bind, run] -@[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by +@[simp, grind =] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by simp [Functor.map, StateT.map, run, ←bind_pure_comp] -@[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl +@[simp, grind =] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl -@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (⟨⟩, s') := rfl +@[simp, grind =] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (⟨⟩, s') := rfl -@[simp] theorem run_modify [Monad m] (f : σ → σ) (s : σ) : (modify f : StateT σ m PUnit).run s = pure (⟨⟩, f s) := rfl +@[simp, grind =] theorem run_modify [Monad m] (f : σ → σ) (s : σ) : (modify f : StateT σ m PUnit).run s = pure (⟨⟩, f s) := rfl -@[simp] theorem run_modifyGet [Monad m] (f : σ → α × σ) (s : σ) : (modifyGet f : StateT σ m α).run s = pure ((f s).1, (f s).2) := by +@[simp, grind =] theorem run_modifyGet [Monad m] (f : σ → α × σ) (s : σ) : (modifyGet f : StateT σ m α).run s = pure ((f s).1, (f s).2) := by simp [modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, run] -@[simp] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateT.lift x : StateT σ m α).run s = x >>= fun a => pure (a, s) := rfl +@[simp, grind =] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateT.lift x : StateT σ m α).run s = x >>= fun a => pure (a, s) := rfl +@[grind =] theorem run_bind_lift {α σ : Type u} [Monad m] [LawfulMonad m] (x : m α) (f : α → StateT σ m β) (s : σ) : (StateT.lift x >>= f).run s = x >>= fun a => (f a).run s := by simp [StateT.lift, StateT.run, bind, StateT.bind] -@[simp] theorem run_monadLift {α σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl +@[simp, grind =] theorem run_monadLift {α σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl -@[simp] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : StateT σ m α) (s : σ) : +@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : StateT σ m α) (s : σ) : (monadMap @f x : StateT σ m α).run s = monadMap @f (x.run s) := rfl @[simp] theorem run_seq {α β σ : Type u} [Monad m] [LawfulMonad m] (f : StateT σ m (α → β)) (x : StateT σ m α) (s : σ) : (f <*> x).run s = (f.run s >>= fun fs => (fun (p : α × σ) => (fs.1 p.1, p.2)) <$> x.run fs.2) := by