From ae48feeb0729dd53b26345d695cb98568552771f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 21 Feb 2021 08:27:59 -0800 Subject: [PATCH] feat: add `LawfulMonad` for `ReaderT` --- src/Init/Control/Lawful.lean | 51 ++++++++++++++++++++++++++++++++++++ src/Init/Prelude.lean | 6 ++--- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/Init/Control/Lawful.lean b/src/Init/Control/Lawful.lean index a0ec869a98..4d78ef693b 100644 --- a/src/Init/Control/Lawful.lean +++ b/src/Init/Control/Lawful.lean @@ -67,6 +67,19 @@ theorem bind_congr [Bind m] {x : m α} {f g : α → m β} (h : ∀ a, f a = g a theorem map_congr [Functor m] {x : m α} {f g : α → β} (h : ∀ a, f a = g a) : (f <$> x : m β) = g <$> x := by simp [funext h] +theorem seq_eq_bind {α β : Type u} [Monad m] [LawfulMonad m] (mf : m (α → β)) (x : m α) : mf <*> x = mf >>= fun f => f <$> x := by + rw [bind_map] + +theorem seqRight_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x *> y = x >>= fun _ => y := by + rw [seqRight_eq, ← bind_map, ← bind_pure_comp] + simp [Function.const] + +theorem seqLeft_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x <* y = x >>= fun a => y >>= fun _ => pure a := by + rw [seqLeft_eq, ← bind_map, ← bind_pure_comp] + simp + apply bind_congr; intro + rw [← bind_pure_comp] + /- Id -/ namespace Id @@ -148,3 +161,41 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (ExceptT ε m) where bind_assoc := by intros; apply ext; simp; apply bind_congr; intro a; cases a <;> simp end ExceptT + +/- ReaderT -/ + +namespace ReaderT + +theorem ext [Monad m] {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ctx) : x = y := by + simp [run] at h + exact funext h + +@[simp] theorem run_pure [Monad m] (a : α) (ctx : ρ) : (pure a : ReaderT ρ m α).run ctx = pure a := rfl +@[simp] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) (ctx : ρ) + : (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl +@[simp] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ) + : (f <$> x).run ctx = f <$> x.run ctx := rfl +@[simp] theorem run_monad_lift [MonadLiftT n m] (x : n α) (ctx : ρ) + : (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl +@[simp] theorem run_monad_map [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ) + : (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl +@[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl +@[simp] theorem run_seq {α β : Type u} [Monad m] [LawfulMonad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ) : (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := by + rw [seq_eq_bind (m := m)]; rfl +@[simp] theorem run_seqRight [Monad m] [LawfulMonad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ) : (x *> y).run ctx = (x.run ctx *> y.run ctx) := by + rw [seqRight_eq_bind (m := m)]; rfl +@[simp] theorem run_seqLeft [Monad m] [LawfulMonad m] (x : ReaderT ρ m α) (y : ReaderT ρ m β) (ctx : ρ) : (x <* y).run ctx = (x.run ctx <* y.run ctx) := by + rw [seqLeft_eq_bind (m := m)]; rfl + +instance [Monad m] [LawfulMonad m] : LawfulMonad (ReaderT ρ m) where + id_map := by intros; apply ext; intros; simp + map_const := by intros; rfl + seqLeft_eq := by intros; apply ext; intros; simp; apply LawfulApplicative.seqLeft_eq + seqRight_eq := by intros; apply ext; intros; simp; apply LawfulApplicative.seqRight_eq + pure_seq := by intros; apply ext; intros; simp; apply LawfulApplicative.pure_seq + bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp + bind_map := by intros; rfl + pure_bind := by intros; apply ext; intros; simp + bind_assoc := by intros; apply ext; intros; simp + +end ReaderT diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index bd26ea7e11..ceb94899cc 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1183,12 +1183,12 @@ instance (m) : MonadLiftT m m where but not restricted to monad transformers. Alternatively, an implementation of [MonadTransFunctor](http://duairc.netsoc.ie/layers-docs/Control-Monad-Layer.html#t:MonadTransFunctor). -/ class MonadFunctor (m : Type u → Type v) (n : Type u → Type w) where - monadMap {α : Type u} : (∀ {β}, m β → m β) → n α → n α + monadMap {α : Type u} : ({β : Type u} → m β → m β) → n α → n α /-- The reflexive-transitive closure of `MonadFunctor`. `monadMap` is used to transitively lift Monad morphisms -/ class MonadFunctorT (m : Type u → Type v) (n : Type u → Type w) where - monadMap {α : Type u} : (∀ {β}, m β → m β) → n α → n α + monadMap {α : Type u} : ({β : Type u} → m β → m β) → n α → n α export MonadFunctorT (monadMap) @@ -1302,7 +1302,7 @@ end ReaderT Note: This class can be seen as a simplification of the more "principled" definition ``` class MonadReader (ρ : outParam (Type u)) (n : Type u → Type u) where - lift {α : Type u} : (∀ {m : Type u → Type u} [Monad m], ReaderT ρ m α) → n α + lift {α : Type u} : ({m : Type u → Type u} → [Monad m] → ReaderT ρ m α) → n α ``` -/ class MonadReaderOf (ρ : Type u) (m : Type u → Type v) where