feat: add LawfulMonad for StateT

This commit is contained in:
Leonardo de Moura 2021-02-21 10:52:53 -08:00
parent b1faed895a
commit d0574d8eb1
3 changed files with 84 additions and 7 deletions

View file

@ -6,6 +6,7 @@ Authors: Sebastian Ullrich, Leonardo de Moura
prelude
import Init.SimpLemmas
import Init.Control.Except
import Init.Control.StateRef
open Function
@ -18,6 +19,9 @@ export LawfulFunctor (map_const id_map comp_map)
attribute [simp] id_map
@[simp] theorem id_map' [Functor m] [LawfulFunctor m] (x : m α) : (fun a => a) <$> x = x :=
id_map x
class LawfulApplicative (f : Type u → Type v) [Applicative f] extends LawfulFunctor f : Prop where
seqLeft_eq (x : f α) (y : f β) : x <* y = const β <$> x <*> y
seqRight_eq (x : f α) (y : f β) : x *> y = const α id <$> x <*> y
@ -175,9 +179,9 @@ theorem ext [Monad m] {x y : ReaderT ρ m α} (h : ∀ ctx, x.run ctx = y.run ct
: (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl
@[simp] theorem run_map [Monad m] (f : α → β) (x : ReaderT ρ m α) (ctx : ρ)
: (f <$> x).run ctx = f <$> x.run ctx := rfl
@[simp] theorem run_monad_lift [MonadLiftT n m] (x : n α) (ctx : ρ)
@[simp] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ)
: (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl
@[simp] theorem run_monad_map [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ)
@[simp] theorem run_monadMap [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : ReaderT ρ m α) (ctx : ρ)
: (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl
@[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl
@[simp] theorem run_seq {α β : Type u} [Monad m] [LawfulMonad m] (f : ReaderT ρ m (α → β)) (x : ReaderT ρ m α) (ctx : ρ) : (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := by
@ -199,3 +203,75 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (ReaderT ρ m) where
bind_assoc := by intros; apply ext; intros; simp
end ReaderT
/- StateRefT -/
instance [Monad m] [LawfulMonad m] : LawfulMonad (StateRefT' ω σ m) :=
inferInstanceAs (LawfulMonad (ReaderT (ST.Ref ω σ) m))
/- StateT -/
namespace StateT
theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y :=
funext h
@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl
@[simp] theorem run_bind [Monad m] (x : StateT σ m α) (f : α → StateT σ m β) (s : σ)
: (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := by
simp [bind, StateT.bind, run]
apply bind_congr
intro p; cases p; rfl
@[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by
simp [Functor.map, StateT.map, run]
rw [← bind_pure_comp]
apply bind_congr
intro p; cases p; rfl
@[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl
@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (⟨⟩, s') := rfl
@[simp] theorem run_monadLift [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl
@[simp] theorem run_monadMap [Monad m] [MonadFunctor n m] (f : {β : Type u} → n β → n β) (x : StateT σ m α) (s : σ)
: (monadMap @f x : StateT σ m α).run s = monadMap @f (x.run s) := rfl
@[simp] theorem run_seq {α β σ : Type u} [Monad m] [LawfulMonad m] (f : StateT σ m (α → β)) (x : StateT σ m α) (s : σ) : (f <*> x).run s = (f.run s >>= fun fs => (fun (p : α × σ) => (fs.1 p.1, p.2)) <$> x.run fs.2) := by
show (f >>= fun g => g <$> x).run s = _
simp
@[simp] theorem run_seqRight [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) (s : σ) : (x *> y).run s = (x.run s >>= fun p => y.run p.2) := by
show (x >>= fun _ => y).run s = _
simp
@[simp] theorem run_seqLeft {α β σ : Type u} [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) (s : σ) : (x <* y).run s = (x.run s >>= fun p => y.run p.2 >>= fun p' => pure (p.1, p'.2)) := by
show (x >>= fun a => y >>= fun _ => pure a).run s = _
simp
theorem seqRight_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x *> y = const α id <$> x <*> y := by
apply ext; intro s
simp; rw [← bind_pure_comp]; simp
apply bind_congr; intro p; cases p
simp[Prod.ext]
theorem seqLeft_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x <* y = const β <$> x <*> y := by
apply ext; intro s
simp; rw [← bind_pure_comp]; simp
apply bind_congr; intro p; cases p
simp[Prod.ext, const]; rw [← bind_pure_comp]
instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where
id_map := by intros; apply ext; intros; simp[Prod.ext]
map_const := by intros; rfl
seqLeft_eq := seqLeft_eq
seqRight_eq := seqRight_eq
pure_seq := by intros; apply ext; intros; simp
bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp
bind_map := by intros; rfl
pure_bind := by intros; apply ext; intros; simp
bind_assoc := by intros; apply ext; intros; simp
end StateT

View file

@ -556,9 +556,6 @@ end
/- Product -/
section
variable {α : Type u} {β : Type v}
instance [Inhabited α] [Inhabited β] : Inhabited (α × β) where
default := (arbitrary, arbitrary)
@ -585,9 +582,11 @@ instance prodHasDecidableLt
theorem Prod.ltDef [HasLess α] [HasLess β] (s t : α × β) : (s < t) = (s.1 < t.1 (s.1 = t.1 ∧ s.2 < t.2)) :=
rfl
end
def Prod.map.{u₁, u₂, v₁, v₂} {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂ : Type v₂}
theorem Prod.ext (p : α × β) : (p.1, p.2) = p := by
cases p; rfl
def Prod.map {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂ : Type v₂}
(f : α₁ → α₂) (g : β₁ → β₂) : α₁ × β₁ → α₂ × β₂
| (a, b) => (f a, g b)

View file

@ -61,6 +61,8 @@ abbrev Eq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α → Sort u1} (m :
@[matchPattern] def rfl {α : Sort u} {a : α} : Eq a a := Eq.refl a
@[simp] theorem id_eq (a : α) : Eq (id a) a := rfl
theorem Eq.subst {α : Sort u} {motive : α → Prop} {a b : α} (h₁ : Eq a b) (h₂ : motive a) : motive b :=
Eq.ndrec h₂ h₁