From 81ba5485ddb198c9688957c304dc46a1e9ec69ff Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 1 Mar 2021 15:24:26 -0800 Subject: [PATCH] refactor: `StateCpsT` Define `run` and `run'` using `runK`. Write auxiliary simp lemmas using `runK`. --- src/Init/Control/StateCps.lean | 40 +++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/Init/Control/StateCps.lean b/src/Init/Control/StateCps.lean index bea0c2cc5a..6b00d4030e 100644 --- a/src/Init/Control/StateCps.lean +++ b/src/Init/Control/StateCps.lean @@ -14,15 +14,15 @@ def StateCpsT (σ : Type u) (m : Type u → Type v) (α : Type u) := (δ : Type namespace StateCpsT -@[inline] def run {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m (α × σ) := - x _ s fun a s => pure (a, s) - -@[inline] def run' {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m α := - (·.1) <$> run x s - @[inline] def runK {α σ : Type u} {m : Type u → Type v} (x : StateCpsT σ m α) (s : σ) (k : α → σ → m β) : m β := x _ s k +@[inline] def run {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m (α × σ) := + runK x s (fun a s => pure (a, s)) + +@[inline] def run' {α σ : Type u} {m : Type u → Type v} [Monad m] (x : StateCpsT σ m α) (s : σ) : m α := + runK x s (fun a s => pure a) + instance : Monad (StateCpsT σ m) where map f x := fun δ s k => x δ s fun a s => k (f a) s pure a := fun δ s k => k a s @@ -42,28 +42,32 @@ instance : MonadStateOf σ (StateCpsT σ m) where instance [Monad m] : MonadLift m (StateCpsT σ m) where monadLift := StateCpsT.lift -@[simp] theorem run'_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run' s = (·.1) <$> run x s := rfl +@[simp] theorem runK_pure {m : Type u → Type v} (a : α) (s : σ) (k : α → σ → m β) : (pure a : StateCpsT σ m α).runK s k = k a s := rfl -@[simp] theorem run_get [Monad m] (s : σ) : (get : StateCpsT σ m σ).run s = pure (s, s) := rfl +@[simp] theorem runK_get {m : Type u → Type v} (s : σ) (k : σ → σ → m β) : (get : StateCpsT σ m σ).runK s k = k s s := rfl -@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateCpsT σ m PUnit).run s = pure (⟨⟩, s') := rfl +@[simp] theorem runK_set {m : Type u → Type v} (s s' : σ) (k : PUnit → σ → m β) : (set s' : StateCpsT σ m PUnit).runK s k = k ⟨⟩ s' := rfl -@[simp] theorem run_modify [Monad m] (f : σ → σ) (s : σ) : (modify f : StateCpsT σ m PUnit).run s = pure (⟨⟩, f s) := rfl +@[simp] theorem runK_modify {m : Type u → Type v} (f : σ → σ) (s : σ) (k : PUnit → σ → m β) : (modify f : StateCpsT σ m PUnit).runK s k = k ⟨⟩ (f s) := rfl -@[simp] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateCpsT.lift x : StateCpsT σ m α).run s = x >>= fun a => pure (a, s) := rfl +@[simp] theorem runK_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) (k : α → σ → m β) : (StateCpsT.lift x : StateCpsT σ m α).runK s k = x >>= (k . s) := rfl -@[simp] theorem run_bind_pure {α σ : Type u} [Monad m] (a : α) (f : α → StateCpsT σ m β) (s : σ) : (pure a >>= f).run s = (f a).run s := rfl +@[simp] theorem runK_monadLift {σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) (k : α → σ → m β) + : (monadLift x : StateCpsT σ m α).runK s k = (monadLift x : m α) >>= (k . s) := rfl -@[simp] theorem run_bind_lift {α σ : Type u} [Monad m] (x : m α) (f : α → StateCpsT σ m β) (s : σ) : (StateCpsT.lift x >>= f).run s = x >>= fun a => (f a).run s := rfl +@[simp] theorem runK_bind_pure {α σ : Type u} [Monad m] (a : α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (pure a >>= f).runK s k = (f a).runK s k := rfl -@[simp] theorem run_bind_get {σ : Type u} [Monad m] (f : σ → StateCpsT σ m β) (s : σ) : (get >>= f).run s = (f s).run s := rfl +@[simp] theorem runK_bind_lift {α σ : Type u} [Monad m] (x : m α) (f : α → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) + : (StateCpsT.lift x >>= f).runK s k = x >>= fun a => (f a).runK s k := rfl -@[simp] theorem run_bind_set {σ : Type u} [Monad m] (f : PUnit → StateCpsT σ m β) (s s' : σ) : (set s' >>= f).run s = (f ⟨⟩).run s' := rfl +@[simp] theorem runK_bind_get {σ : Type u} [Monad m] (f : σ → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (get >>= f).runK s k = (f s).runK s k := rfl -@[simp] theorem run_bind_modify {σ : Type u} [Monad m] (f : σ → σ) (g : PUnit → StateCpsT σ m β) (s : σ) : (modify f >>= g).run s = (g ⟨⟩).run (f s) := rfl +@[simp] theorem runK_bind_set {σ : Type u} [Monad m] (f : PUnit → StateCpsT σ m β) (s s' : σ) (k : β → σ → m γ) : (set s' >>= f).runK s k = (f ⟨⟩).runK s' k := rfl -@[simp] theorem run_monadLift {σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateCpsT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl +@[simp] theorem runK_bind_modify {σ : Type u} [Monad m] (f : σ → σ) (g : PUnit → StateCpsT σ m β) (s : σ) (k : β → σ → m γ) : (modify f >>= g).runK s k = (g ⟨⟩).runK (f s) k := rfl -@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateCpsT σ m α).run s = pure (a, s) := rfl +@[simp] theorem run_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run s = x.runK s (fun a s => pure (a, s)) := rfl + +@[simp] theorem run'_eq [Monad m] (x : StateCpsT σ m α) (s : σ) : x.run' s = x.runK s (fun a s => pure a) := rfl end StateCpsT