refactor: StateCpsT

Define `run` and `run'` using `runK`.
Write auxiliary simp lemmas using `runK`.
This commit is contained in:
Leonardo de Moura 2021-03-01 15:24:26 -08:00
parent bbf158005f
commit 81ba5485dd

View file

@ -14,15 +14,15 @@ def StateCpsT (σ : Type u) (m : Type u → Type v) (α : Type u) := (δ : Type
namespace StateCpsT
@[inline] def run {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m (α × σ) :=
x _ s fun a s => pure (a, s)
@[inline] def run' {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m α :=
(·.1) <$> run x s
@[inline] def runK {α σ : Type u} {m : Type u → Type v} (x : StateCpsT σ m α) (s : σ) (k : ασ → m β) : m β :=
x _ s k
@[inline] def run {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m (α × σ) :=
runK x s (fun a s => pure (a, s))
@[inline] def run' {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m α :=
runK x s (fun a s => pure a)
instance : Monad (StateCpsT σ m) where
map f x := fun δ s k => x δ s fun a s => k (f a) s
pure a := fun δ s k => k a s
@ -42,28 +42,32 @@ instance : MonadStateOf σ (StateCpsT σ m) where
instance [Monad m] : MonadLift m (StateCpsT σ m) where
monadLift := StateCpsT.lift
@[simp] theorem run'_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run' s = (·.1) <$> run x s := rfl
@[simp] theorem runK_pure {m : Type u → Type v} (a : α) (s : σ) (k : ασ → m β) : (pure a : StateCpsT σ m α).runK s k = k a s := rfl
@[simp] theorem run_get [Monad m] (s : σ) : (get : StateCpsT σ m σ).run s = pure (s, s) := rfl
@[simp] theorem runK_get {m : Type u → Type v} (s : σ) (k : σσ → m β) : (get : StateCpsT σ m σ).runK s k = k s s := rfl
@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateCpsT σ m PUnit).run s = pure (⟨⟩, s') := rfl
@[simp] theorem runK_set {m : Type u → Type v} (s s' : σ) (k : PUnit → σ → m β) : (set s' : StateCpsT σ m PUnit).runK s k = k ⟨⟩ s' := rfl
@[simp] theorem run_modify [Monad m] (f : σσ) (s : σ) : (modify f : StateCpsT σ m PUnit).run s = pure (⟨⟩, f s) := rfl
@[simp] theorem runK_modify {m : Type u → Type v} (f : σσ) (s : σ) (k : PUnit → σ → m β) : (modify f : StateCpsT σ m PUnit).runK s k = k ⟨⟩ (f s) := rfl
@[simp] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateCpsT.lift x : StateCpsT σ m α).run s = x >>= fun a => pure (a, s) := rfl
@[simp] theorem runK_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) (k : ασ → m β) : (StateCpsT.lift x : StateCpsT σ m α).runK s k = x >>= (k . s) := rfl
@[simp] theorem run_bind_pure {α σ : Type u} [Monad m] (a : α) (f : α → StateCpsT σ m β) (s : σ) : (pure a >>= f).run s = (f a).run s := rfl
@[simp] theorem runK_monadLift {σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) (k : ασ → m β)
: (monadLift x : StateCpsT σ m α).runK s k = (monadLift x : m α) >>= (k . s) := rfl
@[simp] theorem run_bind_lift {α σ : Type u} [Monad m] (x : m α) (f : α → StateCpsT σ m β) (s : σ) : (StateCpsT.lift x >>= f).run s = x >>= fun a => (f a).run s := rfl
@[simp] theorem runK_bind_pure {α σ : Type u} [Monad m] (a : α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (pure a >>= f).runK s k = (f a).runK s k := rfl
@[simp] theorem run_bind_get {σ : Type u} [Monad m] (f : σ → StateCpsT σ m β) (s : σ) : (get >>= f).run s = (f s).run s := rfl
@[simp] theorem runK_bind_lift {α σ : Type u} [Monad m] (x : m α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ)
: (StateCpsT.lift x >>= f).runK s k = x >>= fun a => (f a).runK s k := rfl
@[simp] theorem run_bind_set {σ : Type u} [Monad m] (f : PUnit → StateCpsT σ m β) (s s' : σ) : (set s' >>= f).run s = (f ⟨⟩).run s' := rfl
@[simp] theorem runK_bind_get {σ : Type u} [Monad m] (f : σ → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (get >>= f).runK s k = (f s).runK s k := rfl
@[simp] theorem run_bind_modify {σ : Type u} [Monad m] (f : σσ) (g : PUnit → StateCpsT σ m β) (s : σ) : (modify f >>= g).run s = (g ⟨⟩).run (f s) := rfl
@[simp] theorem runK_bind_set {σ : Type u} [Monad m] (f : PUnit → StateCpsT σ m β) (s s' : σ) (k : β → σ → m γ) : (set s' >>= f).runK s k = (f ⟨⟩).runK s' k := rfl
@[simp] theorem run_monadLift {σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateCpsT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl
@[simp] theorem runK_bind_modify {σ : Type u} [Monad m] (f : σσ) (g : PUnit → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (modify f >>= g).runK s k = (g ⟨⟩).runK (f s) k := rfl
@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateCpsT σ m α).run s = pure (a, s) := rfl
@[simp] theorem run_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run s = x.runK s (fun a s => pure (a, s)) := rfl
@[simp] theorem run'_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run' s = x.runK s (fun a s => pure a) := rfl
end StateCpsT