lean4-htt/src/Init/Control/StateCps.lean
Leonardo de Moura 81ba5485dd refactor: StateCpsT
Define `run` and `run'` using `runK`.
Write auxiliary simp lemmas using `runK`.
2021-03-02 06:22:22 -08:00

73 lines
3.6 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Control.Lawful
/-
The State monad transformer using CPS style.
-/
def StateCpsT (σ : Type u) (m : Type u → Type v) (α : Type u) := (δ : Type u) → σ → (ασ → m δ) → m δ
namespace StateCpsT
@[inline] def runK {α σ : Type u} {m : Type u → Type v} (x : StateCpsT σ m α) (s : σ) (k : ασ → m β) : m β :=
x _ s k
@[inline] def run {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m (α × σ) :=
runK x s (fun a s => pure (a, s))
@[inline] def run' {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m α :=
runK x s (fun a s => pure a)
instance : Monad (StateCpsT σ m) where
map f x := fun δ s k => x δ s fun a s => k (f a) s
pure a := fun δ s k => k a s
bind x f := fun δ s k => x δ s fun a s => f a δ s k
instance : LawfulMonad (StateCpsT σ m) := by
refine! { .. } <;> intros <;> rfl
instance : MonadStateOf σ (StateCpsT σ m) where
get := fun δ s k => k s s
set s := fun δ _ k => k ⟨⟩ s
modifyGet f := fun _ s k => let (a, s) := f s; k a s
@[inline] protected def lift [Monad m] (x : m α) : StateCpsT σ m α :=
fun _ s k => x >>= (k . s)
instance [Monad m] : MonadLift m (StateCpsT σ m) where
monadLift := StateCpsT.lift
@[simp] theorem runK_pure {m : Type u → Type v} (a : α) (s : σ) (k : ασ → m β) : (pure a : StateCpsT σ m α).runK s k = k a s := rfl
@[simp] theorem runK_get {m : Type u → Type v} (s : σ) (k : σσ → m β) : (get : StateCpsT σ m σ).runK s k = k s s := rfl
@[simp] theorem runK_set {m : Type u → Type v} (s s' : σ) (k : PUnit → σ → m β) : (set s' : StateCpsT σ m PUnit).runK s k = k ⟨⟩ s' := rfl
@[simp] theorem runK_modify {m : Type u → Type v} (f : σσ) (s : σ) (k : PUnit → σ → m β) : (modify f : StateCpsT σ m PUnit).runK s k = k ⟨⟩ (f s) := rfl
@[simp] theorem runK_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) (k : ασ → m β) : (StateCpsT.lift x : StateCpsT σ m α).runK s k = x >>= (k . s) := rfl
@[simp] theorem runK_monadLift {σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) (k : ασ → m β)
: (monadLift x : StateCpsT σ m α).runK s k = (monadLift x : m α) >>= (k . s) := rfl
@[simp] theorem runK_bind_pure {α σ : Type u} [Monad m] (a : α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (pure a >>= f).runK s k = (f a).runK s k := rfl
@[simp] theorem runK_bind_lift {α σ : Type u} [Monad m] (x : m α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ)
: (StateCpsT.lift x >>= f).runK s k = x >>= fun a => (f a).runK s k := rfl
@[simp] theorem runK_bind_get {σ : Type u} [Monad m] (f : σ → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (get >>= f).runK s k = (f s).runK s k := rfl
@[simp] theorem runK_bind_set {σ : Type u} [Monad m] (f : PUnit → StateCpsT σ m β) (s s' : σ) (k : β → σ → m γ) : (set s' >>= f).runK s k = (f ⟨⟩).runK s' k := rfl
@[simp] theorem runK_bind_modify {σ : Type u} [Monad m] (f : σσ) (g : PUnit → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (modify f >>= g).runK s k = (g ⟨⟩).runK (f s) k := rfl
@[simp] theorem run_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run s = x.runK s (fun a s => pure (a, s)) := rfl
@[simp] theorem run'_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run' s = x.runK s (fun a s => pure a) := rfl
end StateCpsT