From d0574d8eb1cdac06f3f73e4fd63a6552da850fcf Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 21 Feb 2021 10:52:53 -0800 Subject: [PATCH] feat: add `LawfulMonad` for `StateT` --- src/Init/Control/Lawful.lean | 80 +++++++++++++++++++++++++++++++++++- src/Init/Core.lean | 9 ++-- src/Init/Prelude.lean | 2 + 3 files changed, 84 insertions(+), 7 deletions(-) diff --git a/src/Init/Control/Lawful.lean b/src/Init/Control/Lawful.lean index 4d78ef693b..313c9d785c 100644 --- a/src/Init/Control/Lawful.lean +++ b/src/Init/Control/Lawful.lean @@ -6,6 +6,7 @@ Authors: Sebastian Ullrich, Leonardo de Moura prelude import Init.SimpLemmas import Init.Control.Except +import Init.Control.StateRef open Function @@ -18,6 +19,9 @@ export LawfulFunctor (map_const id_map comp_map) attribute [simp] id_map +@[simp] theorem id_map' [Functor m] [LawfulFunctor m] (x : m α) : (fun a => a) <$> x = x := + id_map x + class LawfulApplicative (f : Type u → Type v) [Applicative f] extends LawfulFunctor f : Prop where seqLeft_eq (x : f α) (y : f β) : x <* y = const β <$> x <*> y seqRight_eq (x : f α) (y : f β) : x *> y = const α id <$> x <*> y @@ -175,9 +179,9 @@ theorem ext [Monad m] {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ct : (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl @[simp] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ) : (f <$> x).run ctx = f <$> x.run ctx := rfl -@[simp] theorem run_monad_lift [MonadLiftT n m] (x : n α) (ctx : ρ) +@[simp] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ) : (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl -@[simp] theorem run_monad_map [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ) +@[simp] theorem run_monadMap [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ) : (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl @[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl @[simp] theorem run_seq {α β : Type u} [Monad m] [LawfulMonad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ) : (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := by @@ -199,3 +203,75 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (ReaderT ρ m) where bind_assoc := by intros; apply ext; intros; simp end ReaderT + +/- StateRefT -/ + +instance [Monad m] [LawfulMonad m] : LawfulMonad (StateRefT' ω σ m) := + inferInstanceAs (LawfulMonad (ReaderT (ST.Ref ω σ) m)) + +/- StateT -/ + +namespace StateT + +theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y := + funext h + +@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl + +@[simp] theorem run_bind [Monad m] (x : StateT σ m α) (f : α → StateT σ m β) (s : σ) + : (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := by + simp [bind, StateT.bind, run] + apply bind_congr + intro p; cases p; rfl + +@[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by + simp [Functor.map, StateT.map, run] + rw [← bind_pure_comp] + apply bind_congr + intro p; cases p; rfl + +@[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl + +@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (⟨⟩, s') := rfl + +@[simp] theorem run_monadLift [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl + +@[simp] theorem run_monadMap [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : StateT σ m α) (s : σ) + : (monadMap @f x : StateT σ m α).run s = monadMap @f (x.run s) := rfl + +@[simp] theorem run_seq {α β σ : Type u} [Monad m] [LawfulMonad m] (f : StateT σ m (α → β)) (x : StateT σ m α) (s : σ) : (f <*> x).run s = (f.run s >>= fun fs => (fun (p : α × σ) => (fs.1 p.1, p.2)) <$> x.run fs.2) := by + show (f >>= fun g => g <$> x).run s = _ + simp + +@[simp] theorem run_seqRight [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) (s : σ) : (x *> y).run s = (x.run s >>= fun p => y.run p.2) := by + show (x >>= fun _ => y).run s = _ + simp + +@[simp] theorem run_seqLeft {α β σ : Type u} [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) (s : σ) : (x <* y).run s = (x.run s >>= fun p => y.run p.2 >>= fun p' => pure (p.1, p'.2)) := by + show (x >>= fun a => y >>= fun _ => pure a).run s = _ + simp + +theorem seqRight_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x *> y = const α id <$> x <*> y := by + apply ext; intro s + simp; rw [← bind_pure_comp]; simp + apply bind_congr; intro p; cases p + simp[Prod.ext] + +theorem seqLeft_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x <* y = const β <$> x <*> y := by + apply ext; intro s + simp; rw [← bind_pure_comp]; simp + apply bind_congr; intro p; cases p + simp[Prod.ext, const]; rw [← bind_pure_comp] + +instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where + id_map := by intros; apply ext; intros; simp[Prod.ext] + map_const := by intros; rfl + seqLeft_eq := seqLeft_eq + seqRight_eq := seqRight_eq + pure_seq := by intros; apply ext; intros; simp + bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp + bind_map := by intros; rfl + pure_bind := by intros; apply ext; intros; simp + bind_assoc := by intros; apply ext; intros; simp + +end StateT diff --git a/src/Init/Core.lean b/src/Init/Core.lean index de20f01541..8eae03ab07 100644 --- a/src/Init/Core.lean +++ b/src/Init/Core.lean @@ -556,9 +556,6 @@ end /- Product -/ -section -variable {α : Type u} {β : Type v} - instance [Inhabited α] [Inhabited β] : Inhabited (α × β) where default := (arbitrary, arbitrary) @@ -585,9 +582,11 @@ instance prodHasDecidableLt theorem Prod.ltDef [HasLess α] [HasLess β] (s t : α × β) : (s < t) = (s.1 < t.1 ∨ (s.1 = t.1 ∧ s.2 < t.2)) := rfl -end -def Prod.map.{u₁, u₂, v₁, v₂} {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂ : Type v₂} +theorem Prod.ext (p : α × β) : (p.1, p.2) = p := by + cases p; rfl + +def Prod.map {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂ : Type v₂} (f : α₁ → α₂) (g : β₁ → β₂) : α₁ × β₁ → α₂ × β₂ | (a, b) => (f a, g b) diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 8a93546630..b04bd24b84 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -61,6 +61,8 @@ abbrev Eq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α → Sort u1} (m : @[matchPattern] def rfl {α : Sort u} {a : α} : Eq a a := Eq.refl a +@[simp] theorem id_eq (a : α) : Eq (id a) a := rfl + theorem Eq.subst {α : Sort u} {motive : α → Prop} {a b : α} (h₁ : Eq a b) (h₂ : motive a) : motive b := Eq.ndrec h₂ h₁