From ace8ef286a735780a3ee0bd727b61d450b60c52a Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Tue, 23 Jan 2018 13:43:14 +0100 Subject: [PATCH] feat(init/category): even more refactorings, simp lemmas --- library/init/category/except.lean | 12 ++++++------ library/init/category/lawful.lean | 21 +++++++++++++++++++++ library/init/category/reader.lean | 7 +++++-- library/init/category/state.lean | 23 ++++++++++++----------- library/init/category/transformers.lean | 13 +++++++++---- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/library/init/category/except.lean b/library/init/category/except.lean index c80d2efc35..3a98405c2f 100644 --- a/library/init/category/except.lean +++ b/library/init/category/except.lean @@ -95,6 +95,12 @@ section | except.error e := (handle e).run end⟩ + protected def monad_map {m'} [monad m'] {α} (f : ∀ {α}, m α → m' α) : except_t ε m α → except_t ε m' α := + λ x, ⟨f x.run⟩ + + instance (m') [monad m'] : monad_functor m m' (except_t ε m) (except_t ε m') := + ⟨@monad_map m' _⟩ + instance : monad (except_t ε m) := { pure := @return, bind := @bind } end @@ -103,11 +109,5 @@ end except_t instance (m ε) [monad m] : monad_except ε (except_t ε m) := { throw := λ α, except_t.mk ∘ pure ∘ except.error, catch := @except_t.catch ε _ _ } -def map_except_t {ε m m'} [monad m] [monad m'] {α} (f : ∀ {α}, m α → m' α) : except_t ε m α → except_t ε m' α := -λ x, ⟨f x.run⟩ - -instance (ε m m') [monad m] [monad m'] : monad_functor m m' (except_t ε m) (except_t ε m') := -⟨@map_except_t ε m m' _ _⟩ - instance (ε m out) [monad_run out m] : monad_run (λ α, out (except ε α)) (except_t ε m) := ⟨λ α, run ∘ except_t.run, λ α, except_t.mk ∘ unrun⟩ diff --git a/library/init/category/lawful.lean b/library/init/category/lawful.lean index c0d0ebc488..de222c9acd 100644 --- a/library/init/category/lawful.lean +++ b/library/init/category/lawful.lean @@ -76,6 +76,17 @@ lemma map_ext_congr {α β} {m : Type u → Type v} [has_map m] {x : m α} {f : -- instances of previously defined monads +namespace id +variables {α β : Type} +@[simp] lemma map_eq (x : id α) (f : α → β) : f <$> x = f x := rfl +@[simp] lemma bind_eq (x : id α) (f : α → id β) : x >>= f = f x := rfl +@[simp] lemma pure_eq (a : α) : (pure a : id α) = a := rfl +end id + +instance : is_lawful_monad id := +by refine { id_map := _, bind_assoc := _, pure_bind := _ }; + intros; refl + namespace state_t section variable {σ : Type u} @@ -96,6 +107,11 @@ section change (x >>= pure ∘ f).run st = _, simp end + @[simp] lemma run_monad_lift {n} [has_monad_lift_t n m] (x : n α) : (monad_lift x : state_t σ m α).run st = do a ← (monad_lift x : m α), pure (a, st) := rfl + @[simp] lemma run_monad_map {m' n n'} [monad m'] [monad_functor_t n n' m m'] (f : ∀ {α}, n α → n' α) : (monad_map @f x : state_t σ m' α).run st = monad_map @f (x.run st) := rfl + @[simp] lemma run_zoom {σ'} (st get set) : (state_t.zoom get set x : state_t σ' m α).run st = (λ p : α × σ, (p.1, set p.2 st)) <$> x.run (get st) := rfl + @[simp] lemma run_get : (state_t.get : state_t σ m σ).run st = pure (st, st) := rfl + @[simp] lemma run_put (st') : (state_t.put st' : state_t σ m _).run st = pure (punit.star, st') := rfl end end state_t @@ -121,6 +137,8 @@ namespace except_t rw [bind_ext_congr], intro a; cases a; simp [except_t.bind_cont, except.map] end + @[simp] lemma run_monad_lift {n} [has_monad_lift_t n m] (x : n α) : (@monad_lift _ _ _ _ x : except_t ε m α).run = except.ok <$> (monad_lift x : m α) := rfl + @[simp] lemma run_monad_map {m' n n'} [monad m'] [monad_functor_t n n' m m'] (f : ∀ {α}, n α → n' α) : (monad_map @f x : except_t ε m' α).run = monad_map @f x.run := rfl end except_t instance (m : Type u → Type v) [monad m] [is_lawful_monad m] (ε : Type u) : is_lawful_monad (except_t ε m) := @@ -157,6 +175,9 @@ section @[simp] lemma run_bind (f : α → reader_t r m β) : (x >>= f).run cfg = x.run cfg >>= λ a, (f a).run cfg := rfl @[simp] lemma run_map (f : α → β) [is_lawful_monad m] : (f <$> x).run cfg = f <$> x.run cfg := by rw ←bind_pure_comp_eq_map m; refl + @[simp] lemma run_monad_lift {n} [has_monad_lift_t n m] (x : n α) : (@monad_lift _ _ _ _ x : reader_t r m α).run cfg = (monad_lift x : m α) := rfl + @[simp] lemma run_monad_map {m' n n'} [monad m'] [monad_functor_t n n' m m'] (f : ∀ {α}, n α → n' α) : (monad_map @f x : reader_t r m' α).run cfg = monad_map @f (x.run cfg) := rfl + @[simp] lemma run_read : (reader_t.read : reader_t r m r).run cfg = pure cfg := rfl end end reader_t diff --git a/library/init/category/reader.lean b/library/init/category/reader.lean index 19b8be0835..5d9618937d 100644 --- a/library/init/category/reader.lean +++ b/library/init/category/reader.lean @@ -39,11 +39,11 @@ section instance (m) [monad m] : has_monad_lift m (reader_t r m) := ⟨@reader_t.lift r m _⟩ - protected def map {r m m'} [monad m] [monad m'] {α} (f : Π {α}, m α → m' α) : reader_t r m α → reader_t r m' α := + protected def monad_map {r m m'} [monad m] [monad m'] {α} (f : Π {α}, m α → m' α) : reader_t r m α → reader_t r m' α := λ x, ⟨λ r, f (x.run r)⟩ instance (r m m') [monad m] [monad m'] : monad_functor m m' (reader_t r m) (reader_t r m') := - ⟨@reader_t.map r m m' _ _⟩ + ⟨@reader_t.monad_map r m m' _ _⟩ end end reader_t @@ -68,3 +68,6 @@ def with_reader_t {r r' m} [monad m] {α} (f : r' → r) : reader_t r m α → r def with_reader {r r'} {m n n'} [monad m] [monad_reader_functor r r' m n n'] {α : Type u} (f : r' → r) : n α → n' α := monad_map $ λ α, (with_reader_t f : reader_t r m α → reader_t r' m α) + +instance (r : Type u) (m out) [monad_run out m] : monad_run (λ α, r → out α) (reader_t r m) := +⟨λ α x, run ∘ x.run, λ α a, reader_t.mk (unrun ∘ a)⟩ diff --git a/library/init/category/state.lean b/library/init/category/state.lean index a185cef27b..70a4c1caeb 100644 --- a/library/init/category/state.lean +++ b/library/init/category/state.lean @@ -17,12 +17,13 @@ namespace state_t section variable {σ : Type u} variable {m : Type u → Type v} + + @[inline] protected def run {α : Type u} (st : σ) (x : state_t σ m α) : m (α × σ) := + state_t.run' x st + variable [monad m] variables {α β : Type u} - @[inline] protected def run (st : σ) (x : state_t σ m α) : m (α × σ) := - state_t.run' x st - protected def pure (a : α) : state_t σ m α := ⟨λ s, pure (a, s)⟩ @@ -56,18 +57,18 @@ section protected def lift {α : Type u} (t : m α) : state_t σ m α := ⟨λ s, do a ← t, return (a, s)⟩ - instance (m) [monad m] : has_monad_lift m (state_t σ m) := + instance : has_monad_lift m (state_t σ m) := ⟨@state_t.lift σ m _⟩ - protected def map {σ m m'} [monad m] [monad m'] {α} (f : Π {α}, m α → m' α) : state_t σ m α → state_t σ m' α := + protected def monad_map {σ m m'} [monad m] [monad m'] {α} (f : Π {α}, m α → m' α) : state_t σ m α → state_t σ m' α := λ x, ⟨λ st, f (x.run st)⟩ instance (σ m m') [monad m] [monad m'] : monad_functor m m' (state_t σ m) (state_t σ m') := - ⟨@state_t.map σ m m' _ _⟩ + ⟨@state_t.monad_map σ m m' _ _⟩ -- TODO(Sebastian): uses lenses as in https://hackage.haskell.org/package/lens-4.15.4/docs/Control-Lens-Zoom.html#t:Zoom ? - protected def zoom {σ σ' α : Type u} {m : Type u → Type v} [monad m] (f : σ → σ') (f' : σ' → σ) (x : state_t σ' m α) : state_t σ m α := - ⟨λ st, (λ p : α × σ', (p.fst, f' p.snd)) <$> x.run (f st)⟩ + protected def zoom {σ σ' α : Type u} {m : Type u → Type v} [monad m] (get : σ → σ') (set : σ' → σ → σ) (x : state_t σ' m α) : state_t σ m α := + ⟨λ st, (λ p : α × σ', (p.fst, set p.snd st)) <$> x.run (get st)⟩ instance (ε) [monad_except ε m] : monad_except ε (state_t σ m) := { throw := λ α, state_t.lift ∘ throw, @@ -106,8 +107,8 @@ class monad_state_functor (σ σ' : out_param (Type u)) (m : out_param (Type u attribute [instance] monad_state_functor.mk local attribute [instance] monad_state_functor.functor -def zoom {σ σ'} {m n n'} [monad m] {α : Type u} (f : σ → σ') (f' : σ' → σ) [monad_state_functor σ' σ m n n'] : n α → n' α := -monad_map $ λ α, (state_t.zoom f f' : state_t σ' m α → state_t σ m α) +def zoom {σ σ'} {m n n'} [monad m] {α : Type u} (get : σ → σ') (set : σ' → σ → σ) [monad_state_functor σ' σ m n n'] : n α → n' α := +monad_map $ λ α, (state_t.zoom get set : state_t σ' m α → state_t σ m α) instance (σ m out) [monad_run out m] : monad_run (λ α, σ → out (α × σ)) (state_t σ m) := -⟨λ α x, run ∘ x.run', λ α a, state_t.mk (unrun ∘ a)⟩ +⟨λ α x, run ∘ (λ σ, x.run σ), λ α a, state_t.mk (unrun ∘ a)⟩ diff --git a/library/init/category/transformers.lean b/library/init/category/transformers.lean index 98ff0f062a..63d40a0621 100644 --- a/library/init/category/transformers.lean +++ b/library/init/category/transformers.lean @@ -23,20 +23,22 @@ instance monad_transformer_lift (t m) [monad_transformer t] [monad m] : has_mona ⟨monad_transformer.monad_lift t m⟩ class has_monad_lift_t (m : Type u → Type v) (n : Type u → Type w) := -(monad_lift : ∀ α, m α → n α) +(monad_lift {} : ∀ {α}, m α → n α) -def monad_lift {m n} [has_monad_lift_t m n] {α} : m α → n α := -has_monad_lift_t.monad_lift n α +export has_monad_lift_t (monad_lift) @[reducible] def has_monad_lift_to_has_coe {m n} [has_monad_lift_t m n] {α} : has_coe (m α) (n α) := ⟨monad_lift⟩ instance has_monad_lift_t_trans (m n o) [has_monad_lift n o] [has_monad_lift_t m n] : has_monad_lift_t m o := -⟨λ α (ma : m α), has_monad_lift.monad_lift o α $ has_monad_lift_t.monad_lift n α ma⟩ +⟨λ α (ma : m α), has_monad_lift.monad_lift o α $ @monad_lift m n _ _ ma⟩ instance has_monad_lift_t_refl (m) : has_monad_lift_t m m := ⟨λ α, id⟩ +@[simp] lemma monad_lift_refl {m : Type u → Type v} {α} : (monad_lift : m α → m α) = id := rfl + + /-- A functor in the category of monads. Can be used to lift monad-transforming functions. Based on https://hackage.haskell.org/package/pipes-2.4.0/docs/Control-MFunctor.html, but not restricted to monad transformers. -/ @@ -58,6 +60,9 @@ instance monad_functor_t_trans (m m' n n' o o') [monad_functor n n' o o'] [monad instance monad_functor_t_refl (m m') : monad_functor_t m m' m m' := ⟨λ α f, f⟩ +@[simp] lemma monad_map_refl {m m' : Type u → Type v} (f : ∀ {α}, m α → m' α) {α} : (monad_map @f : m α → m' α) = f := rfl + + /-- Run a monad stack to completion. -/ class monad_run (out : out_param $ Type u → Type v) (m : Type u → Type v) := (run {} {α : Type u} : m α → out α)